package handler import ( "bufio" "bytes" "encoding/json" "fmt" "io" "log" "net/http" "strings" "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/yourname/cyrene-ai/gateway/internal/config" "github.com/yourname/cyrene-ai/gateway/internal/ws" ) // ChatHandler 聊天处理器 type ChatHandler struct { cfg *config.Config hub *ws.Hub upgrader websocket.Upgrader } // NewChatHandler 创建聊天处理器 func NewChatHandler(cfg *config.Config, hub *ws.Hub) *ChatHandler { return &ChatHandler{ cfg: cfg, hub: hub, upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return true // 开发阶段允许所有来源 }, }, } } // HandleWebSocket 处理WebSocket升级和消息路由 func (h *ChatHandler) HandleWebSocket(c *gin.Context) { // 从query参数获取token和session_id token := c.Query("token") sessionID := c.Query("session_id") if token == "" { // 也尝试从Authorization头读取 authHeader := c.GetHeader("Authorization") if len(authHeader) > 7 && authHeader[:7] == "Bearer " { token = authHeader[7:] } } if token == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "需要认证令牌"}) return } // 验证token userID, err := h.cfg.ValidateToken(token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "认证令牌无效"}) return } // 主对话仅限管理员访问 if !strings.HasPrefix(userID, "admin_") { c.JSON(http.StatusForbidden, gin.H{ "error": "主对话仅限管理员使用", "errorType": "admin_only", "hint": "请使用管理员账号 (admin_ 前缀) 登录以访问主对话功能", }) return } if sessionID == "" { sessionID = "session_" + generateID() } // 升级WebSocket连接 conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { log.Printf("[WS] 升级连接失败: %v", err) return } // 创建客户端 client := ws.NewClient(h.hub, conn, userID, sessionID) // 注册到Hub h.hub.Register(client) // 启动读写协程 go client.WritePump() go client.ReadPump(func(client *ws.Client, msg ws.ClientMessage) { h.handleMessage(client, msg) }) } // handleMessage 处理WebSocket消息 func (h *ChatHandler) handleMessage(client *ws.Client, msg ws.ClientMessage) { switch msg.Type { case "message": h.handleChatMessage(client, msg) case "voice_input": h.handleVoiceInput(client, msg) case "history": h.handleHistoryRequest(client, msg) default: log.Printf("[WS] 未知消息类型: %s from user=%s", msg.Type, client.UserID) } } // handleChatMessage 处理文字聊天消息 - 转发到 AI-Core(流式发送) func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage) { mode := msg.Mode if mode == "" { mode = "text" } // 记录用户消息 h.hub.RecordMessage(client.SessionID, "user", msg.Content) // 设置会话状态为 thinking h.hub.UpdateSessionState(client.SessionID, "thinking") // 构建 AI-Core 请求 aiReq := map[string]interface{}{ "user_id": client.UserID, "session_id": client.SessionID, "message": msg.Content, "mode": mode, } if len(msg.Attachments) > 0 { aiReq["attachments"] = msg.Attachments } reqBody, err := json.Marshal(aiReq) if err != nil { log.Printf("[chat] 序列化请求失败: %v", err) h.hub.UpdateSessionState(client.SessionID, "error") client.SendMessage(ws.ServerMessage{ Type: "error", MessageID: "msg_" + generateID(), Error: "内部错误,请稍后重试", Timestamp: time.Now().UnixMilli(), }) return } // 缓存用户消息(在 goroutine 前完成,避免竞态) userMsg := ws.Message{ ID: "msg_" + generateID(), Role: "user", Content: msg.Content, Timestamp: time.Now().UnixMilli(), } if len(msg.Attachments) > 0 { userMsg.Attachments = msg.Attachments } h.hub.CacheMessage(client.UserID, client.SessionID, userMsg) // 在 goroutine 中进行 AI-Core 调用和流式发送,避免阻塞 ReadPump go h.streamResponse(client, mode, reqBody, msg.Content) } // streamResponse 调用 AI-Core SSE 流式接口并逐 delta 转发给客户端 func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []byte, userMsg string) { aiCoreURL := h.cfg.AICoreURL + "/api/v1/chat" httpReq, err := http.NewRequest("POST", aiCoreURL, bytes.NewReader(reqBody)) if err != nil { log.Printf("[chat] 创建 AI-Core 请求失败: %v", err) h.hub.UpdateSessionState(client.SessionID, "error") client.SendMessage(ws.ServerMessage{ Type: "error", MessageID: "msg_" + generateID(), Error: "服务暂不可用", Timestamp: time.Now().UnixMilli(), }) return } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Accept", "text/event-stream") httpClient := &http.Client{Timeout: 120 * time.Second} resp, err := httpClient.Do(httpReq) if err != nil { log.Printf("[chat] AI-Core 调用失败: %v", err) h.hub.UpdateSessionState(client.SessionID, "error") client.SendMessage(ws.ServerMessage{ Type: "error", MessageID: "msg_" + generateID(), Error: fmt.Sprintf("AI-Core 调用失败: %v", err), Timestamp: time.Now().UnixMilli(), }) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) log.Printf("[chat] AI-Core 返回错误 [%d]: %s", resp.StatusCode, string(body)) h.hub.UpdateSessionState(client.SessionID, "error") client.SendMessage(ws.ServerMessage{ Type: "error", MessageID: "msg_" + generateID(), Error: fmt.Sprintf("AI-Core 错误 (%d)", resp.StatusCode), Timestamp: time.Now().UnixMilli(), }) return } // 使用 bufio.Scanner 逐行读取 SSE 响应 scanner := bufio.NewScanner(resp.Body) // 增大 scanner buffer 以处理大块 SSE 数据 scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) var fullText string var msgID string var segments []ws.VoiceSegment // 收集断句信息 for scanner.Scan() { line := scanner.Text() // 跳过非 data 行 if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") // SSE 流结束标记 if data == "[DONE]" { break } // 解析 delta 数据 var chunk struct { Delta string `json:"delta"` Error string `json:"error,omitempty"` MessageID string `json:"message_id,omitempty"` Mode string `json:"mode,omitempty"` Done bool `json:"done,omitempty"` // 断句相关 (来自 AI-Core 新格式) Segments []struct { Index int `json:"index"` Text string `json:"text"` } `json:"segments,omitempty"` } if err := json.Unmarshal([]byte(data), &chunk); err != nil { log.Printf("[chat] 解析 SSE delta 失败: %v, raw=%s", err, data) continue } // 错误处理 if chunk.Error != "" { log.Printf("[chat] AI-Core 流式错误: %s", chunk.Error) h.hub.UpdateSessionState(client.SessionID, "error") client.SendMessage(ws.ServerMessage{ Type: "error", MessageID: "msg_" + generateID(), Error: chunk.Error, Timestamp: time.Now().UnixMilli(), }) return } // 记录 message_id if chunk.MessageID != "" { msgID = chunk.MessageID } // 如果是结束标记(含 done: true),跳出 if chunk.Done { break } // 处理断句事件 (stream_segments) if len(chunk.Segments) > 0 { for _, seg := range chunk.Segments { segments = append(segments, ws.VoiceSegment{ Index: seg.Index, Text: seg.Text, }) } // 发送断句事件给前端 client.SendMessage(ws.ServerMessage{ Type: "stream_segments", MessageID: msgID, Segments: segments, SessionID: client.SessionID, Timestamp: time.Now().UnixMilli(), }) continue } // 逐 delta 转发 if chunk.Delta != "" { fullText += chunk.Delta client.SendMessage(ws.ServerMessage{ Type: "stream_chunk", MessageID: msgID, Content: chunk.Delta, Role: "assistant", SessionID: client.SessionID, Timestamp: time.Now().UnixMilli(), }) } } if err := scanner.Err(); err != nil { log.Printf("[chat] SSE 读取错误: %v", err) h.hub.UpdateSessionState(client.SessionID, "error") client.SendMessage(ws.ServerMessage{ Type: "error", MessageID: "msg_" + generateID(), Error: fmt.Sprintf("流读取错误: %v", err), Timestamp: time.Now().UnixMilli(), }) return } if msgID == "" { msgID = "msg_" + generateID() } // 检测是否为多消息格式(包含空行分隔的多条消息) multiParts := parseMultiMessage(fullText) if len(multiParts) > 1 { // 发送 multi_message 事件 var items []ws.MultiMessageItem for i, part := range multiParts { items = append(items, ws.MultiMessageItem{ Index: i, Content: part, }) } client.SendMessage(ws.ServerMessage{ Type: "multi_message", MessageID: msgID, SessionID: client.SessionID, MultiMessage: &ws.MultiMessagePayload{ Messages: items, }, Timestamp: time.Now().UnixMilli(), }) } // 发送 stream_end client.SendMessage(ws.ServerMessage{ Type: "stream_end", MessageID: msgID, SessionID: client.SessionID, Timestamp: time.Now().UnixMilli(), }) // 缓存完整响应 if fullText != "" { h.hub.CacheMessage(client.UserID, client.SessionID, ws.Message{ ID: msgID, Role: "assistant", Content: fullText, Timestamp: time.Now().UnixMilli(), }) } h.hub.RecordMessage(client.SessionID, "assistant", fullText) // 设置会话状态为 idle h.hub.UpdateSessionState(client.SessionID, "idle") } // handleVoiceInput 处理语音输入 func (h *ChatHandler) handleVoiceInput(client *ws.Client, msg ws.ClientMessage) { // MVP阶段:返回提示 response := ws.ServerMessage{ Type: "error", MessageID: "msg_" + generateID(), Error: "语音处理功能将在后续版本中启用", Timestamp: time.Now().UnixMilli(), } client.SendMessage(response) } // handleHistoryRequest 处理历史消息请求 func (h *ChatHandler) handleHistoryRequest(client *ws.Client, msg ws.ClientMessage) { // 优先使用请求中的 session_id,否则使用客户端的 session_id sessionID := msg.SessionID if sessionID == "" { sessionID = client.SessionID } messages := h.hub.GetConversation(client.UserID, sessionID) if messages == nil { messages = []ws.Message{} } response := ws.ServerMessage{ Type: "history_response", MessageID: "hist_" + generateID(), Messages: messages, Timestamp: time.Now().UnixMilli(), } if err := client.SendMessage(response); err != nil { log.Printf("[WS] 发送历史消息失败: %v", err) } } // SendSystemMessage 向用户发送系统消息(用于主动通知) func (h *ChatHandler) SendSystemMessage(userID, sessionID, text string) error { msg := ws.ServerMessage{ Type: "response", MessageID: "sys_" + generateID(), Text: text, Timestamp: time.Now().UnixMilli(), } data, err := json.Marshal(msg) if err != nil { return err } h.hub.SendToSession(userID, sessionID, data) return nil } func generateID() string { return time.Now().Format("20060102150405") + randomStr(6) } func randomStr(n int) string { const letters = "abcdefghijklmnopqrstuvwxyz0123456789" b := make([]byte, n) for i := range b { b[i] = letters[time.Now().UnixNano()%int64(len(letters))] } return string(b) } // parseMultiMessage 检测并解析多消息格式 // 如果文本包含空行分隔的多条消息,拆分为多条;否则返回单条 func parseMultiMessage(text string) []string { if text == "" { return nil } // 按双换行(空行)分割 parts := strings.Split(text, "\n\n") // 过滤空字符串并去除首尾空白 var result []string for _, p := range parts { p = strings.TrimSpace(p) if p != "" { result = append(result, p) } } // 如果只有一条,返回 nil 表示不是多消息格式 if len(result) <= 1 { return nil } return result }