Files
Cyrene/backend/gateway/internal/handler/chat_handler.go
T
AskaEth a058b0ab8e fix: 第一轮修复 - 记忆管理/IoT操控/历史消息持久化/动作消息/链路优化/安全配置
- 修复记忆管理数据库连接不可用 (ai-core重编译+Unicode修复)
- 修复IoT子会话工具调用链路日志缺失
- 新增最终审查子会话(review_provider) 支持消息格式解析拆分
- 实现历史消息持久化(后端存储+前端分页加载)
- 前端新增动作消息(ActionMessage)类型和渲染
- 优化对话链路速度(非阻塞子会话+快速问候通道)
- JWT密钥环境变量化(无默认值启动panic)
- Token自动刷新机制(401拦截器+refresh接口)
- WebSocket指数退避重连(jitter+最大10次)
- localStorage清理一致性(cyrene_前缀+版本检查)
- IoT环境变量统一为IOT_SERVICE_URL
2026-05-21 23:10:07 +08:00

487 lines
13 KiB
Go

package handler
import (
"bufio"
"bytes"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/yourname/cyrene-ai/gateway/internal/config"
"github.com/yourname/cyrene-ai/gateway/internal/store"
"github.com/yourname/cyrene-ai/gateway/internal/ws"
)
// ChatHandler 聊天处理器
type ChatHandler struct {
cfg *config.Config
hub *ws.Hub
sessionStore *store.SessionStore
upgrader websocket.Upgrader
}
// NewChatHandler 创建聊天处理器
func NewChatHandler(cfg *config.Config, hub *ws.Hub, sessionStore *store.SessionStore) *ChatHandler {
return &ChatHandler{
cfg: cfg,
hub: hub,
sessionStore: sessionStore,
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // 开发阶段允许所有来源
},
},
}
}
// HandleWebSocket 处理WebSocket升级和消息路由
func (h *ChatHandler) HandleWebSocket(c *gin.Context) {
// 从query参数获取token和session_id
token := c.Query("token")
sessionID := c.Query("session_id")
if token == "" {
// 也尝试从Authorization头读取
authHeader := c.GetHeader("Authorization")
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
token = authHeader[7:]
}
}
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "需要认证令牌"})
return
}
// 验证token
userID, err := h.cfg.ValidateToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证令牌无效"})
return
}
// 主对话仅限管理员访问
if userID != "admin" {
c.JSON(http.StatusForbidden, gin.H{
"error": "主对话仅限管理员使用",
"errorType": "admin_only",
"hint": "请使用管理员账号登录以访问主对话功能",
})
return
}
if sessionID == "" {
sessionID = "session_" + generateID()
}
// 升级WebSocket连接
conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("[WS] 升级连接失败: %v", err)
return
}
// 创建客户端
client := ws.NewClient(h.hub, conn, userID, sessionID)
// 注册到Hub
h.hub.Register(client)
// 启动读写协程
go client.WritePump()
go client.ReadPump(func(client *ws.Client, msg ws.ClientMessage) {
h.handleMessage(client, msg)
})
}
// handleMessage 处理WebSocket消息
func (h *ChatHandler) handleMessage(client *ws.Client, msg ws.ClientMessage) {
switch msg.Type {
case "message":
h.handleChatMessage(client, msg)
case "voice_input":
h.handleVoiceInput(client, msg)
case "history":
h.handleHistoryRequest(client, msg)
default:
log.Printf("[WS] 未知消息类型: %s from user=%s", msg.Type, client.UserID)
}
}
// handleChatMessage 处理文字聊天消息 - 转发到 AI-Core(流式发送)
func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage) {
mode := msg.Mode
if mode == "" {
mode = "text"
}
// 持久化用户消息到数据库(在 WebSocket 发送之前)
if h.sessionStore != nil && h.sessionStore.IsAvailable() {
if err := h.sessionStore.AddMessage(client.SessionID, "user", msg.Content); err != nil {
log.Printf("[chat] 持久化用户消息失败: %v", err)
}
}
// 记录用户消息
h.hub.RecordMessage(client.SessionID, "user", msg.Content)
// 设置会话状态为 thinking
h.hub.UpdateSessionState(client.SessionID, "thinking")
// 构建 AI-Core 请求
aiReq := map[string]interface{}{
"user_id": client.UserID,
"session_id": client.SessionID,
"message": msg.Content,
"mode": mode,
}
if len(msg.Attachments) > 0 {
aiReq["attachments"] = msg.Attachments
}
reqBody, err := json.Marshal(aiReq)
if err != nil {
log.Printf("[chat] 序列化请求失败: %v", err)
h.hub.UpdateSessionState(client.SessionID, "error")
client.SendMessage(ws.ServerMessage{
Type: "error",
MessageID: "msg_" + generateID(),
Error: "内部错误,请稍后重试",
Timestamp: time.Now().UnixMilli(),
})
return
}
// 缓存用户消息(在 goroutine 前完成,避免竞态)
userMsg := ws.Message{
ID: "msg_" + generateID(),
Role: "user",
Content: msg.Content,
Timestamp: time.Now().UnixMilli(),
}
if len(msg.Attachments) > 0 {
userMsg.Attachments = msg.Attachments
}
h.hub.CacheMessage(client.UserID, client.SessionID, userMsg)
// 在 goroutine 中进行 AI-Core 调用和流式发送,避免阻塞 ReadPump
go h.streamResponse(client, mode, reqBody, msg.Content)
}
// streamResponse 调用 AI-Core SSE 流式接口并逐 delta 转发给客户端
func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []byte, userMsg string) {
aiCoreURL := h.cfg.AICoreURL + "/api/v1/chat"
httpReq, err := http.NewRequest("POST", aiCoreURL, bytes.NewReader(reqBody))
if err != nil {
log.Printf("[chat] 创建 AI-Core 请求失败: %v", err)
h.hub.UpdateSessionState(client.SessionID, "error")
client.SendMessage(ws.ServerMessage{
Type: "error",
MessageID: "msg_" + generateID(),
Error: "服务暂不可用",
Timestamp: time.Now().UnixMilli(),
})
return
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "text/event-stream")
httpClient := &http.Client{Timeout: 120 * time.Second}
resp, err := httpClient.Do(httpReq)
if err != nil {
log.Printf("[chat] AI-Core 调用失败: %v", err)
h.hub.UpdateSessionState(client.SessionID, "error")
client.SendMessage(ws.ServerMessage{
Type: "error",
MessageID: "msg_" + generateID(),
Error: fmt.Sprintf("AI-Core 调用失败: %v", err),
Timestamp: time.Now().UnixMilli(),
})
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Printf("[chat] AI-Core 返回错误 [%d]: %s", resp.StatusCode, string(body))
h.hub.UpdateSessionState(client.SessionID, "error")
client.SendMessage(ws.ServerMessage{
Type: "error",
MessageID: "msg_" + generateID(),
Error: fmt.Sprintf("AI-Core 错误 (%d)", resp.StatusCode),
Timestamp: time.Now().UnixMilli(),
})
return
}
// 使用 bufio.Scanner 逐行读取 SSE 响应
scanner := bufio.NewScanner(resp.Body)
// 增大 scanner buffer 以处理大块 SSE 数据
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
var fullText string
var msgID string
var segments []ws.VoiceSegment // 收集断句信息
for scanner.Scan() {
line := scanner.Text()
// 跳过非 data 行
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
// SSE 流结束标记
if data == "[DONE]" {
break
}
// 解析 delta 数据
var chunk struct {
Delta string `json:"delta"`
Error string `json:"error,omitempty"`
MessageID string `json:"message_id,omitempty"`
Mode string `json:"mode,omitempty"`
Done bool `json:"done,omitempty"`
// 断句相关 (来自 AI-Core 新格式)
Segments []struct {
Index int `json:"index"`
Text string `json:"text"`
} `json:"segments,omitempty"`
}
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
log.Printf("[chat] 解析 SSE delta 失败: %v, raw=%s", err, data)
continue
}
// 错误处理
if chunk.Error != "" {
log.Printf("[chat] AI-Core 流式错误: %s", chunk.Error)
h.hub.UpdateSessionState(client.SessionID, "error")
client.SendMessage(ws.ServerMessage{
Type: "error",
MessageID: "msg_" + generateID(),
Error: chunk.Error,
Timestamp: time.Now().UnixMilli(),
})
return
}
// 记录 message_id
if chunk.MessageID != "" {
msgID = chunk.MessageID
}
// 如果是结束标记(含 done: true),跳出
if chunk.Done {
break
}
// 处理断句事件 (stream_segments)
if len(chunk.Segments) > 0 {
for _, seg := range chunk.Segments {
segments = append(segments, ws.VoiceSegment{
Index: seg.Index,
Text: seg.Text,
})
}
// 发送断句事件给前端
client.SendMessage(ws.ServerMessage{
Type: "stream_segments",
MessageID: msgID,
Segments: segments,
SessionID: client.SessionID,
Timestamp: time.Now().UnixMilli(),
})
continue
}
// 逐 delta 转发
if chunk.Delta != "" {
fullText += chunk.Delta
client.SendMessage(ws.ServerMessage{
Type: "stream_chunk",
MessageID: msgID,
Content: chunk.Delta,
Role: "assistant",
SessionID: client.SessionID,
Timestamp: time.Now().UnixMilli(),
})
}
}
if err := scanner.Err(); err != nil {
log.Printf("[chat] SSE 读取错误: %v", err)
h.hub.UpdateSessionState(client.SessionID, "error")
client.SendMessage(ws.ServerMessage{
Type: "error",
MessageID: "msg_" + generateID(),
Error: fmt.Sprintf("流读取错误: %v", err),
Timestamp: time.Now().UnixMilli(),
})
return
}
if msgID == "" {
msgID = "msg_" + generateID()
}
// 检测是否为多消息格式(包含空行分隔的多条消息)
multiParts := parseMultiMessage(fullText)
if len(multiParts) > 1 {
// 发送 multi_message 事件
var items []ws.MultiMessageItem
for i, part := range multiParts {
items = append(items, ws.MultiMessageItem{
Index: i,
Content: part,
})
}
client.SendMessage(ws.ServerMessage{
Type: "multi_message",
MessageID: msgID,
SessionID: client.SessionID,
MultiMessage: &ws.MultiMessagePayload{
Messages: items,
},
Timestamp: time.Now().UnixMilli(),
})
}
// 发送 stream_end
client.SendMessage(ws.ServerMessage{
Type: "stream_end",
MessageID: msgID,
SessionID: client.SessionID,
Timestamp: time.Now().UnixMilli(),
})
// 持久化 AI 回复到数据库(在 WebSocket 发送之前)
if fullText != "" {
if h.sessionStore != nil && h.sessionStore.IsAvailable() {
if err := h.sessionStore.AddMessage(client.SessionID, "assistant", fullText); err != nil {
log.Printf("[chat] 持久化 AI 回复失败: %v", err)
}
}
h.hub.CacheMessage(client.UserID, client.SessionID, ws.Message{
ID: msgID,
Role: "assistant",
Content: fullText,
Timestamp: time.Now().UnixMilli(),
})
}
h.hub.RecordMessage(client.SessionID, "assistant", fullText)
// 设置会话状态为 idle
h.hub.UpdateSessionState(client.SessionID, "idle")
}
// handleVoiceInput 处理语音输入
func (h *ChatHandler) handleVoiceInput(client *ws.Client, msg ws.ClientMessage) {
// MVP阶段:返回提示
response := ws.ServerMessage{
Type: "error",
MessageID: "msg_" + generateID(),
Error: "语音处理功能将在后续版本中启用",
Timestamp: time.Now().UnixMilli(),
}
client.SendMessage(response)
}
// handleHistoryRequest 处理历史消息请求
func (h *ChatHandler) handleHistoryRequest(client *ws.Client, msg ws.ClientMessage) {
// 优先使用请求中的 session_id,否则使用客户端的 session_id
sessionID := msg.SessionID
if sessionID == "" {
sessionID = client.SessionID
}
messages := h.hub.GetConversation(client.UserID, sessionID)
if messages == nil {
messages = []ws.Message{}
}
response := ws.ServerMessage{
Type: "history_response",
MessageID: "hist_" + generateID(),
Messages: messages,
Timestamp: time.Now().UnixMilli(),
}
if err := client.SendMessage(response); err != nil {
log.Printf("[WS] 发送历史消息失败: %v", err)
}
}
// SendSystemMessage 向用户发送系统消息(用于主动通知)
func (h *ChatHandler) SendSystemMessage(userID, sessionID, text string) error {
msg := ws.ServerMessage{
Type: "response",
MessageID: "sys_" + generateID(),
Text: text,
Timestamp: time.Now().UnixMilli(),
}
data, err := json.Marshal(msg)
if err != nil {
return err
}
h.hub.SendToSession(userID, sessionID, data)
return nil
}
func generateID() string {
return time.Now().Format("20060102150405") + randomStr(6)
}
func randomStr(n int) string {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
// fallback: deterministic but hard to predict
for i := range b {
b[i] = byte(time.Now().UnixNano()%256)
}
}
return hex.EncodeToString(b)[:n]
}
// parseMultiMessage 检测并解析多消息格式
// 如果文本包含空行分隔的多条消息,拆分为多条;否则返回单条
func parseMultiMessage(text string) []string {
if text == "" {
return nil
}
// 按双换行(空行)分割
parts := strings.Split(text, "\n\n")
// 过滤空字符串并去除首尾空白
var result []string
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
result = append(result, p)
}
}
// 如果只有一条,返回 nil 表示不是多消息格式
if len(result) <= 1 {
return nil
}
return result
}