package handler import ( "crypto/rand" "encoding/json" "fmt" "git.yeij.top/AskaEth/Cyrene/pkg/logger" "net/http" "strings" "time" "github.com/gin-gonic/gin" "git.yeij.top/AskaEth/Cyrene/gateway/internal/middleware" "git.yeij.top/AskaEth/Cyrene/gateway/internal/store" "git.yeij.top/AskaEth/Cyrene/gateway/internal/ws" ) // SessionHandler 会话管理处理器 type SessionHandler struct { store *store.SessionStore // PostgreSQL 持久化存储 hub *ws.Hub useDB bool // 数据库是否可用 } // NewSessionHandler 创建会话处理器 func NewSessionHandler(hub *ws.Hub, s *store.SessionStore) *SessionHandler { return &SessionHandler{ store: s, hub: hub, useDB: s != nil && s.IsAvailable(), } } // canAccess 检查当前用户是否有权访问指定 userID 所属的资源 // admin 用户豁免所有权检查 func (h *SessionHandler) canAccess(c *gin.Context, resourceOwnerID string) bool { currentUserID := middleware.GetUserID(c) if middleware.GetIsAdmin(c) { return true } return currentUserID == resourceOwnerID } // ========== POST /api/v1/sessions — 创建会话 ========== type createSessionRequest struct { UserID string `json:"user_id"` SessionID string `json:"session_id"` Title string `json:"title"` IsMain bool `json:"is_main"` } // Create 创建新会话 func (h *SessionHandler) Create(c *gin.Context) { userID := middleware.GetUserID(c) var req createSessionRequest if err := c.ShouldBindJSON(&req); err != nil { // 允许空 body } if req.Title == "" { req.Title = "新的对话" } if req.SessionID == "" { req.SessionID = "session_" + randomID(12) } if h.useDB { if err := h.store.CreateSession(userID, req.SessionID, req.Title, req.IsMain); err != nil { logger.Printf("[SessionHandler] 创建会话失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "创建会话失败", "errorType": "db_error"}) return } } c.JSON(http.StatusCreated, gin.H{ "id": req.SessionID, "user_id": userID, "title": req.Title, "is_main": req.IsMain, "created_at": time.Now().UnixMilli(), "updated_at": time.Now().UnixMilli(), }) } // ========== GET /api/v1/sessions?user_id=xxx — 获取用户会话列表 ========== // List 获取会话列表 (按 updated_at DESC 排序) func (h *SessionHandler) List(c *gin.Context) { currentUserID := middleware.GetUserID(c) isAdmin := middleware.GetIsAdmin(c) // 非管理员只能查询自己的会话 userID := c.Query("user_id") if userID == "" { userID = currentUserID } else if !isAdmin { // 非管理员试图查询其他用户的会话 if userID != currentUserID { c.JSON(http.StatusForbidden, gin.H{"error": "无权访问其他用户的会话"}) return } } if h.useDB { sessions, err := h.store.GetUserSessions(userID) if err != nil { logger.Printf("[SessionHandler] 查询会话列表失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "查询会话失败", "errorType": "db_error"}) return } // 转换为列表格式 result := make([]gin.H, 0, len(sessions)) for _, s := range sessions { result = append(result, gin.H{ "id": s.ID, "user_id": s.UserID, "title": s.Title, "is_main": s.IsMain, "created_at": s.CreatedAt.UnixMilli(), "updated_at": s.UpdatedAt.UnixMilli(), }) } c.JSON(http.StatusOK, gin.H{"sessions": result}) return } // 降级:返回空列表 c.JSON(http.StatusOK, gin.H{"sessions": []gin.H{}}) } // ========== GET /api/v1/sessions/:id — 获取单个会话 ========== // Get 获取单个会话信息 func (h *SessionHandler) Get(c *gin.Context) { sessionID := c.Param("id") if h.useDB { session, err := h.store.GetSession(sessionID) if err != nil { logger.Printf("[SessionHandler] 查询会话失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "查询会话失败", "errorType": "db_error"}) return } if session == nil { c.JSON(http.StatusNotFound, gin.H{ "error": "会话不存在", "errorType": "session_not_found", "hint": "该会话可能已被删除或尚未创建", }) return } // 所有权校验:非管理员只能访问自己的会话 if !h.canAccess(c, session.UserID) { c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此会话"}) return } c.JSON(http.StatusOK, gin.H{ "id": session.ID, "user_id": session.UserID, "title": session.Title, "is_main": session.IsMain, "created_at": session.CreatedAt.UnixMilli(), "updated_at": session.UpdatedAt.UnixMilli(), }) return } c.JSON(http.StatusNotFound, gin.H{ "error": "会话存储不可用", "errorType": "store_unavailable", "hint": "数据库连接未建立,Gateway 运行在仅内存模式", }) } // ========== DELETE /api/v1/sessions/:id — 删除会话 ========== // Delete 删除会话 (不删除记忆) func (h *SessionHandler) Delete(c *gin.Context) { sessionID := c.Param("id") if h.useDB { // 所有权校验:先获取session再验证归属 session, err := h.store.GetSession(sessionID) if err != nil { logger.Printf("[SessionHandler] 查询会话失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "删除会话失败", "errorType": "db_error"}) return } if session == nil { c.JSON(http.StatusNotFound, gin.H{ "error": "会话不存在", "errorType": "session_not_found", }) return } if !h.canAccess(c, session.UserID) { c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此会话"}) return } if err := h.store.DeleteSession(sessionID); err != nil { logger.Printf("[SessionHandler] 删除会话失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "删除会话失败", "errorType": "db_error"}) return } } // 同时清理 Hub 中的缓存 h.hub.DeleteConversation("", sessionID) c.JSON(http.StatusOK, gin.H{"status": "deleted"}) } // ========== DELETE /api/v1/sessions?user_id=xxx — 删除用户所有会话 ========== // DeleteAll 删除用户所有会话 (不删除记忆) func (h *SessionHandler) DeleteAll(c *gin.Context) { currentUserID := middleware.GetUserID(c) isAdmin := middleware.GetIsAdmin(c) userID := c.Query("user_id") if userID == "" { userID = currentUserID } else if !isAdmin { // 非管理员只能删除自己的会话 if userID != currentUserID { c.JSON(http.StatusForbidden, gin.H{"error": "无权删除其他用户的会话"}) return } } if h.useDB { if err := h.store.DeleteAllUserSessions(userID); err != nil { logger.Printf("[SessionHandler] 删除用户所有会话失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "删除会话失败", "errorType": "db_error"}) return } } c.JSON(http.StatusOK, gin.H{"status": "deleted"}) } // ========== GET /api/v1/sessions/:id/messages?limit=50 — 获取会话消息 ========== // GetMessages 获取会话的完整消息列表 func (h *SessionHandler) GetMessages(c *gin.Context) { sessionID := c.Param("id") // 所有权校验 if h.useDB { session, err := h.store.GetSession(sessionID) if err != nil { logger.Printf("[SessionHandler] 查询会话失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "查询消息失败", "errorType": "db_error"}) return } if session == nil { c.JSON(http.StatusNotFound, gin.H{ "error": "会话不存在", "errorType": "session_not_found", }) return } if !h.canAccess(c, session.UserID) { c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此会话的消息"}) return } } limit := 50 if l := c.Query("limit"); l != "" { parsed := 0 for _, ch := range l { if ch < '0' || ch > '9' { break } parsed = parsed*10 + int(ch-'0') } if parsed > 0 { limit = parsed } } offset := 0 if o := c.Query("offset"); o != "" { parsed := 0 for _, ch := range o { if ch < '0' || ch > '9' { break } parsed = parsed*10 + int(ch-'0') } if parsed > 0 { offset = parsed } } if h.useDB { messages, err := h.store.GetMessages(sessionID, limit, offset) if err != nil { logger.Printf("[SessionHandler] 查询消息失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "查询消息失败", "errorType": "db_error"}) return } // 转换为统一格式 result := make([]gin.H, 0, len(messages)) for _, m := range messages { result = append(result, gin.H{ "id": m.ID, "session_id": m.SessionID, "role": m.Role, "msg_type": m.MsgType, "content": m.Content, "created_at": m.CreatedAt.UnixMilli(), }) } c.JSON(http.StatusOK, gin.H{"messages": result}) return } // 降级:从 Hub 内存缓存读取(使用当前认证用户的 ID 作为缓存键前缀) userID := middleware.GetUserID(c) messages := h.hub.GetConversation(userID, sessionID) if messages == nil { messages = []ws.Message{} } c.JSON(http.StatusOK, gin.H{"messages": messages}) } // ========== DELETE /api/v1/sessions/:id/messages — 清空会话消息 ========== // ClearMessages 清空会话所有消息但不删除会话本身 func (h *SessionHandler) ClearMessages(c *gin.Context) { sessionID := c.Param("id") if h.useDB { // 所有权校验 session, err := h.store.GetSession(sessionID) if err != nil { logger.Printf("[SessionHandler] 查询会话失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "清空消息失败", "errorType": "db_error"}) return } if session == nil { c.JSON(http.StatusNotFound, gin.H{ "error": "会话不存在", "errorType": "session_not_found", }) return } if !h.canAccess(c, session.UserID) { c.JSON(http.StatusForbidden, gin.H{"error": "无权清空此会话的消息"}) return } if err := h.store.ClearSessionMessages(sessionID); err != nil { logger.Printf("[SessionHandler] 清空消息失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "清空消息失败", "errorType": "db_error"}) return } } // 同时清理 Hub 内存缓存 h.hub.DeleteConversation("", sessionID) c.JSON(http.StatusOK, gin.H{"status": "cleared"}) } // ========== Admin 端点 ========== // ListActiveSessions 获取当前所有活跃 WebSocket 会话列表 (管理员) func (h *SessionHandler) ListActiveSessions(c *gin.Context) { sessions := h.hub.GetAllActiveSessions() if sessions == nil { sessions = []*ws.SessionState{} } c.JSON(http.StatusOK, gin.H{ "sessions": sessions, "total": len(sessions), }) } // GetActiveSessions 返回按用户分组的活跃会话列表 func (h *SessionHandler) GetActiveSessions(c *gin.Context) { sessionsByUser := h.hub.GetActiveSessionsByUser() if sessionsByUser == nil { sessionsByUser = make(map[string][]*ws.SessionState) } c.JSON(http.StatusOK, gin.H{ "users": sessionsByUser, }) } // GetSession 获取指定会话的详细信息 (管理员) func (h *SessionHandler) GetSession(c *gin.Context) { sessionID := c.Param("id") if sessionID == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "缺少会话ID"}) return } session := h.hub.GetSession(sessionID) if session == nil { c.JSON(http.StatusNotFound, gin.H{ "error": "会话不存在", "errorType": "session_not_found", "hint": "该会话可能已断开,或 Gateway 重启后内存数据已清空", }) return } c.JSON(http.StatusOK, session) } // ========== GET /api/v1/messages/search?q=xxx&user_id=xxx&limit=50&offset=0 — 全文搜索消息 ========== // SearchMessages 全文搜索消息 func (h *SessionHandler) SearchMessages(c *gin.Context) { query := c.Query("q") if query == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "缺少搜索关键词参数 q", "errorType": "missing_query"}) return } userID := c.Query("user_id") if userID == "" { userID = middleware.GetUserID(c) } limit := 50 offset := 0 if l := c.Query("limit"); l != "" { parsed := 0 for _, ch := range l { if ch < '0' || ch > '9' { break } parsed = parsed*10 + int(ch-'0') } if parsed > 0 && parsed <= 200 { limit = parsed } } if o := c.Query("offset"); o != "" { parsed := 0 for _, ch := range o { if ch < '0' || ch > '9' { break } parsed = parsed*10 + int(ch-'0') } if parsed >= 0 { offset = parsed } } if !h.useDB { c.JSON(http.StatusOK, gin.H{ "results": []gin.H{}, "total": 0, "query": query, }) return } results, total, err := h.store.SearchMessages(userID, query, limit, offset) if err != nil { logger.Printf("[SessionHandler] 搜索消息失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "搜索失败", "errorType": "db_error"}) return } // 转换为统一格式 items := make([]gin.H, 0, len(results)) for _, r := range results { items = append(items, gin.H{ "message_id": r.MessageID, "session_id": r.SessionID, "session_title": r.SessionTitle, "role": r.Role, "content": r.Content, "created_at": r.CreatedAt.UnixMilli(), }) } c.JSON(http.StatusOK, gin.H{ "results": items, "total": total, "query": query, "limit": limit, "offset": offset, }) } // ========== GET /api/v1/sessions/:id/export?format=json|markdown|txt — 导出会话 ========== // ExportSession 导出会话为指定格式 func (h *SessionHandler) ExportSession(c *gin.Context) { sessionID := c.Param("id") format := c.Query("format") if format == "" { format = "json" } // 验证格式 switch format { case "json", "markdown", "txt": // valid default: c.JSON(http.StatusBadRequest, gin.H{ "error": "不支持的导出格式", "errorType": "invalid_format", "hint": "支持的格式: json, markdown, txt", }) return } if !h.useDB { c.JSON(http.StatusServiceUnavailable, gin.H{ "error": "会话存储不可用", "errorType": "store_unavailable", "hint": "数据库连接未建立,无法导出会话", }) return } // 获取会话信息 session, err := h.store.GetSession(sessionID) if err != nil { logger.Printf("[SessionHandler] 查询会话失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "查询会话失败", "errorType": "db_error"}) return } if session == nil { c.JSON(http.StatusNotFound, gin.H{ "error": "会话不存在", "errorType": "session_not_found", }) return } // 所有权校验:非管理员只能导出自己的会话 if !h.canAccess(c, session.UserID) { c.JSON(http.StatusForbidden, gin.H{"error": "无权导出此会话"}) return } // 获取所有消息 (不限制数量,导出全部) messages, err := h.store.GetMessages(sessionID, 0, 0) if err != nil { logger.Printf("[SessionHandler] 查询消息失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "查询消息失败", "errorType": "db_error"}) return } if messages == nil { messages = []store.Message{} } now := time.Now() switch format { case "json": h.exportJSON(c, session, messages, now) case "markdown": h.exportMarkdown(c, session, messages, now) case "txt": h.exportTXT(c, session, messages, now) } } // exportJSON 导出 JSON 格式 func (h *SessionHandler) exportJSON(c *gin.Context, session *store.Session, messages []store.Message, now time.Time) { type msgOut struct { Role string `json:"role"` Content string `json:"content"` CreatedAt int64 `json:"created_at"` } type sessionOut struct { ID string `json:"id"` Title string `json:"title"` CreatedAt int64 `json:"created_at"` UpdatedAt int64 `json:"updated_at"` } type export struct { Session sessionOut `json:"session"` Messages []msgOut `json:"messages"` } msgs := make([]msgOut, 0, len(messages)) for _, m := range messages { msgs = append(msgs, msgOut{ Role: m.Role, Content: m.Content, CreatedAt: m.CreatedAt.UnixMilli(), }) } data := export{ Session: sessionOut{ ID: session.ID, Title: session.Title, CreatedAt: session.CreatedAt.UnixMilli(), UpdatedAt: session.UpdatedAt.UnixMilli(), }, Messages: msgs, } jsonBytes, err := json.MarshalIndent(data, "", " ") if err != nil { logger.Printf("[SessionHandler] JSON序列化失败: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "导出失败", "errorType": "serialization_error"}) return } c.Header("Content-Type", "application/json; charset=utf-8") c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="session_%s.json"`, session.ID)) c.Data(http.StatusOK, "application/json; charset=utf-8", jsonBytes) } // exportMarkdown 导出 Markdown 格式 func (h *SessionHandler) exportMarkdown(c *gin.Context, session *store.Session, messages []store.Message, now time.Time) { var sb strings.Builder sb.WriteString(fmt.Sprintf("# 对话导出: %s\n", session.Title)) sb.WriteString(fmt.Sprintf("**会话 ID**: %s\n", session.ID)) sb.WriteString(fmt.Sprintf("**导出时间**: %s\n", now.Format("2006-01-02 15:04:05"))) sb.WriteString(fmt.Sprintf("**消息数量**: %d\n", len(messages))) sb.WriteString("\n---\n\n") for _, m := range messages { timeStr := m.CreatedAt.Format("2006-01-02 15:04:05") switch m.Role { case "user": sb.WriteString(fmt.Sprintf("### 👤 用户 (%s)\n\n", timeStr)) case "assistant": sb.WriteString(fmt.Sprintf("### 💫 昔涟 (%s)\n\n", timeStr)) case "system": sb.WriteString(fmt.Sprintf("### ⚙️ 系统 (%s)\n\n", timeStr)) default: sb.WriteString(fmt.Sprintf("### %s (%s)\n\n", m.Role, timeStr)) } sb.WriteString(m.Content) sb.WriteString("\n\n---\n\n") } content := sb.String() c.Header("Content-Type", "text/markdown; charset=utf-8") c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="session_%s.md"`, session.ID)) c.Data(http.StatusOK, "text/markdown; charset=utf-8", []byte(content)) } // exportTXT 导出纯文本格式 func (h *SessionHandler) exportTXT(c *gin.Context, session *store.Session, messages []store.Message, now time.Time) { var sb strings.Builder sb.WriteString(fmt.Sprintf("对话导出: %s\n", session.Title)) sb.WriteString(fmt.Sprintf("会话 ID: %s\n", session.ID)) sb.WriteString(fmt.Sprintf("导出时间: %s\n", now.Format("2006-01-02 15:04:05"))) sb.WriteString(fmt.Sprintf("消息数量: %d\n", len(messages))) sb.WriteString(strings.Repeat("=", 50)) sb.WriteString("\n\n") for _, m := range messages { timeStr := m.CreatedAt.Format("2006-01-02 15:04:05") roleLabel := m.Role switch m.Role { case "user": roleLabel = "用户" case "assistant": roleLabel = "昔涟" case "system": roleLabel = "系统" } sb.WriteString(fmt.Sprintf("[%s] %s:\n%s\n\n", timeStr, roleLabel, m.Content)) } content := sb.String() c.Header("Content-Type", "text/plain; charset=utf-8") c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="session_%s.txt"`, session.ID)) c.Data(http.StatusOK, "text/plain; charset=utf-8", []byte(content)) } // 简单的工具函数 func randomID(n int) string { const letters = "abcdefghijklmnopqrstuvwxyz0123456789" b := make([]byte, n) if _, err := rand.Read(b); err != nil { // fallback to deterministic IDs only if crypto/rand fails for i := range b { b[i] = letters[i%len(letters)] } return string(b) } for i := range b { b[i] = letters[int(b[i])%len(letters)] } return string(b) }