Files
Cyrene/backend/gateway/internal/ws/hub.go
T
AskaEth 186513f381 feat: 多功能升级 — 流式逐字渲染、对话缓存、会话组织优化、记忆管理修复、性能仪表盘
- 前端消息流式逐字渲染 (AI-Core ChatStream → SSE → Gateway → WebSocket stream_chunk → fadeInUp + cursorBlink)
- 后端对话缓存 (conversationCache sync.Map, GET /sessions/:id/messages)
- 前端侧边栏历史多轮对话显示
- DevTools 性能监控图标移至首页仪表盘
- DevTools 用户记忆查询/删减功能修复 (补全 DELETE 数据链路)
- 后端和 DevTools 按用户分类组织实时活动会话 (map[userID]map[sessionID]*Client)
- 新增 docs/api-reference/ 路由参考文档
- 新增 docs/message-flow-architecture.md 消息链路架构文档
2026-05-16 17:44:03 +08:00

359 lines
8.4 KiB
Go

package ws
import (
"fmt"
"log"
"sync"
"time"
)
// SessionState 会话状态
type SessionState struct {
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
State string `json:"state"` // idle, thinking, streaming, error
ConnectedAt time.Time `json:"connected_at"`
LastActivity time.Time `json:"last_activity"`
MessageCount int `json:"message_count"`
RecentMessages []SessionMessage `json:"recent_messages,omitempty"`
}
// SessionMessage 会话消息记录
type SessionMessage struct {
Role string `json:"role"` // user, assistant, system
Content string `json:"content"` // 截断到 100 字符
Timestamp int64 `json:"timestamp"`
}
// Message 完整对话消息(用于缓存)
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Timestamp int64 `json:"timestamp"`
}
const maxRecentMessages = 20
// Hub WebSocket连接池
type Hub struct {
mu sync.RWMutex
clients map[*Client]bool
broadcast chan []byte
register chan *Client
unregister chan *Client
// 按用户ID索引的客户端映射
userClients map[string]map[*Client]bool
// 会话状态追踪 (sessionID -> SessionState)
sessions map[string]*SessionState
// 对话缓存:key = "userID:sessionID", value = []Message
conversationCache sync.Map
convCacheMu sync.Mutex
}
// NewHub 创建WebSocket Hub
func NewHub() *Hub {
return &Hub{
clients: make(map[*Client]bool),
broadcast: make(chan []byte, 256),
register: make(chan *Client),
unregister: make(chan *Client),
userClients: make(map[string]map[*Client]bool),
sessions: make(map[string]*SessionState),
}
}
// Run 启动Hub主循环
func (h *Hub) Run() {
for {
select {
case client := <-h.register:
h.mu.Lock()
h.clients[client] = true
// 用户索引
if h.userClients[client.UserID] == nil {
h.userClients[client.UserID] = make(map[*Client]bool)
}
h.userClients[client.UserID][client] = true
// 会话状态追踪:如果该session尚未存在则创建
if _, exists := h.sessions[client.SessionID]; !exists {
h.sessions[client.SessionID] = &SessionState{
SessionID: client.SessionID,
UserID: client.UserID,
State: "idle",
ConnectedAt: time.Now(),
LastActivity: time.Now(),
MessageCount: 0,
}
}
h.mu.Unlock()
log.Printf("[WS] 客户端连接: user=%s session=%s (当前连接数: %d)",
client.UserID, client.SessionID, len(h.clients))
case client := <-h.unregister:
h.mu.Lock()
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
close(client.Send)
// 清理用户索引
if h.userClients[client.UserID] != nil {
delete(h.userClients[client.UserID], client)
if len(h.userClients[client.UserID]) == 0 {
delete(h.userClients, client.UserID)
}
}
// 检查该session是否还有其他连接,没有则移除会话状态
hasOtherConn := false
if clients, ok := h.userClients[client.UserID]; ok {
for c := range clients {
if c.SessionID == client.SessionID {
hasOtherConn = true
break
}
}
}
if !hasOtherConn {
delete(h.sessions, client.SessionID)
}
}
h.mu.Unlock()
log.Printf("[WS] 客户端断开: user=%s session=%s (当前连接数: %d)",
client.UserID, client.SessionID, len(h.clients))
case message := <-h.broadcast:
h.mu.RLock()
for client := range h.clients {
select {
case client.Send <- message:
default:
// 客户端发送通道已满,跳过
close(client.Send)
delete(h.clients, client)
}
}
h.mu.RUnlock()
}
}
}
// Register 注册客户端到Hub(供外部包使用)
func (h *Hub) Register(client *Client) {
h.register <- client
}
// SendToUser 向指定用户的所有连接发送消息
func (h *Hub) SendToUser(userID string, message []byte) {
h.mu.RLock()
defer h.mu.RUnlock()
if clients, ok := h.userClients[userID]; ok {
for client := range clients {
select {
case client.Send <- message:
default:
// 跳过阻塞的客户端
}
}
}
}
// SendToSession 向指定会话的连接发送消息
func (h *Hub) SendToSession(userID, sessionID string, message []byte) {
h.mu.RLock()
defer h.mu.RUnlock()
if clients, ok := h.userClients[userID]; ok {
for client := range clients {
if client.SessionID == sessionID {
select {
case client.Send <- message:
default:
}
}
}
}
}
// ClientCount 获取当前连接数
func (h *Hub) ClientCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.clients)
}
// UserClientCount 获取指定用户的连接数
func (h *Hub) UserClientCount(userID string) int {
h.mu.RLock()
defer h.mu.RUnlock()
if clients, ok := h.userClients[userID]; ok {
return len(clients)
}
return 0
}
// GetActiveSessions 返回所有活跃会话的列表
func (h *Hub) GetActiveSessions() []*SessionState {
h.mu.RLock()
defer h.mu.RUnlock()
result := make([]*SessionState, 0, len(h.sessions))
for _, s := range h.sessions {
// 返回副本避免外部修改
cp := *s
// 不包含 recent_messages 在列表接口中
cp.RecentMessages = nil
result = append(result, &cp)
}
return result
}
// GetActiveSessionsByUser 返回按用户分组的活跃会话列表
func (h *Hub) GetActiveSessionsByUser() map[string][]*SessionState {
h.mu.RLock()
defer h.mu.RUnlock()
result := make(map[string][]*SessionState)
for _, s := range h.sessions {
cp := *s
cp.RecentMessages = nil
result[s.UserID] = append(result[s.UserID], &cp)
}
return result
}
// GetUserSessions 获取某用户的所有活跃会话
func (h *Hub) GetUserSessions(userID string) []*SessionState {
h.mu.RLock()
defer h.mu.RUnlock()
var result []*SessionState
if clients, ok := h.userClients[userID]; ok {
seen := make(map[string]bool)
for c := range clients {
if !seen[c.SessionID] {
if s, ok := h.sessions[c.SessionID]; ok {
cp := *s
cp.RecentMessages = nil
result = append(result, &cp)
seen[c.SessionID] = true
}
}
}
}
return result
}
// GetSession 返回指定会话的详细信息(含最近消息)
func (h *Hub) GetSession(sessionID string) *SessionState {
h.mu.RLock()
defer h.mu.RUnlock()
s, ok := h.sessions[sessionID]
if !ok {
return nil
}
// 返回副本
cp := *s
if s.RecentMessages != nil {
cp.RecentMessages = make([]SessionMessage, len(s.RecentMessages))
copy(cp.RecentMessages, s.RecentMessages)
}
return &cp
}
// UpdateSessionState 更新会话状态
func (h *Hub) UpdateSessionState(sessionID, state string) {
h.mu.Lock()
defer h.mu.Unlock()
if s, ok := h.sessions[sessionID]; ok {
s.State = state
s.LastActivity = time.Now()
}
}
// RecordMessage 记录消息到会话
func (h *Hub) RecordMessage(sessionID, role, content string) {
h.mu.Lock()
defer h.mu.Unlock()
s, ok := h.sessions[sessionID]
if !ok {
return
}
s.MessageCount++
s.LastActivity = time.Now()
// 截断内容到 100 字符
runes := []rune(content)
if len(runes) > 100 {
content = string(runes[:100]) + "..."
}
s.RecentMessages = append(s.RecentMessages, SessionMessage{
Role: role,
Content: content,
Timestamp: time.Now().UnixMilli(),
})
// 只保留最近 N 条消息
if len(s.RecentMessages) > maxRecentMessages {
s.RecentMessages = s.RecentMessages[len(s.RecentMessages)-maxRecentMessages:]
}
}
// ========== 对话缓存方法 ==========
// cacheKey 生成对话缓存 key
func cacheKey(userID, sessionID string) string {
return fmt.Sprintf("%s:%s", userID, sessionID)
}
// CacheMessage 缓存单条消息到对话历史
func (h *Hub) CacheMessage(userID, sessionID string, msg Message) {
key := cacheKey(userID, sessionID)
h.convCacheMu.Lock()
defer h.convCacheMu.Unlock()
existing, _ := h.conversationCache.Load(key)
var messages []Message
if existing != nil {
messages = existing.([]Message)
}
messages = append(messages, msg)
h.conversationCache.Store(key, messages)
}
// GetConversation 获取完整对话历史
func (h *Hub) GetConversation(userID, sessionID string) []Message {
key := cacheKey(userID, sessionID)
val, ok := h.conversationCache.Load(key)
if !ok {
return []Message{}
}
messages, ok := val.([]Message)
if !ok {
return []Message{}
}
return messages
}
// DeleteConversation 删除对话缓存
func (h *Hub) DeleteConversation(userID, sessionID string) {
key := cacheKey(userID, sessionID)
h.conversationCache.Delete(key)
}