package ws import ( "encoding/json" "fmt" "git.yeij.top/AskaEth/Cyrene/pkg/logger" "net/http" "os" "strings" "sync" "time" "git.yeij.top/AskaEth/Cyrene/gateway/internal/store" ) // SessionState 会话状态 type SessionState struct { SessionID string `json:"session_id"` UserID string `json:"user_id"` State string `json:"state"` // idle, thinking, streaming, error ConnectedAt time.Time `json:"connected_at"` LastActivity time.Time `json:"last_activity"` MessageCount int `json:"message_count"` RecentMessages []SessionMessage `json:"recent_messages,omitempty"` } // SessionMessage 会话消息记录 type SessionMessage struct { Role string `json:"role"` // user, assistant, system Content string `json:"content"` // 截断到 100 字符 Timestamp int64 `json:"timestamp"` } // Message 完整对话消息(用于缓存) type Message struct { ID string `json:"id"` Role string `json:"role"` Content string `json:"content"` MsgType string `json:"msg_type,omitempty"` Attachments []MessageAttachment `json:"attachments,omitempty"` Timestamp int64 `json:"timestamp"` ClientInfo *ClientInfo `json:"client_info,omitempty"` } const maxRecentMessages = 20 // KnownClient tracks a device that has ever connected (online or offline). type KnownClient struct { ClientID string `json:"client_id"` UserID string `json:"user_id"` DeviceName string `json:"device_name"` UserAgent string `json:"user_agent"` Note string `json:"note"` // user-assigned label Online bool `json:"online"` LastSeenAt time.Time `json:"last_seen_at"` FirstSeenAt time.Time `json:"first_seen_at"` } // Hub WebSocket连接池 type Hub struct { mu sync.RWMutex clients map[*Client]bool broadcast chan []byte register chan *Client unregister chan *Client // 按用户ID索引的客户端映射 userClients map[string]map[*Client]bool // 会话状态追踪 (sessionID -> SessionState) sessions map[string]*SessionState // 对话缓存:key = "userID:sessionID", value = []Message conversationCache sync.Map convCacheMu sync.Mutex // IoT 设备广播 iotServiceURL string iotStopCh chan struct{} iotPollRunning bool // 持久化存储 (可选,数据库连接失败时为 nil) store *store.SessionStore // 闲置超时时间 idleTimeout time.Duration // Phase 2: 离线主动消息队列 + 在线状态通知 pendingProactive map[string][]json.RawMessage // userID -> queued messages aiCoreURL string internalToken string // 多端客户端追踪: clientID -> KnownClient (在线+离线) knownClients map[string]*KnownClient } // SetStore 设置持久化存储 (可选) func (h *Hub) SetStore(s *store.SessionStore) { h.store = s } // SetIdleTimeout 设置闲置超时时间 func (h *Hub) SetIdleTimeout(minutes int) { h.idleTimeout = time.Duration(minutes) * time.Minute } // NewHub 创建WebSocket Hub func NewHub() *Hub { return &Hub{ clients: make(map[*Client]bool), broadcast: make(chan []byte, 256), register: make(chan *Client), unregister: make(chan *Client), userClients: make(map[string]map[*Client]bool), sessions: make(map[string]*SessionState), iotStopCh: make(chan struct{}), idleTimeout: 30 * time.Minute, // 默认30分钟 pendingProactive: make(map[string][]json.RawMessage), knownClients: make(map[string]*KnownClient), } } // StartIdleCleanup 启动闲置会话清理 goroutine // 每5分钟检查一次,将超过 idleTimeout 无活动的会话标记为 idle func (h *Hub) StartIdleCleanup() { go func() { defer func() { if r := recover(); r != nil { logger.Printf("[WS] 闲置会话清理 panic 恢复: %v", r) } }() ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { h.cleanupIdleSessions() } }() logger.Printf("[WS] 闲置会话清理已启动 (超时: %v)", h.idleTimeout) } // cleanupIdleSessions 标记超时会话为 idle(不删除状态) func (h *Hub) cleanupIdleSessions() { h.mu.Lock() defer h.mu.Unlock() now := time.Now() idleCount := 0 for sessionID, s := range h.sessions { // 检查该 session 是否还有活跃连接 hasActiveConn := false for _, clients := range h.userClients { for c := range clients { if c.SessionID == sessionID { hasActiveConn = true break } } if hasActiveConn { break } } // 如果没有活跃连接且超过闲置超时,标记为 idle if !hasActiveConn && now.Sub(s.LastActivity) > h.idleTimeout { if s.State != "idle" { s.State = "idle" idleCount++ } } } if idleCount > 0 { logger.Printf("[WS] 闲置清理: %d 个会话标记为 idle", idleCount) } } // GetAllActiveSessions 返回所有会话状态(包括 idle),供 ethend 监看使用 func (h *Hub) GetAllActiveSessions() []*SessionState { h.mu.RLock() defer h.mu.RUnlock() if h.sessions == nil || len(h.sessions) == 0 { return []*SessionState{} } result := make([]*SessionState, 0, len(h.sessions)) for _, s := range h.sessions { cp := *s cp.RecentMessages = nil result = append(result, &cp) } return result } // SetAICoreConfig sets the ai-core URL and internal token for presence notifications. func (h *Hub) SetAICoreConfig(url, token string) { h.aiCoreURL = url h.internalToken = token } // QueueProactiveMessage queues a proactive message for offline delivery. func (h *Hub) QueueProactiveMessage(userID string, msg json.RawMessage) { h.mu.Lock() defer h.mu.Unlock() h.pendingProactive[userID] = append(h.pendingProactive[userID], msg) // Keep only the most recent 3 messages if len(h.pendingProactive[userID]) > 3 { h.pendingProactive[userID] = h.pendingProactive[userID][1:] } } // FlushPendingProactive returns and clears queued proactive messages for a user. func (h *Hub) FlushPendingProactive(userID string) []json.RawMessage { h.mu.Lock() defer h.mu.Unlock() msgs := h.pendingProactive[userID] delete(h.pendingProactive, userID) return msgs } // notifyAICorePresence sends a presence update to ai-core. func (h *Hub) notifyAICorePresence(userID, status, sessionID string) { if h.aiCoreURL == "" || h.internalToken == "" { return } body, _ := json.Marshal(map[string]string{ "user_id": userID, "status": status, "session_id": sessionID, "timestamp": fmt.Sprintf("%d", time.Now().Unix()), }) go func() { req, _ := http.NewRequest("POST", h.aiCoreURL+"/api/v1/internal/presence", strings.NewReader(string(body))) req.Header.Set("Content-Type", "application/json") req.Header.Set("X-Internal-Token", h.internalToken) client := &http.Client{Timeout: 5 * time.Second} resp, err := client.Do(req) if err != nil { logger.Printf("[presence] 通知 ai-core 失败: %v", err) return } resp.Body.Close() logger.Printf("[presence] 通知 ai-core: user=%s status=%s", userID, status) }() } // Run 启动Hub主循环 func (h *Hub) Run() { for { select { case client := <-h.register: h.mu.Lock() h.clients[client] = true // 用户索引 if h.userClients[client.UserID] == nil { h.userClients[client.UserID] = make(map[*Client]bool) } h.userClients[client.UserID][client] = true // 会话状态追踪:如果该session尚未存在则创建 if _, exists := h.sessions[client.SessionID]; !exists { h.sessions[client.SessionID] = &SessionState{ SessionID: client.SessionID, UserID: client.UserID, State: "idle", ConnectedAt: time.Now(), LastActivity: time.Now(), MessageCount: 0, } } // 多端客户端追踪 if client.ClientID != "" { now := time.Now() if kc, ok := h.knownClients[client.ClientID]; ok { kc.Online = true kc.LastSeenAt = now kc.DeviceName = client.DeviceName kc.UserAgent = client.UserAgent } else { h.knownClients[client.ClientID] = &KnownClient{ ClientID: client.ClientID, UserID: client.UserID, DeviceName: client.DeviceName, UserAgent: client.UserAgent, Online: true, LastSeenAt: now, FirstSeenAt: now, } } } // 持久化客户端记录到数据库 if client.ClientID != "" && h.store != nil && h.store.IsAvailable() { if err := h.store.UpsertClient(client.ClientID, client.UserID, client.DeviceName, client.UserAgent); err != nil { logger.Printf("[WS] 持久化客户端记录失败: %v", err) } } // Phase 2: 检测是否为重连 (之前处于离线状态) wasOffline := len(h.userClients[client.UserID]) == 1 // 刚加入,之前为0 h.mu.Unlock() // 重连后推送积压的主动消息 if wasOffline { pending := h.FlushPendingProactive(client.UserID) if len(pending) > 0 { logger.Printf("[proactive] 推送 %d 条积压消息给重连用户 %s", len(pending), client.UserID) // 只推送最新的一条 go func() { // small delay for WS connection to stabilize time.Sleep(500 * time.Millisecond) h.SendToUser(client.UserID, pending[len(pending)-1]) }() } } // 通知 ai-core 用户上线 h.notifyAICorePresence(client.UserID, "online", client.SessionID) logger.Printf("[WS] 客户端连接: user=%s session=%s (当前连接数: %d)", client.UserID, client.SessionID, len(h.clients)) case client := <-h.unregister: h.mu.Lock() if _, ok := h.clients[client]; ok { delete(h.clients, client) close(client.Send) // 清理用户索引 if h.userClients[client.UserID] != nil { delete(h.userClients[client.UserID], client) if len(h.userClients[client.UserID]) == 0 { delete(h.userClients, client.UserID) } } // 检查该session是否还有其他连接,没有则标记为 idle 而非删除 hasOtherConn := false if clients, ok := h.userClients[client.UserID]; ok { for c := range clients { if c.SessionID == client.SessionID { hasOtherConn = true break } } } if !hasOtherConn { // 不再删除 session 状态,而是标记为 idle 保留在内存中 if s, ok := h.sessions[client.SessionID]; ok { s.State = "idle" } } // 多端客户端追踪: 检查同一 clientID 是否还有其他连接 if client.ClientID != "" { hasOtherClientConn := false for c := range h.clients { if c.ClientID == client.ClientID { hasOtherClientConn = true break } } if !hasOtherClientConn { if kc, ok := h.knownClients[client.ClientID]; ok { kc.Online = false kc.LastSeenAt = time.Now() } } } } h.mu.Unlock() logger.Printf("[WS] 客户端断开: user=%s session=%s (当前连接数: %d)", client.UserID, client.SessionID, len(h.clients)) case message := <-h.broadcast: // 两阶段广播:Phase 1 在 RLock 下收集失效客户端,Phase 2 在 Lock 下清理 var staleClients []*Client h.mu.RLock() for client := range h.clients { select { case client.Send <- message: default: // 客户端发送通道已满,标记为失效 staleClients = append(staleClients, client) } } h.mu.RUnlock() // Phase 2: 在写锁下清理失效客户端 if len(staleClients) > 0 { h.mu.Lock() for _, client := range staleClients { // 二次检查:客户端可能已被 unregister 移除 if _, ok := h.clients[client]; !ok { continue } delete(h.clients, client) close(client.Send) // 清理用户索引 if h.userClients[client.UserID] != nil { delete(h.userClients[client.UserID], client) if len(h.userClients[client.UserID]) == 0 { delete(h.userClients, client.UserID) } } // 检查该 session 是否还有其他连接 hasOtherConn := false if clients, ok := h.userClients[client.UserID]; ok { for c := range clients { if c.SessionID == client.SessionID { hasOtherConn = true break } } } if !hasOtherConn { if s, ok := h.sessions[client.SessionID]; ok { s.State = "idle" } } // 多端客户端追踪 if client.ClientID != "" { hasOtherClientConn := false for c := range h.clients { if c.ClientID == client.ClientID { hasOtherClientConn = true break } } if !hasOtherClientConn { if kc, ok := h.knownClients[client.ClientID]; ok { kc.Online = false kc.LastSeenAt = time.Now() } } } } h.mu.Unlock() logger.Printf("[WS] 广播清理 %d 个失效客户端 (当前连接数: %d)", len(staleClients), len(h.clients)) } } } } // Register 注册客户端到Hub(供外部包使用) func (h *Hub) Register(client *Client) { h.register <- client } // SendToUser 向指定用户的所有连接发送消息 func (h *Hub) SendToUser(userID string, message []byte) { h.mu.RLock() defer h.mu.RUnlock() if clients, ok := h.userClients[userID]; ok { for client := range clients { select { case client.Send <- message: default: // 跳过阻塞的客户端 } } } } // SendToUserExcept 向指定用户的所有连接发送消息,排除指定 clientID func (h *Hub) SendToUserExcept(userID, excludeClientID string, message []byte) { h.mu.RLock() defer h.mu.RUnlock() if clients, ok := h.userClients[userID]; ok { for client := range clients { if client.ClientID == excludeClientID { continue } select { case client.Send <- message: default: // 跳过阻塞的客户端 } } } } // SendToSession 向指定会话的连接发送消息 func (h *Hub) SendToSession(userID, sessionID string, message []byte) { h.mu.RLock() defer h.mu.RUnlock() if clients, ok := h.userClients[userID]; ok { for client := range clients { if client.SessionID == sessionID { select { case client.Send <- message: default: } } } } } // BroadcastToAll 向所有连接的客户端广播消息 func (h *Hub) BroadcastToAll(message []byte) { h.mu.RLock() defer h.mu.RUnlock() for client := range h.clients { select { case client.Send <- message: default: // 跳过阻塞的客户端 } } } // ClientCount 获取当前连接数 func (h *Hub) ClientCount() int { h.mu.RLock() defer h.mu.RUnlock() return len(h.clients) } // UserClientCount 获取指定用户的连接数 func (h *Hub) UserClientCount(userID string) int { h.mu.RLock() defer h.mu.RUnlock() if clients, ok := h.userClients[userID]; ok { return len(clients) } return 0 } // GetActiveSessions 返回所有活跃会话的列表 func (h *Hub) GetActiveSessions() []*SessionState { h.mu.RLock() defer h.mu.RUnlock() // 即使没有活跃连接也返回空列表而非 nil if h.sessions == nil || len(h.sessions) == 0 { return []*SessionState{} } result := make([]*SessionState, 0, len(h.sessions)) for _, s := range h.sessions { // 返回副本避免外部修改 cp := *s // 不包含 recent_messages 在列表接口中 cp.RecentMessages = nil result = append(result, &cp) } return result } // GetActiveSessionsByUser 返回按用户分组的活跃会话列表 func (h *Hub) GetActiveSessionsByUser() map[string][]*SessionState { h.mu.RLock() defer h.mu.RUnlock() // 即使没有活跃连接也返回空 map 而非 nil result := make(map[string][]*SessionState) if h.sessions == nil { return result } for _, s := range h.sessions { cp := *s cp.RecentMessages = nil result[s.UserID] = append(result[s.UserID], &cp) } return result } // GetUserSessions 获取某用户的所有活跃会话 func (h *Hub) GetUserSessions(userID string) []*SessionState { h.mu.RLock() defer h.mu.RUnlock() var result []*SessionState if clients, ok := h.userClients[userID]; ok { seen := make(map[string]bool) for c := range clients { if !seen[c.SessionID] { if s, ok := h.sessions[c.SessionID]; ok { cp := *s cp.RecentMessages = nil result = append(result, &cp) seen[c.SessionID] = true } } } } return result } // GetSession 返回指定会话的详细信息(含最近消息) func (h *Hub) GetSession(sessionID string) *SessionState { h.mu.RLock() defer h.mu.RUnlock() s, ok := h.sessions[sessionID] if !ok { return nil } // 返回副本 cp := *s if s.RecentMessages != nil { cp.RecentMessages = make([]SessionMessage, len(s.RecentMessages)) copy(cp.RecentMessages, s.RecentMessages) } return &cp } // UpdateSessionState 更新会话状态 func (h *Hub) UpdateSessionState(sessionID, state string) { h.mu.Lock() defer h.mu.Unlock() if s, ok := h.sessions[sessionID]; ok { s.State = state s.LastActivity = time.Now() } } // RecordMessage 记录消息到会话 func (h *Hub) RecordMessage(sessionID, role, content string) { h.mu.Lock() defer h.mu.Unlock() s, ok := h.sessions[sessionID] if !ok { return } s.MessageCount++ s.LastActivity = time.Now() // 截断内容到 100 字符 runes := []rune(content) if len(runes) > 100 { content = string(runes[:100]) + "..." } s.RecentMessages = append(s.RecentMessages, SessionMessage{ Role: role, Content: content, Timestamp: time.Now().UnixMilli(), }) // 只保留最近 N 条消息 if len(s.RecentMessages) > maxRecentMessages { s.RecentMessages = s.RecentMessages[len(s.RecentMessages)-maxRecentMessages:] } } // ========== IoT 设备广播 ========== // StartIoTBroadcast 启动 IoT 设备轮询并向所有客户端广播 func (h *Hub) StartIoTBroadcast(iotServiceURL string) { h.mu.Lock() if h.iotPollRunning { h.mu.Unlock() return } h.iotServiceURL = iotServiceURL h.iotStopCh = make(chan struct{}) h.iotPollRunning = true h.mu.Unlock() go func() { defer func() { if r := recover(); r != nil { logger.Printf("[IoT广播] 轮询循环 panic 恢复: %v", r) } }() h.iotPollLoop() }() logger.Printf("[IoT广播] 已启动 (IoT服务地址: %s)", iotServiceURL) } // StopIoTBroadcast 停止 IoT 设备广播 func (h *Hub) StopIoTBroadcast() { h.mu.Lock() defer h.mu.Unlock() if !h.iotPollRunning { return } close(h.iotStopCh) h.iotPollRunning = false logger.Println("[IoT广播] 已停止") } // iotPollLoop IoT 设备轮询循环 func (h *Hub) iotPollLoop() { // 首次立即推送 h.pollAndBroadcastIoT() ticker := time.NewTicker(10 * time.Second) // 每10秒轮询一次 defer ticker.Stop() for { select { case <-h.iotStopCh: return case <-ticker.C: h.pollAndBroadcastIoT() } } } // pollAndBroadcastIoT 从 IoT 调试服务获取设备状态并广播 func (h *Hub) pollAndBroadcastIoT() { if h.ClientCount() == 0 { return // 没有客户端连接时跳过 } h.mu.RLock() url := h.iotServiceURL h.mu.RUnlock() if url == "" { // 向后兼容:优先使用 IOT_SERVICE_URL,回退到 IOT_DEBUG_SERVICE_URL url = getEnv("IOT_SERVICE_URL", "") if url == "" { url = getEnv("IOT_DEBUG_SERVICE_URL", "http://localhost:8083") } } devices, err := fetchIoTDevices(url) if err != nil { logger.Printf("[IoT广播] 获取设备失败: %v", err) // 即使失败也发送空列表,让前端知道 IoT 服务状态 devices = []IotDeviceInfo{} } msg := ServerMessage{ Type: "device_update", Devices: devices, Timestamp: time.Now().UnixMilli(), } data, err := json.Marshal(msg) if err != nil { logger.Printf("[IoT广播] 消息序列化失败: %v", err) return } h.BroadcastToAll(data) deviceNames := make([]string, 0, len(devices)) for _, d := range devices { deviceNames = append(deviceNames, d.Name) } logger.Printf("[IoT广播] 已推送 %d 个设备状态到 %d 个客户端: %v", len(devices), h.ClientCount(), deviceNames) } // fetchIoTDevices 从 IoT 调试服务获取设备列表 func fetchIoTDevices(serviceURL string) ([]IotDeviceInfo, error) { client := &http.Client{Timeout: 5 * time.Second} resp, err := client.Get(serviceURL + "/api/v1/devices") if err != nil { return nil, fmt.Errorf("请求IoT服务失败: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("IoT服务返回状态码 %d", resp.StatusCode) } var result struct { Devices []IotDeviceInfo `json:"devices"` Total int `json:"total"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, fmt.Errorf("解析IoT设备列表失败: %w", err) } return result.Devices, nil } func getEnv(key, fallback string) string { if v := os.Getenv(key); v != "" { return v } return fallback } // ========== 对话缓存方法 ========== const maxConversationCache = 50 // cacheKey 生成对话缓存 key func cacheKey(userID, sessionID string) string { return fmt.Sprintf("%s:%s", userID, sessionID) } // CacheMessage 缓存单条消息到对话历史 // 应对外暴露:由 gateway chat handler 在收到用户消息和 AI 回复时调用 func (h *Hub) CacheMessage(userID, sessionID string, msg Message) { key := cacheKey(userID, sessionID) h.convCacheMu.Lock() defer h.convCacheMu.Unlock() existing, _ := h.conversationCache.Load(key) var messages []Message if existing != nil { messages = existing.([]Message) } messages = append(messages, msg) // 限制缓存消息数量上限 if len(messages) > maxConversationCache { messages = messages[len(messages)-maxConversationCache:] } h.conversationCache.Store(key, messages) } // GetConversation 获取完整对话历史 func (h *Hub) GetConversation(userID, sessionID string) []Message { key := cacheKey(userID, sessionID) val, ok := h.conversationCache.Load(key) if !ok { return []Message{} } messages, ok := val.([]Message) if !ok { return []Message{} } return messages } // GetSessionHistory 获取会话历史消息(限制条数) func (h *Hub) GetSessionHistory(userID, sessionID string, limit int) []Message { messages := h.GetConversation(userID, sessionID) if len(messages) == 0 { return []Message{} } if limit > 0 && len(messages) > limit { start := len(messages) - limit return messages[start:] } return messages } // DeleteConversation 删除对话缓存 func (h *Hub) DeleteConversation(userID, sessionID string) { key := cacheKey(userID, sessionID) h.conversationCache.Delete(key) } // ========== 多端客户端追踪 ========== // GetKnownClients returns all known clients (online + offline). func (h *Hub) GetKnownClients(userID string) []KnownClient { h.mu.RLock() defer h.mu.RUnlock() result := make([]KnownClient, 0) for _, kc := range h.knownClients { if userID == "" || kc.UserID == userID { cp := *kc cp.UserAgent = "" // don't leak UA in list result = append(result, cp) } } return result } // UpdateClientNote sets a user-defined note/label on a client. func (h *Hub) UpdateClientNote(clientID, note string) bool { h.mu.Lock() defer h.mu.Unlock() kc, ok := h.knownClients[clientID] if !ok { return false } kc.Note = note return true } // ClientInfo returns the ClientInfo for a given client. func (h *Hub) ClientInfo(clientID string) *ClientInfo { h.mu.RLock() defer h.mu.RUnlock() kc, ok := h.knownClients[clientID] if !ok { return nil } return &ClientInfo{ ClientID: kc.ClientID, DeviceName: kc.DeviceName, UserAgent: kc.UserAgent, } } // buildClientInfo builds a ClientInfo from a Client. func buildClientInfo(c *Client) *ClientInfo { if c.ClientID == "" { return nil } return &ClientInfo{ ClientID: c.ClientID, DeviceName: c.DeviceName, UserAgent: c.UserAgent, } }