package ws import ( "fmt" "log" "sync" "time" ) // SessionState 会话状态 type SessionState struct { SessionID string `json:"session_id"` UserID string `json:"user_id"` State string `json:"state"` // idle, thinking, streaming, error ConnectedAt time.Time `json:"connected_at"` LastActivity time.Time `json:"last_activity"` MessageCount int `json:"message_count"` RecentMessages []SessionMessage `json:"recent_messages,omitempty"` } // SessionMessage 会话消息记录 type SessionMessage struct { Role string `json:"role"` // user, assistant, system Content string `json:"content"` // 截断到 100 字符 Timestamp int64 `json:"timestamp"` } // Message 完整对话消息(用于缓存) type Message struct { Role string `json:"role"` Content string `json:"content"` Timestamp int64 `json:"timestamp"` } const maxRecentMessages = 20 // Hub WebSocket连接池 type Hub struct { mu sync.RWMutex clients map[*Client]bool broadcast chan []byte register chan *Client unregister chan *Client // 按用户ID索引的客户端映射 userClients map[string]map[*Client]bool // 会话状态追踪 (sessionID -> SessionState) sessions map[string]*SessionState // 对话缓存:key = "userID:sessionID", value = []Message conversationCache sync.Map convCacheMu sync.Mutex } // NewHub 创建WebSocket Hub func NewHub() *Hub { return &Hub{ clients: make(map[*Client]bool), broadcast: make(chan []byte, 256), register: make(chan *Client), unregister: make(chan *Client), userClients: make(map[string]map[*Client]bool), sessions: make(map[string]*SessionState), } } // Run 启动Hub主循环 func (h *Hub) Run() { for { select { case client := <-h.register: h.mu.Lock() h.clients[client] = true // 用户索引 if h.userClients[client.UserID] == nil { h.userClients[client.UserID] = make(map[*Client]bool) } h.userClients[client.UserID][client] = true // 会话状态追踪:如果该session尚未存在则创建 if _, exists := h.sessions[client.SessionID]; !exists { h.sessions[client.SessionID] = &SessionState{ SessionID: client.SessionID, UserID: client.UserID, State: "idle", ConnectedAt: time.Now(), LastActivity: time.Now(), MessageCount: 0, } } h.mu.Unlock() log.Printf("[WS] 客户端连接: user=%s session=%s (当前连接数: %d)", client.UserID, client.SessionID, len(h.clients)) case client := <-h.unregister: h.mu.Lock() if _, ok := h.clients[client]; ok { delete(h.clients, client) close(client.Send) // 清理用户索引 if h.userClients[client.UserID] != nil { delete(h.userClients[client.UserID], client) if len(h.userClients[client.UserID]) == 0 { delete(h.userClients, client.UserID) } } // 检查该session是否还有其他连接,没有则移除会话状态 hasOtherConn := false if clients, ok := h.userClients[client.UserID]; ok { for c := range clients { if c.SessionID == client.SessionID { hasOtherConn = true break } } } if !hasOtherConn { delete(h.sessions, client.SessionID) } } h.mu.Unlock() log.Printf("[WS] 客户端断开: user=%s session=%s (当前连接数: %d)", client.UserID, client.SessionID, len(h.clients)) case message := <-h.broadcast: h.mu.RLock() for client := range h.clients { select { case client.Send <- message: default: // 客户端发送通道已满,跳过 close(client.Send) delete(h.clients, client) } } h.mu.RUnlock() } } } // Register 注册客户端到Hub(供外部包使用) func (h *Hub) Register(client *Client) { h.register <- client } // SendToUser 向指定用户的所有连接发送消息 func (h *Hub) SendToUser(userID string, message []byte) { h.mu.RLock() defer h.mu.RUnlock() if clients, ok := h.userClients[userID]; ok { for client := range clients { select { case client.Send <- message: default: // 跳过阻塞的客户端 } } } } // SendToSession 向指定会话的连接发送消息 func (h *Hub) SendToSession(userID, sessionID string, message []byte) { h.mu.RLock() defer h.mu.RUnlock() if clients, ok := h.userClients[userID]; ok { for client := range clients { if client.SessionID == sessionID { select { case client.Send <- message: default: } } } } } // ClientCount 获取当前连接数 func (h *Hub) ClientCount() int { h.mu.RLock() defer h.mu.RUnlock() return len(h.clients) } // UserClientCount 获取指定用户的连接数 func (h *Hub) UserClientCount(userID string) int { h.mu.RLock() defer h.mu.RUnlock() if clients, ok := h.userClients[userID]; ok { return len(clients) } return 0 } // GetActiveSessions 返回所有活跃会话的列表 func (h *Hub) GetActiveSessions() []*SessionState { h.mu.RLock() defer h.mu.RUnlock() result := make([]*SessionState, 0, len(h.sessions)) for _, s := range h.sessions { // 返回副本避免外部修改 cp := *s // 不包含 recent_messages 在列表接口中 cp.RecentMessages = nil result = append(result, &cp) } return result } // GetActiveSessionsByUser 返回按用户分组的活跃会话列表 func (h *Hub) GetActiveSessionsByUser() map[string][]*SessionState { h.mu.RLock() defer h.mu.RUnlock() result := make(map[string][]*SessionState) for _, s := range h.sessions { cp := *s cp.RecentMessages = nil result[s.UserID] = append(result[s.UserID], &cp) } return result } // GetUserSessions 获取某用户的所有活跃会话 func (h *Hub) GetUserSessions(userID string) []*SessionState { h.mu.RLock() defer h.mu.RUnlock() var result []*SessionState if clients, ok := h.userClients[userID]; ok { seen := make(map[string]bool) for c := range clients { if !seen[c.SessionID] { if s, ok := h.sessions[c.SessionID]; ok { cp := *s cp.RecentMessages = nil result = append(result, &cp) seen[c.SessionID] = true } } } } return result } // GetSession 返回指定会话的详细信息(含最近消息) func (h *Hub) GetSession(sessionID string) *SessionState { h.mu.RLock() defer h.mu.RUnlock() s, ok := h.sessions[sessionID] if !ok { return nil } // 返回副本 cp := *s if s.RecentMessages != nil { cp.RecentMessages = make([]SessionMessage, len(s.RecentMessages)) copy(cp.RecentMessages, s.RecentMessages) } return &cp } // UpdateSessionState 更新会话状态 func (h *Hub) UpdateSessionState(sessionID, state string) { h.mu.Lock() defer h.mu.Unlock() if s, ok := h.sessions[sessionID]; ok { s.State = state s.LastActivity = time.Now() } } // RecordMessage 记录消息到会话 func (h *Hub) RecordMessage(sessionID, role, content string) { h.mu.Lock() defer h.mu.Unlock() s, ok := h.sessions[sessionID] if !ok { return } s.MessageCount++ s.LastActivity = time.Now() // 截断内容到 100 字符 runes := []rune(content) if len(runes) > 100 { content = string(runes[:100]) + "..." } s.RecentMessages = append(s.RecentMessages, SessionMessage{ Role: role, Content: content, Timestamp: time.Now().UnixMilli(), }) // 只保留最近 N 条消息 if len(s.RecentMessages) > maxRecentMessages { s.RecentMessages = s.RecentMessages[len(s.RecentMessages)-maxRecentMessages:] } } // ========== 对话缓存方法 ========== // cacheKey 生成对话缓存 key func cacheKey(userID, sessionID string) string { return fmt.Sprintf("%s:%s", userID, sessionID) } // CacheMessage 缓存单条消息到对话历史 func (h *Hub) CacheMessage(userID, sessionID string, msg Message) { key := cacheKey(userID, sessionID) h.convCacheMu.Lock() defer h.convCacheMu.Unlock() existing, _ := h.conversationCache.Load(key) var messages []Message if existing != nil { messages = existing.([]Message) } messages = append(messages, msg) h.conversationCache.Store(key, messages) } // GetConversation 获取完整对话历史 func (h *Hub) GetConversation(userID, sessionID string) []Message { key := cacheKey(userID, sessionID) val, ok := h.conversationCache.Load(key) if !ok { return []Message{} } messages, ok := val.([]Message) if !ok { return []Message{} } return messages } // DeleteConversation 删除对话缓存 func (h *Hub) DeleteConversation(userID, sessionID string) { key := cacheKey(userID, sessionID) h.conversationCache.Delete(key) }