4b35736f73
P0 (5): crypto/rand session ID, TTS fallback可达性, goroutine defer recover, adminAuth前缀修正 P1 (5): 普通用户密码验证, context传递, priority clamp, 超时重试, 自主思考速率限制 P2 (4): Briefing AI降级, 前端消息类型渲染, Docker Compose补全, PWA 192图标 P3 (5): goroutine错误处理, .gitignore完善, reminder created_at, voice Dockerfile, Go版本更新
644 lines
15 KiB
Go
644 lines
15 KiB
Go
package ws
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
|
)
|
|
|
|
// SessionState 会话状态
|
|
type SessionState struct {
|
|
SessionID string `json:"session_id"`
|
|
UserID string `json:"user_id"`
|
|
State string `json:"state"` // idle, thinking, streaming, error
|
|
ConnectedAt time.Time `json:"connected_at"`
|
|
LastActivity time.Time `json:"last_activity"`
|
|
MessageCount int `json:"message_count"`
|
|
RecentMessages []SessionMessage `json:"recent_messages,omitempty"`
|
|
}
|
|
|
|
// SessionMessage 会话消息记录
|
|
type SessionMessage struct {
|
|
Role string `json:"role"` // user, assistant, system
|
|
Content string `json:"content"` // 截断到 100 字符
|
|
Timestamp int64 `json:"timestamp"`
|
|
}
|
|
|
|
// Message 完整对话消息(用于缓存)
|
|
type Message struct {
|
|
ID string `json:"id"`
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
Attachments []MessageAttachment `json:"attachments,omitempty"`
|
|
Timestamp int64 `json:"timestamp"`
|
|
}
|
|
|
|
const maxRecentMessages = 20
|
|
|
|
// Hub WebSocket连接池
|
|
type Hub struct {
|
|
mu sync.RWMutex
|
|
clients map[*Client]bool
|
|
broadcast chan []byte
|
|
register chan *Client
|
|
unregister chan *Client
|
|
|
|
// 按用户ID索引的客户端映射
|
|
userClients map[string]map[*Client]bool
|
|
|
|
// 会话状态追踪 (sessionID -> SessionState)
|
|
sessions map[string]*SessionState
|
|
|
|
// 对话缓存:key = "userID:sessionID", value = []Message
|
|
conversationCache sync.Map
|
|
convCacheMu sync.Mutex
|
|
|
|
// IoT 设备广播
|
|
iotServiceURL string
|
|
iotStopCh chan struct{}
|
|
iotPollRunning bool
|
|
|
|
// 持久化存储 (可选,数据库连接失败时为 nil)
|
|
store *store.SessionStore
|
|
|
|
// 闲置超时时间
|
|
idleTimeout time.Duration
|
|
}
|
|
|
|
// SetStore 设置持久化存储 (可选)
|
|
func (h *Hub) SetStore(s *store.SessionStore) {
|
|
h.store = s
|
|
}
|
|
|
|
// SetIdleTimeout 设置闲置超时时间
|
|
func (h *Hub) SetIdleTimeout(minutes int) {
|
|
h.idleTimeout = time.Duration(minutes) * time.Minute
|
|
}
|
|
|
|
// NewHub 创建WebSocket Hub
|
|
func NewHub() *Hub {
|
|
return &Hub{
|
|
clients: make(map[*Client]bool),
|
|
broadcast: make(chan []byte, 256),
|
|
register: make(chan *Client),
|
|
unregister: make(chan *Client),
|
|
userClients: make(map[string]map[*Client]bool),
|
|
sessions: make(map[string]*SessionState),
|
|
iotStopCh: make(chan struct{}),
|
|
idleTimeout: 30 * time.Minute, // 默认30分钟
|
|
}
|
|
}
|
|
|
|
// StartIdleCleanup 启动闲置会话清理 goroutine
|
|
// 每5分钟检查一次,将超过 idleTimeout 无活动的会话标记为 idle
|
|
func (h *Hub) StartIdleCleanup() {
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("[WS] 闲置会话清理 panic 恢复: %v", r)
|
|
}
|
|
}()
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
h.cleanupIdleSessions()
|
|
}
|
|
}()
|
|
log.Printf("[WS] 闲置会话清理已启动 (超时: %v)", h.idleTimeout)
|
|
}
|
|
|
|
// cleanupIdleSessions 标记超时会话为 idle(不删除状态)
|
|
func (h *Hub) cleanupIdleSessions() {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
idleCount := 0
|
|
|
|
for sessionID, s := range h.sessions {
|
|
// 检查该 session 是否还有活跃连接
|
|
hasActiveConn := false
|
|
for _, clients := range h.userClients {
|
|
for c := range clients {
|
|
if c.SessionID == sessionID {
|
|
hasActiveConn = true
|
|
break
|
|
}
|
|
}
|
|
if hasActiveConn {
|
|
break
|
|
}
|
|
}
|
|
|
|
// 如果没有活跃连接且超过闲置超时,标记为 idle
|
|
if !hasActiveConn && now.Sub(s.LastActivity) > h.idleTimeout {
|
|
if s.State != "idle" {
|
|
s.State = "idle"
|
|
idleCount++
|
|
}
|
|
}
|
|
}
|
|
|
|
if idleCount > 0 {
|
|
log.Printf("[WS] 闲置清理: %d 个会话标记为 idle", idleCount)
|
|
}
|
|
}
|
|
|
|
// GetAllActiveSessions 返回所有会话状态(包括 idle),供 DevTools 监看使用
|
|
func (h *Hub) GetAllActiveSessions() []*SessionState {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
|
|
if h.sessions == nil || len(h.sessions) == 0 {
|
|
return []*SessionState{}
|
|
}
|
|
|
|
result := make([]*SessionState, 0, len(h.sessions))
|
|
for _, s := range h.sessions {
|
|
cp := *s
|
|
cp.RecentMessages = nil
|
|
result = append(result, &cp)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// Run 启动Hub主循环
|
|
func (h *Hub) Run() {
|
|
for {
|
|
select {
|
|
case client := <-h.register:
|
|
h.mu.Lock()
|
|
h.clients[client] = true
|
|
|
|
// 用户索引
|
|
if h.userClients[client.UserID] == nil {
|
|
h.userClients[client.UserID] = make(map[*Client]bool)
|
|
}
|
|
h.userClients[client.UserID][client] = true
|
|
|
|
// 会话状态追踪:如果该session尚未存在则创建
|
|
if _, exists := h.sessions[client.SessionID]; !exists {
|
|
h.sessions[client.SessionID] = &SessionState{
|
|
SessionID: client.SessionID,
|
|
UserID: client.UserID,
|
|
State: "idle",
|
|
ConnectedAt: time.Now(),
|
|
LastActivity: time.Now(),
|
|
MessageCount: 0,
|
|
}
|
|
}
|
|
h.mu.Unlock()
|
|
|
|
log.Printf("[WS] 客户端连接: user=%s session=%s (当前连接数: %d)",
|
|
client.UserID, client.SessionID, len(h.clients))
|
|
|
|
case client := <-h.unregister:
|
|
h.mu.Lock()
|
|
if _, ok := h.clients[client]; ok {
|
|
delete(h.clients, client)
|
|
close(client.Send)
|
|
|
|
// 清理用户索引
|
|
if h.userClients[client.UserID] != nil {
|
|
delete(h.userClients[client.UserID], client)
|
|
if len(h.userClients[client.UserID]) == 0 {
|
|
delete(h.userClients, client.UserID)
|
|
}
|
|
}
|
|
|
|
// 检查该session是否还有其他连接,没有则标记为 idle 而非删除
|
|
hasOtherConn := false
|
|
if clients, ok := h.userClients[client.UserID]; ok {
|
|
for c := range clients {
|
|
if c.SessionID == client.SessionID {
|
|
hasOtherConn = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if !hasOtherConn {
|
|
// 不再删除 session 状态,而是标记为 idle 保留在内存中
|
|
if s, ok := h.sessions[client.SessionID]; ok {
|
|
s.State = "idle"
|
|
}
|
|
}
|
|
}
|
|
h.mu.Unlock()
|
|
|
|
log.Printf("[WS] 客户端断开: user=%s session=%s (当前连接数: %d)",
|
|
client.UserID, client.SessionID, len(h.clients))
|
|
|
|
case message := <-h.broadcast:
|
|
h.mu.RLock()
|
|
for client := range h.clients {
|
|
select {
|
|
case client.Send <- message:
|
|
default:
|
|
// 客户端发送通道已满,跳过
|
|
close(client.Send)
|
|
delete(h.clients, client)
|
|
}
|
|
}
|
|
h.mu.RUnlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
// Register 注册客户端到Hub(供外部包使用)
|
|
func (h *Hub) Register(client *Client) {
|
|
h.register <- client
|
|
}
|
|
|
|
// SendToUser 向指定用户的所有连接发送消息
|
|
func (h *Hub) SendToUser(userID string, message []byte) {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
|
|
if clients, ok := h.userClients[userID]; ok {
|
|
for client := range clients {
|
|
select {
|
|
case client.Send <- message:
|
|
default:
|
|
// 跳过阻塞的客户端
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// SendToSession 向指定会话的连接发送消息
|
|
func (h *Hub) SendToSession(userID, sessionID string, message []byte) {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
|
|
if clients, ok := h.userClients[userID]; ok {
|
|
for client := range clients {
|
|
if client.SessionID == sessionID {
|
|
select {
|
|
case client.Send <- message:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// BroadcastToAll 向所有连接的客户端广播消息
|
|
func (h *Hub) BroadcastToAll(message []byte) {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
|
|
for client := range h.clients {
|
|
select {
|
|
case client.Send <- message:
|
|
default:
|
|
// 跳过阻塞的客户端
|
|
}
|
|
}
|
|
}
|
|
|
|
// ClientCount 获取当前连接数
|
|
func (h *Hub) ClientCount() int {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
return len(h.clients)
|
|
}
|
|
|
|
// UserClientCount 获取指定用户的连接数
|
|
func (h *Hub) UserClientCount(userID string) int {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
if clients, ok := h.userClients[userID]; ok {
|
|
return len(clients)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// GetActiveSessions 返回所有活跃会话的列表
|
|
func (h *Hub) GetActiveSessions() []*SessionState {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
|
|
// 即使没有活跃连接也返回空列表而非 nil
|
|
if h.sessions == nil || len(h.sessions) == 0 {
|
|
return []*SessionState{}
|
|
}
|
|
|
|
result := make([]*SessionState, 0, len(h.sessions))
|
|
for _, s := range h.sessions {
|
|
// 返回副本避免外部修改
|
|
cp := *s
|
|
// 不包含 recent_messages 在列表接口中
|
|
cp.RecentMessages = nil
|
|
result = append(result, &cp)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetActiveSessionsByUser 返回按用户分组的活跃会话列表
|
|
func (h *Hub) GetActiveSessionsByUser() map[string][]*SessionState {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
|
|
// 即使没有活跃连接也返回空 map 而非 nil
|
|
result := make(map[string][]*SessionState)
|
|
|
|
if h.sessions == nil {
|
|
return result
|
|
}
|
|
|
|
for _, s := range h.sessions {
|
|
cp := *s
|
|
cp.RecentMessages = nil
|
|
result[s.UserID] = append(result[s.UserID], &cp)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetUserSessions 获取某用户的所有活跃会话
|
|
func (h *Hub) GetUserSessions(userID string) []*SessionState {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
|
|
var result []*SessionState
|
|
if clients, ok := h.userClients[userID]; ok {
|
|
seen := make(map[string]bool)
|
|
for c := range clients {
|
|
if !seen[c.SessionID] {
|
|
if s, ok := h.sessions[c.SessionID]; ok {
|
|
cp := *s
|
|
cp.RecentMessages = nil
|
|
result = append(result, &cp)
|
|
seen[c.SessionID] = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// GetSession 返回指定会话的详细信息(含最近消息)
|
|
func (h *Hub) GetSession(sessionID string) *SessionState {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
|
|
s, ok := h.sessions[sessionID]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
// 返回副本
|
|
cp := *s
|
|
if s.RecentMessages != nil {
|
|
cp.RecentMessages = make([]SessionMessage, len(s.RecentMessages))
|
|
copy(cp.RecentMessages, s.RecentMessages)
|
|
}
|
|
return &cp
|
|
}
|
|
|
|
// UpdateSessionState 更新会话状态
|
|
func (h *Hub) UpdateSessionState(sessionID, state string) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
if s, ok := h.sessions[sessionID]; ok {
|
|
s.State = state
|
|
s.LastActivity = time.Now()
|
|
}
|
|
}
|
|
|
|
// RecordMessage 记录消息到会话
|
|
func (h *Hub) RecordMessage(sessionID, role, content string) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
s, ok := h.sessions[sessionID]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
s.MessageCount++
|
|
s.LastActivity = time.Now()
|
|
|
|
// 截断内容到 100 字符
|
|
runes := []rune(content)
|
|
if len(runes) > 100 {
|
|
content = string(runes[:100]) + "..."
|
|
}
|
|
|
|
s.RecentMessages = append(s.RecentMessages, SessionMessage{
|
|
Role: role,
|
|
Content: content,
|
|
Timestamp: time.Now().UnixMilli(),
|
|
})
|
|
|
|
// 只保留最近 N 条消息
|
|
if len(s.RecentMessages) > maxRecentMessages {
|
|
s.RecentMessages = s.RecentMessages[len(s.RecentMessages)-maxRecentMessages:]
|
|
}
|
|
}
|
|
|
|
// ========== IoT 设备广播 ==========
|
|
|
|
// StartIoTBroadcast 启动 IoT 设备轮询并向所有客户端广播
|
|
func (h *Hub) StartIoTBroadcast(iotServiceURL string) {
|
|
h.mu.Lock()
|
|
if h.iotPollRunning {
|
|
h.mu.Unlock()
|
|
return
|
|
}
|
|
h.iotServiceURL = iotServiceURL
|
|
h.iotStopCh = make(chan struct{})
|
|
h.iotPollRunning = true
|
|
h.mu.Unlock()
|
|
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
log.Printf("[IoT广播] 轮询循环 panic 恢复: %v", r)
|
|
}
|
|
}()
|
|
h.iotPollLoop()
|
|
}()
|
|
log.Printf("[IoT广播] 已启动 (IoT服务地址: %s)", iotServiceURL)
|
|
}
|
|
|
|
// StopIoTBroadcast 停止 IoT 设备广播
|
|
func (h *Hub) StopIoTBroadcast() {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
if !h.iotPollRunning {
|
|
return
|
|
}
|
|
close(h.iotStopCh)
|
|
h.iotPollRunning = false
|
|
log.Println("[IoT广播] 已停止")
|
|
}
|
|
|
|
// iotPollLoop IoT 设备轮询循环
|
|
func (h *Hub) iotPollLoop() {
|
|
// 首次立即推送
|
|
h.pollAndBroadcastIoT()
|
|
|
|
ticker := time.NewTicker(10 * time.Second) // 每10秒轮询一次
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-h.iotStopCh:
|
|
return
|
|
case <-ticker.C:
|
|
h.pollAndBroadcastIoT()
|
|
}
|
|
}
|
|
}
|
|
|
|
// pollAndBroadcastIoT 从 IoT 调试服务获取设备状态并广播
|
|
func (h *Hub) pollAndBroadcastIoT() {
|
|
if h.ClientCount() == 0 {
|
|
return // 没有客户端连接时跳过
|
|
}
|
|
|
|
h.mu.RLock()
|
|
url := h.iotServiceURL
|
|
h.mu.RUnlock()
|
|
|
|
if url == "" {
|
|
url = getEnv("IOT_DEBUG_SERVICE_URL", "http://localhost:8083")
|
|
}
|
|
|
|
devices, err := fetchIoTDevices(url)
|
|
if err != nil {
|
|
log.Printf("[IoT广播] 获取设备失败: %v", err)
|
|
// 即使失败也发送空列表,让前端知道 IoT 服务状态
|
|
devices = []IotDeviceInfo{}
|
|
}
|
|
|
|
msg := ServerMessage{
|
|
Type: "device_update",
|
|
Devices: devices,
|
|
Timestamp: time.Now().UnixMilli(),
|
|
}
|
|
|
|
data, err := json.Marshal(msg)
|
|
if err != nil {
|
|
log.Printf("[IoT广播] 消息序列化失败: %v", err)
|
|
return
|
|
}
|
|
|
|
h.BroadcastToAll(data)
|
|
|
|
deviceNames := make([]string, 0, len(devices))
|
|
for _, d := range devices {
|
|
deviceNames = append(deviceNames, d.Name)
|
|
}
|
|
log.Printf("[IoT广播] 已推送 %d 个设备状态到 %d 个客户端: %v", len(devices), h.ClientCount(), deviceNames)
|
|
}
|
|
|
|
// fetchIoTDevices 从 IoT 调试服务获取设备列表
|
|
func fetchIoTDevices(serviceURL string) ([]IotDeviceInfo, error) {
|
|
client := &http.Client{Timeout: 5 * time.Second}
|
|
|
|
resp, err := client.Get(serviceURL + "/api/v1/devices")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("请求IoT服务失败: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("IoT服务返回状态码 %d", resp.StatusCode)
|
|
}
|
|
|
|
var result struct {
|
|
Devices []IotDeviceInfo `json:"devices"`
|
|
Total int `json:"total"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
return nil, fmt.Errorf("解析IoT设备列表失败: %w", err)
|
|
}
|
|
|
|
return result.Devices, nil
|
|
}
|
|
|
|
func getEnv(key, fallback string) string {
|
|
if v := os.Getenv(key); v != "" {
|
|
return v
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
// ========== 对话缓存方法 ==========
|
|
|
|
const maxConversationCache = 50
|
|
|
|
// cacheKey 生成对话缓存 key
|
|
func cacheKey(userID, sessionID string) string {
|
|
return fmt.Sprintf("%s:%s", userID, sessionID)
|
|
}
|
|
|
|
// CacheMessage 缓存单条消息到对话历史
|
|
// 应对外暴露:由 gateway chat handler 在收到用户消息和 AI 回复时调用
|
|
func (h *Hub) CacheMessage(userID, sessionID string, msg Message) {
|
|
key := cacheKey(userID, sessionID)
|
|
|
|
h.convCacheMu.Lock()
|
|
defer h.convCacheMu.Unlock()
|
|
|
|
existing, _ := h.conversationCache.Load(key)
|
|
var messages []Message
|
|
if existing != nil {
|
|
messages = existing.([]Message)
|
|
}
|
|
messages = append(messages, msg)
|
|
|
|
// 限制缓存消息数量上限
|
|
if len(messages) > maxConversationCache {
|
|
messages = messages[len(messages)-maxConversationCache:]
|
|
}
|
|
|
|
h.conversationCache.Store(key, messages)
|
|
}
|
|
|
|
// GetConversation 获取完整对话历史
|
|
func (h *Hub) GetConversation(userID, sessionID string) []Message {
|
|
key := cacheKey(userID, sessionID)
|
|
|
|
val, ok := h.conversationCache.Load(key)
|
|
if !ok {
|
|
return []Message{}
|
|
}
|
|
|
|
messages, ok := val.([]Message)
|
|
if !ok {
|
|
return []Message{}
|
|
}
|
|
return messages
|
|
}
|
|
|
|
// GetSessionHistory 获取会话历史消息(限制条数)
|
|
func (h *Hub) GetSessionHistory(userID, sessionID string, limit int) []Message {
|
|
messages := h.GetConversation(userID, sessionID)
|
|
if len(messages) == 0 {
|
|
return []Message{}
|
|
}
|
|
|
|
if limit > 0 && len(messages) > limit {
|
|
start := len(messages) - limit
|
|
return messages[start:]
|
|
}
|
|
return messages
|
|
}
|
|
|
|
// DeleteConversation 删除对话缓存
|
|
func (h *Hub) DeleteConversation(userID, sessionID string) {
|
|
key := cacheKey(userID, sessionID)
|
|
h.conversationCache.Delete(key)
|
|
}
|