package ws import ( "encoding/json" "fmt" "log" "net/http" "os" "sync" "time" "github.com/yourname/cyrene-ai/gateway/internal/store" ) // SessionState 会话状态 type SessionState struct { SessionID string `json:"session_id"` UserID string `json:"user_id"` State string `json:"state"` // idle, thinking, streaming, error ConnectedAt time.Time `json:"connected_at"` LastActivity time.Time `json:"last_activity"` MessageCount int `json:"message_count"` RecentMessages []SessionMessage `json:"recent_messages,omitempty"` } // SessionMessage 会话消息记录 type SessionMessage struct { Role string `json:"role"` // user, assistant, system Content string `json:"content"` // 截断到 100 字符 Timestamp int64 `json:"timestamp"` } // Message 完整对话消息(用于缓存) type Message struct { ID string `json:"id"` Role string `json:"role"` Content string `json:"content"` Attachments []MessageAttachment `json:"attachments,omitempty"` Timestamp int64 `json:"timestamp"` } const maxRecentMessages = 20 // Hub WebSocket连接池 type Hub struct { mu sync.RWMutex clients map[*Client]bool broadcast chan []byte register chan *Client unregister chan *Client // 按用户ID索引的客户端映射 userClients map[string]map[*Client]bool // 会话状态追踪 (sessionID -> SessionState) sessions map[string]*SessionState // 对话缓存:key = "userID:sessionID", value = []Message conversationCache sync.Map convCacheMu sync.Mutex // IoT 设备广播 iotServiceURL string iotStopCh chan struct{} iotPollRunning bool // 持久化存储 (可选,数据库连接失败时为 nil) store *store.SessionStore // 闲置超时时间 idleTimeout time.Duration } // SetStore 设置持久化存储 (可选) func (h *Hub) SetStore(s *store.SessionStore) { h.store = s } // SetIdleTimeout 设置闲置超时时间 func (h *Hub) SetIdleTimeout(minutes int) { h.idleTimeout = time.Duration(minutes) * time.Minute } // NewHub 创建WebSocket Hub func NewHub() *Hub { return &Hub{ clients: make(map[*Client]bool), broadcast: make(chan []byte, 256), register: make(chan *Client), unregister: make(chan *Client), userClients: make(map[string]map[*Client]bool), sessions: make(map[string]*SessionState), iotStopCh: make(chan struct{}), idleTimeout: 30 * time.Minute, // 默认30分钟 } } // StartIdleCleanup 启动闲置会话清理 goroutine // 每5分钟检查一次,将超过 idleTimeout 无活动的会话标记为 idle func (h *Hub) StartIdleCleanup() { go func() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { h.cleanupIdleSessions() } }() log.Printf("[WS] 闲置会话清理已启动 (超时: %v)", h.idleTimeout) } // cleanupIdleSessions 标记超时会话为 idle(不删除状态) func (h *Hub) cleanupIdleSessions() { h.mu.Lock() defer h.mu.Unlock() now := time.Now() idleCount := 0 for sessionID, s := range h.sessions { // 检查该 session 是否还有活跃连接 hasActiveConn := false for _, clients := range h.userClients { for c := range clients { if c.SessionID == sessionID { hasActiveConn = true break } } if hasActiveConn { break } } // 如果没有活跃连接且超过闲置超时,标记为 idle if !hasActiveConn && now.Sub(s.LastActivity) > h.idleTimeout { if s.State != "idle" { s.State = "idle" idleCount++ } } } if idleCount > 0 { log.Printf("[WS] 闲置清理: %d 个会话标记为 idle", idleCount) } } // GetAllActiveSessions 返回所有会话状态(包括 idle),供 DevTools 监看使用 func (h *Hub) GetAllActiveSessions() []*SessionState { h.mu.RLock() defer h.mu.RUnlock() if h.sessions == nil || len(h.sessions) == 0 { return []*SessionState{} } result := make([]*SessionState, 0, len(h.sessions)) for _, s := range h.sessions { cp := *s cp.RecentMessages = nil result = append(result, &cp) } return result } // Run 启动Hub主循环 func (h *Hub) Run() { for { select { case client := <-h.register: h.mu.Lock() h.clients[client] = true // 用户索引 if h.userClients[client.UserID] == nil { h.userClients[client.UserID] = make(map[*Client]bool) } h.userClients[client.UserID][client] = true // 会话状态追踪:如果该session尚未存在则创建 if _, exists := h.sessions[client.SessionID]; !exists { h.sessions[client.SessionID] = &SessionState{ SessionID: client.SessionID, UserID: client.UserID, State: "idle", ConnectedAt: time.Now(), LastActivity: time.Now(), MessageCount: 0, } } h.mu.Unlock() log.Printf("[WS] 客户端连接: user=%s session=%s (当前连接数: %d)", client.UserID, client.SessionID, len(h.clients)) case client := <-h.unregister: h.mu.Lock() if _, ok := h.clients[client]; ok { delete(h.clients, client) close(client.Send) // 清理用户索引 if h.userClients[client.UserID] != nil { delete(h.userClients[client.UserID], client) if len(h.userClients[client.UserID]) == 0 { delete(h.userClients, client.UserID) } } // 检查该session是否还有其他连接,没有则标记为 idle 而非删除 hasOtherConn := false if clients, ok := h.userClients[client.UserID]; ok { for c := range clients { if c.SessionID == client.SessionID { hasOtherConn = true break } } } if !hasOtherConn { // 不再删除 session 状态,而是标记为 idle 保留在内存中 if s, ok := h.sessions[client.SessionID]; ok { s.State = "idle" } } } h.mu.Unlock() log.Printf("[WS] 客户端断开: user=%s session=%s (当前连接数: %d)", client.UserID, client.SessionID, len(h.clients)) case message := <-h.broadcast: h.mu.RLock() for client := range h.clients { select { case client.Send <- message: default: // 客户端发送通道已满,跳过 close(client.Send) delete(h.clients, client) } } h.mu.RUnlock() } } } // Register 注册客户端到Hub(供外部包使用) func (h *Hub) Register(client *Client) { h.register <- client } // SendToUser 向指定用户的所有连接发送消息 func (h *Hub) SendToUser(userID string, message []byte) { h.mu.RLock() defer h.mu.RUnlock() if clients, ok := h.userClients[userID]; ok { for client := range clients { select { case client.Send <- message: default: // 跳过阻塞的客户端 } } } } // SendToSession 向指定会话的连接发送消息 func (h *Hub) SendToSession(userID, sessionID string, message []byte) { h.mu.RLock() defer h.mu.RUnlock() if clients, ok := h.userClients[userID]; ok { for client := range clients { if client.SessionID == sessionID { select { case client.Send <- message: default: } } } } } // BroadcastToAll 向所有连接的客户端广播消息 func (h *Hub) BroadcastToAll(message []byte) { h.mu.RLock() defer h.mu.RUnlock() for client := range h.clients { select { case client.Send <- message: default: // 跳过阻塞的客户端 } } } // ClientCount 获取当前连接数 func (h *Hub) ClientCount() int { h.mu.RLock() defer h.mu.RUnlock() return len(h.clients) } // UserClientCount 获取指定用户的连接数 func (h *Hub) UserClientCount(userID string) int { h.mu.RLock() defer h.mu.RUnlock() if clients, ok := h.userClients[userID]; ok { return len(clients) } return 0 } // GetActiveSessions 返回所有活跃会话的列表 func (h *Hub) GetActiveSessions() []*SessionState { h.mu.RLock() defer h.mu.RUnlock() // 即使没有活跃连接也返回空列表而非 nil if h.sessions == nil || len(h.sessions) == 0 { return []*SessionState{} } result := make([]*SessionState, 0, len(h.sessions)) for _, s := range h.sessions { // 返回副本避免外部修改 cp := *s // 不包含 recent_messages 在列表接口中 cp.RecentMessages = nil result = append(result, &cp) } return result } // GetActiveSessionsByUser 返回按用户分组的活跃会话列表 func (h *Hub) GetActiveSessionsByUser() map[string][]*SessionState { h.mu.RLock() defer h.mu.RUnlock() // 即使没有活跃连接也返回空 map 而非 nil result := make(map[string][]*SessionState) if h.sessions == nil { return result } for _, s := range h.sessions { cp := *s cp.RecentMessages = nil result[s.UserID] = append(result[s.UserID], &cp) } return result } // GetUserSessions 获取某用户的所有活跃会话 func (h *Hub) GetUserSessions(userID string) []*SessionState { h.mu.RLock() defer h.mu.RUnlock() var result []*SessionState if clients, ok := h.userClients[userID]; ok { seen := make(map[string]bool) for c := range clients { if !seen[c.SessionID] { if s, ok := h.sessions[c.SessionID]; ok { cp := *s cp.RecentMessages = nil result = append(result, &cp) seen[c.SessionID] = true } } } } return result } // GetSession 返回指定会话的详细信息(含最近消息) func (h *Hub) GetSession(sessionID string) *SessionState { h.mu.RLock() defer h.mu.RUnlock() s, ok := h.sessions[sessionID] if !ok { return nil } // 返回副本 cp := *s if s.RecentMessages != nil { cp.RecentMessages = make([]SessionMessage, len(s.RecentMessages)) copy(cp.RecentMessages, s.RecentMessages) } return &cp } // UpdateSessionState 更新会话状态 func (h *Hub) UpdateSessionState(sessionID, state string) { h.mu.Lock() defer h.mu.Unlock() if s, ok := h.sessions[sessionID]; ok { s.State = state s.LastActivity = time.Now() } } // RecordMessage 记录消息到会话 func (h *Hub) RecordMessage(sessionID, role, content string) { h.mu.Lock() defer h.mu.Unlock() s, ok := h.sessions[sessionID] if !ok { return } s.MessageCount++ s.LastActivity = time.Now() // 截断内容到 100 字符 runes := []rune(content) if len(runes) > 100 { content = string(runes[:100]) + "..." } s.RecentMessages = append(s.RecentMessages, SessionMessage{ Role: role, Content: content, Timestamp: time.Now().UnixMilli(), }) // 只保留最近 N 条消息 if len(s.RecentMessages) > maxRecentMessages { s.RecentMessages = s.RecentMessages[len(s.RecentMessages)-maxRecentMessages:] } } // ========== IoT 设备广播 ========== // StartIoTBroadcast 启动 IoT 设备轮询并向所有客户端广播 func (h *Hub) StartIoTBroadcast(iotServiceURL string) { h.mu.Lock() if h.iotPollRunning { h.mu.Unlock() return } h.iotServiceURL = iotServiceURL h.iotStopCh = make(chan struct{}) h.iotPollRunning = true h.mu.Unlock() go h.iotPollLoop() log.Printf("[IoT广播] 已启动 (IoT服务地址: %s)", iotServiceURL) } // StopIoTBroadcast 停止 IoT 设备广播 func (h *Hub) StopIoTBroadcast() { h.mu.Lock() defer h.mu.Unlock() if !h.iotPollRunning { return } close(h.iotStopCh) h.iotPollRunning = false log.Println("[IoT广播] 已停止") } // iotPollLoop IoT 设备轮询循环 func (h *Hub) iotPollLoop() { // 首次立即推送 h.pollAndBroadcastIoT() ticker := time.NewTicker(10 * time.Second) // 每10秒轮询一次 defer ticker.Stop() for { select { case <-h.iotStopCh: return case <-ticker.C: h.pollAndBroadcastIoT() } } } // pollAndBroadcastIoT 从 IoT 调试服务获取设备状态并广播 func (h *Hub) pollAndBroadcastIoT() { if h.ClientCount() == 0 { return // 没有客户端连接时跳过 } h.mu.RLock() url := h.iotServiceURL h.mu.RUnlock() if url == "" { url = getEnv("IOT_DEBUG_SERVICE_URL", "http://localhost:8083") } devices, err := fetchIoTDevices(url) if err != nil { log.Printf("[IoT广播] 获取设备失败: %v", err) // 即使失败也发送空列表,让前端知道 IoT 服务状态 devices = []IotDeviceInfo{} } msg := ServerMessage{ Type: "device_update", Devices: devices, Timestamp: time.Now().UnixMilli(), } data, err := json.Marshal(msg) if err != nil { log.Printf("[IoT广播] 消息序列化失败: %v", err) return } h.BroadcastToAll(data) deviceNames := make([]string, 0, len(devices)) for _, d := range devices { deviceNames = append(deviceNames, d.Name) } log.Printf("[IoT广播] 已推送 %d 个设备状态到 %d 个客户端: %v", len(devices), h.ClientCount(), deviceNames) } // fetchIoTDevices 从 IoT 调试服务获取设备列表 func fetchIoTDevices(serviceURL string) ([]IotDeviceInfo, error) { client := &http.Client{Timeout: 5 * time.Second} resp, err := client.Get(serviceURL + "/api/v1/devices") if err != nil { return nil, fmt.Errorf("请求IoT服务失败: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("IoT服务返回状态码 %d", resp.StatusCode) } var result struct { Devices []IotDeviceInfo `json:"devices"` Total int `json:"total"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, fmt.Errorf("解析IoT设备列表失败: %w", err) } return result.Devices, nil } func getEnv(key, fallback string) string { if v := os.Getenv(key); v != "" { return v } return fallback } // ========== 对话缓存方法 ========== const maxConversationCache = 50 // cacheKey 生成对话缓存 key func cacheKey(userID, sessionID string) string { return fmt.Sprintf("%s:%s", userID, sessionID) } // CacheMessage 缓存单条消息到对话历史 // 应对外暴露:由 gateway chat handler 在收到用户消息和 AI 回复时调用 func (h *Hub) CacheMessage(userID, sessionID string, msg Message) { key := cacheKey(userID, sessionID) h.convCacheMu.Lock() defer h.convCacheMu.Unlock() existing, _ := h.conversationCache.Load(key) var messages []Message if existing != nil { messages = existing.([]Message) } messages = append(messages, msg) // 限制缓存消息数量上限 if len(messages) > maxConversationCache { messages = messages[len(messages)-maxConversationCache:] } h.conversationCache.Store(key, messages) } // GetConversation 获取完整对话历史 func (h *Hub) GetConversation(userID, sessionID string) []Message { key := cacheKey(userID, sessionID) val, ok := h.conversationCache.Load(key) if !ok { return []Message{} } messages, ok := val.([]Message) if !ok { return []Message{} } return messages } // GetSessionHistory 获取会话历史消息(限制条数) func (h *Hub) GetSessionHistory(userID, sessionID string, limit int) []Message { messages := h.GetConversation(userID, sessionID) if len(messages) == 0 { return []Message{} } if limit > 0 && len(messages) > limit { start := len(messages) - limit return messages[start:] } return messages } // DeleteConversation 删除对话缓存 func (h *Hub) DeleteConversation(userID, sessionID string) { key := cacheKey(userID, sessionID) h.conversationCache.Delete(key) }