a058b0ab8e
- 修复记忆管理数据库连接不可用 (ai-core重编译+Unicode修复) - 修复IoT子会话工具调用链路日志缺失 - 新增最终审查子会话(review_provider) 支持消息格式解析拆分 - 实现历史消息持久化(后端存储+前端分页加载) - 前端新增动作消息(ActionMessage)类型和渲染 - 优化对话链路速度(非阻塞子会话+快速问候通道) - JWT密钥环境变量化(无默认值启动panic) - Token自动刷新机制(401拦截器+refresh接口) - WebSocket指数退避重连(jitter+最大10次) - localStorage清理一致性(cyrene_前缀+版本检查) - IoT环境变量统一为IOT_SERVICE_URL
487 lines
13 KiB
Go
487 lines
13 KiB
Go
package handler
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
|
|
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
|
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
|
"github.com/yourname/cyrene-ai/gateway/internal/ws"
|
|
)
|
|
|
|
// ChatHandler 聊天处理器
|
|
type ChatHandler struct {
|
|
cfg *config.Config
|
|
hub *ws.Hub
|
|
sessionStore *store.SessionStore
|
|
upgrader websocket.Upgrader
|
|
}
|
|
|
|
// NewChatHandler 创建聊天处理器
|
|
func NewChatHandler(cfg *config.Config, hub *ws.Hub, sessionStore *store.SessionStore) *ChatHandler {
|
|
return &ChatHandler{
|
|
cfg: cfg,
|
|
hub: hub,
|
|
sessionStore: sessionStore,
|
|
upgrader: websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true // 开发阶段允许所有来源
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// HandleWebSocket 处理WebSocket升级和消息路由
|
|
func (h *ChatHandler) HandleWebSocket(c *gin.Context) {
|
|
// 从query参数获取token和session_id
|
|
token := c.Query("token")
|
|
sessionID := c.Query("session_id")
|
|
|
|
if token == "" {
|
|
// 也尝试从Authorization头读取
|
|
authHeader := c.GetHeader("Authorization")
|
|
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
|
token = authHeader[7:]
|
|
}
|
|
}
|
|
|
|
if token == "" {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "需要认证令牌"})
|
|
return
|
|
}
|
|
|
|
// 验证token
|
|
userID, err := h.cfg.ValidateToken(token)
|
|
if err != nil {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证令牌无效"})
|
|
return
|
|
}
|
|
|
|
// 主对话仅限管理员访问
|
|
if userID != "admin" {
|
|
c.JSON(http.StatusForbidden, gin.H{
|
|
"error": "主对话仅限管理员使用",
|
|
"errorType": "admin_only",
|
|
"hint": "请使用管理员账号登录以访问主对话功能",
|
|
})
|
|
return
|
|
}
|
|
|
|
if sessionID == "" {
|
|
sessionID = "session_" + generateID()
|
|
}
|
|
|
|
// 升级WebSocket连接
|
|
conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
if err != nil {
|
|
log.Printf("[WS] 升级连接失败: %v", err)
|
|
return
|
|
}
|
|
|
|
// 创建客户端
|
|
client := ws.NewClient(h.hub, conn, userID, sessionID)
|
|
|
|
// 注册到Hub
|
|
h.hub.Register(client)
|
|
|
|
// 启动读写协程
|
|
go client.WritePump()
|
|
go client.ReadPump(func(client *ws.Client, msg ws.ClientMessage) {
|
|
h.handleMessage(client, msg)
|
|
})
|
|
}
|
|
|
|
// handleMessage 处理WebSocket消息
|
|
func (h *ChatHandler) handleMessage(client *ws.Client, msg ws.ClientMessage) {
|
|
switch msg.Type {
|
|
case "message":
|
|
h.handleChatMessage(client, msg)
|
|
case "voice_input":
|
|
h.handleVoiceInput(client, msg)
|
|
case "history":
|
|
h.handleHistoryRequest(client, msg)
|
|
default:
|
|
log.Printf("[WS] 未知消息类型: %s from user=%s", msg.Type, client.UserID)
|
|
}
|
|
}
|
|
|
|
// handleChatMessage 处理文字聊天消息 - 转发到 AI-Core(流式发送)
|
|
func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage) {
|
|
mode := msg.Mode
|
|
if mode == "" {
|
|
mode = "text"
|
|
}
|
|
|
|
// 持久化用户消息到数据库(在 WebSocket 发送之前)
|
|
if h.sessionStore != nil && h.sessionStore.IsAvailable() {
|
|
if err := h.sessionStore.AddMessage(client.SessionID, "user", msg.Content); err != nil {
|
|
log.Printf("[chat] 持久化用户消息失败: %v", err)
|
|
}
|
|
}
|
|
|
|
// 记录用户消息
|
|
h.hub.RecordMessage(client.SessionID, "user", msg.Content)
|
|
|
|
// 设置会话状态为 thinking
|
|
h.hub.UpdateSessionState(client.SessionID, "thinking")
|
|
|
|
// 构建 AI-Core 请求
|
|
aiReq := map[string]interface{}{
|
|
"user_id": client.UserID,
|
|
"session_id": client.SessionID,
|
|
"message": msg.Content,
|
|
"mode": mode,
|
|
}
|
|
if len(msg.Attachments) > 0 {
|
|
aiReq["attachments"] = msg.Attachments
|
|
}
|
|
reqBody, err := json.Marshal(aiReq)
|
|
if err != nil {
|
|
log.Printf("[chat] 序列化请求失败: %v", err)
|
|
h.hub.UpdateSessionState(client.SessionID, "error")
|
|
client.SendMessage(ws.ServerMessage{
|
|
Type: "error",
|
|
MessageID: "msg_" + generateID(),
|
|
Error: "内部错误,请稍后重试",
|
|
Timestamp: time.Now().UnixMilli(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// 缓存用户消息(在 goroutine 前完成,避免竞态)
|
|
userMsg := ws.Message{
|
|
ID: "msg_" + generateID(),
|
|
Role: "user",
|
|
Content: msg.Content,
|
|
Timestamp: time.Now().UnixMilli(),
|
|
}
|
|
if len(msg.Attachments) > 0 {
|
|
userMsg.Attachments = msg.Attachments
|
|
}
|
|
h.hub.CacheMessage(client.UserID, client.SessionID, userMsg)
|
|
|
|
// 在 goroutine 中进行 AI-Core 调用和流式发送,避免阻塞 ReadPump
|
|
go h.streamResponse(client, mode, reqBody, msg.Content)
|
|
}
|
|
|
|
// streamResponse 调用 AI-Core SSE 流式接口并逐 delta 转发给客户端
|
|
func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []byte, userMsg string) {
|
|
aiCoreURL := h.cfg.AICoreURL + "/api/v1/chat"
|
|
httpReq, err := http.NewRequest("POST", aiCoreURL, bytes.NewReader(reqBody))
|
|
if err != nil {
|
|
log.Printf("[chat] 创建 AI-Core 请求失败: %v", err)
|
|
h.hub.UpdateSessionState(client.SessionID, "error")
|
|
client.SendMessage(ws.ServerMessage{
|
|
Type: "error",
|
|
MessageID: "msg_" + generateID(),
|
|
Error: "服务暂不可用",
|
|
Timestamp: time.Now().UnixMilli(),
|
|
})
|
|
return
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
httpReq.Header.Set("Accept", "text/event-stream")
|
|
|
|
httpClient := &http.Client{Timeout: 120 * time.Second}
|
|
resp, err := httpClient.Do(httpReq)
|
|
if err != nil {
|
|
log.Printf("[chat] AI-Core 调用失败: %v", err)
|
|
h.hub.UpdateSessionState(client.SessionID, "error")
|
|
client.SendMessage(ws.ServerMessage{
|
|
Type: "error",
|
|
MessageID: "msg_" + generateID(),
|
|
Error: fmt.Sprintf("AI-Core 调用失败: %v", err),
|
|
Timestamp: time.Now().UnixMilli(),
|
|
})
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
log.Printf("[chat] AI-Core 返回错误 [%d]: %s", resp.StatusCode, string(body))
|
|
h.hub.UpdateSessionState(client.SessionID, "error")
|
|
client.SendMessage(ws.ServerMessage{
|
|
Type: "error",
|
|
MessageID: "msg_" + generateID(),
|
|
Error: fmt.Sprintf("AI-Core 错误 (%d)", resp.StatusCode),
|
|
Timestamp: time.Now().UnixMilli(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// 使用 bufio.Scanner 逐行读取 SSE 响应
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
// 增大 scanner buffer 以处理大块 SSE 数据
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
|
|
|
var fullText string
|
|
var msgID string
|
|
var segments []ws.VoiceSegment // 收集断句信息
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
|
|
// 跳过非 data 行
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
|
|
// SSE 流结束标记
|
|
if data == "[DONE]" {
|
|
break
|
|
}
|
|
|
|
// 解析 delta 数据
|
|
var chunk struct {
|
|
Delta string `json:"delta"`
|
|
Error string `json:"error,omitempty"`
|
|
MessageID string `json:"message_id,omitempty"`
|
|
Mode string `json:"mode,omitempty"`
|
|
Done bool `json:"done,omitempty"`
|
|
// 断句相关 (来自 AI-Core 新格式)
|
|
Segments []struct {
|
|
Index int `json:"index"`
|
|
Text string `json:"text"`
|
|
} `json:"segments,omitempty"`
|
|
}
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|
log.Printf("[chat] 解析 SSE delta 失败: %v, raw=%s", err, data)
|
|
continue
|
|
}
|
|
|
|
// 错误处理
|
|
if chunk.Error != "" {
|
|
log.Printf("[chat] AI-Core 流式错误: %s", chunk.Error)
|
|
h.hub.UpdateSessionState(client.SessionID, "error")
|
|
client.SendMessage(ws.ServerMessage{
|
|
Type: "error",
|
|
MessageID: "msg_" + generateID(),
|
|
Error: chunk.Error,
|
|
Timestamp: time.Now().UnixMilli(),
|
|
})
|
|
return
|
|
}
|
|
|
|
// 记录 message_id
|
|
if chunk.MessageID != "" {
|
|
msgID = chunk.MessageID
|
|
}
|
|
|
|
// 如果是结束标记(含 done: true),跳出
|
|
if chunk.Done {
|
|
break
|
|
}
|
|
|
|
// 处理断句事件 (stream_segments)
|
|
if len(chunk.Segments) > 0 {
|
|
for _, seg := range chunk.Segments {
|
|
segments = append(segments, ws.VoiceSegment{
|
|
Index: seg.Index,
|
|
Text: seg.Text,
|
|
})
|
|
}
|
|
// 发送断句事件给前端
|
|
client.SendMessage(ws.ServerMessage{
|
|
Type: "stream_segments",
|
|
MessageID: msgID,
|
|
Segments: segments,
|
|
SessionID: client.SessionID,
|
|
Timestamp: time.Now().UnixMilli(),
|
|
})
|
|
continue
|
|
}
|
|
|
|
// 逐 delta 转发
|
|
if chunk.Delta != "" {
|
|
fullText += chunk.Delta
|
|
|
|
client.SendMessage(ws.ServerMessage{
|
|
Type: "stream_chunk",
|
|
MessageID: msgID,
|
|
Content: chunk.Delta,
|
|
Role: "assistant",
|
|
SessionID: client.SessionID,
|
|
Timestamp: time.Now().UnixMilli(),
|
|
})
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
log.Printf("[chat] SSE 读取错误: %v", err)
|
|
h.hub.UpdateSessionState(client.SessionID, "error")
|
|
client.SendMessage(ws.ServerMessage{
|
|
Type: "error",
|
|
MessageID: "msg_" + generateID(),
|
|
Error: fmt.Sprintf("流读取错误: %v", err),
|
|
Timestamp: time.Now().UnixMilli(),
|
|
})
|
|
return
|
|
}
|
|
|
|
if msgID == "" {
|
|
msgID = "msg_" + generateID()
|
|
}
|
|
|
|
// 检测是否为多消息格式(包含空行分隔的多条消息)
|
|
multiParts := parseMultiMessage(fullText)
|
|
if len(multiParts) > 1 {
|
|
// 发送 multi_message 事件
|
|
var items []ws.MultiMessageItem
|
|
for i, part := range multiParts {
|
|
items = append(items, ws.MultiMessageItem{
|
|
Index: i,
|
|
Content: part,
|
|
})
|
|
}
|
|
client.SendMessage(ws.ServerMessage{
|
|
Type: "multi_message",
|
|
MessageID: msgID,
|
|
SessionID: client.SessionID,
|
|
MultiMessage: &ws.MultiMessagePayload{
|
|
Messages: items,
|
|
},
|
|
Timestamp: time.Now().UnixMilli(),
|
|
})
|
|
}
|
|
|
|
// 发送 stream_end
|
|
client.SendMessage(ws.ServerMessage{
|
|
Type: "stream_end",
|
|
MessageID: msgID,
|
|
SessionID: client.SessionID,
|
|
Timestamp: time.Now().UnixMilli(),
|
|
})
|
|
|
|
// 持久化 AI 回复到数据库(在 WebSocket 发送之前)
|
|
if fullText != "" {
|
|
if h.sessionStore != nil && h.sessionStore.IsAvailable() {
|
|
if err := h.sessionStore.AddMessage(client.SessionID, "assistant", fullText); err != nil {
|
|
log.Printf("[chat] 持久化 AI 回复失败: %v", err)
|
|
}
|
|
}
|
|
|
|
h.hub.CacheMessage(client.UserID, client.SessionID, ws.Message{
|
|
ID: msgID,
|
|
Role: "assistant",
|
|
Content: fullText,
|
|
Timestamp: time.Now().UnixMilli(),
|
|
})
|
|
}
|
|
h.hub.RecordMessage(client.SessionID, "assistant", fullText)
|
|
|
|
// 设置会话状态为 idle
|
|
h.hub.UpdateSessionState(client.SessionID, "idle")
|
|
}
|
|
|
|
// handleVoiceInput 处理语音输入
|
|
func (h *ChatHandler) handleVoiceInput(client *ws.Client, msg ws.ClientMessage) {
|
|
// MVP阶段:返回提示
|
|
response := ws.ServerMessage{
|
|
Type: "error",
|
|
MessageID: "msg_" + generateID(),
|
|
Error: "语音处理功能将在后续版本中启用",
|
|
Timestamp: time.Now().UnixMilli(),
|
|
}
|
|
client.SendMessage(response)
|
|
}
|
|
|
|
|
|
// handleHistoryRequest 处理历史消息请求
|
|
func (h *ChatHandler) handleHistoryRequest(client *ws.Client, msg ws.ClientMessage) {
|
|
// 优先使用请求中的 session_id,否则使用客户端的 session_id
|
|
sessionID := msg.SessionID
|
|
if sessionID == "" {
|
|
sessionID = client.SessionID
|
|
}
|
|
|
|
messages := h.hub.GetConversation(client.UserID, sessionID)
|
|
if messages == nil {
|
|
messages = []ws.Message{}
|
|
}
|
|
|
|
response := ws.ServerMessage{
|
|
Type: "history_response",
|
|
MessageID: "hist_" + generateID(),
|
|
Messages: messages,
|
|
Timestamp: time.Now().UnixMilli(),
|
|
}
|
|
|
|
if err := client.SendMessage(response); err != nil {
|
|
log.Printf("[WS] 发送历史消息失败: %v", err)
|
|
}
|
|
}
|
|
|
|
// SendSystemMessage 向用户发送系统消息(用于主动通知)
|
|
func (h *ChatHandler) SendSystemMessage(userID, sessionID, text string) error {
|
|
msg := ws.ServerMessage{
|
|
Type: "response",
|
|
MessageID: "sys_" + generateID(),
|
|
Text: text,
|
|
Timestamp: time.Now().UnixMilli(),
|
|
}
|
|
|
|
data, err := json.Marshal(msg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
h.hub.SendToSession(userID, sessionID, data)
|
|
return nil
|
|
}
|
|
|
|
func generateID() string {
|
|
return time.Now().Format("20060102150405") + randomStr(6)
|
|
}
|
|
|
|
func randomStr(n int) string {
|
|
b := make([]byte, n)
|
|
if _, err := rand.Read(b); err != nil {
|
|
// fallback: deterministic but hard to predict
|
|
for i := range b {
|
|
b[i] = byte(time.Now().UnixNano()%256)
|
|
}
|
|
}
|
|
return hex.EncodeToString(b)[:n]
|
|
}
|
|
|
|
// parseMultiMessage 检测并解析多消息格式
|
|
// 如果文本包含空行分隔的多条消息,拆分为多条;否则返回单条
|
|
func parseMultiMessage(text string) []string {
|
|
if text == "" {
|
|
return nil
|
|
}
|
|
// 按双换行(空行)分割
|
|
parts := strings.Split(text, "\n\n")
|
|
// 过滤空字符串并去除首尾空白
|
|
var result []string
|
|
for _, p := range parts {
|
|
p = strings.TrimSpace(p)
|
|
if p != "" {
|
|
result = append(result, p)
|
|
}
|
|
}
|
|
// 如果只有一条,返回 nil 表示不是多消息格式
|
|
if len(result) <= 1 {
|
|
return nil
|
|
}
|
|
return result
|
|
}
|
|
|