package handler import ( "encoding/json" "log" "net/http" "strings" "github.com/yourname/cyrene-ai/memory-service/internal/model" "github.com/yourname/cyrene-ai/memory-service/internal/service" ) // MemoryHandler HTTP API 处理器 type MemoryHandler struct { svc *service.MemoryService } // NewMemoryHandler 创建记忆处理器 func NewMemoryHandler(svc *service.MemoryService) *MemoryHandler { return &MemoryHandler{svc: svc} } // RegisterRoutes 注册所有路由到 mux func (h *MemoryHandler) RegisterRoutes(mux *http.ServeMux) { // POST /api/v1/memories - 创建/保存记忆 mux.HandleFunc("/api/v1/memories", h.handleMemories) // GET/DELETE/PUT /api/v1/memories/... (带 ID) mux.HandleFunc("/api/v1/memories/", h.handleMemoryByID) // POST /api/v1/memories/query - 语义查询 mux.HandleFunc("/api/v1/memories/query", h.handleQuery) // POST /api/v1/memories/consolidate - 合并相似记忆 mux.HandleFunc("/api/v1/memories/consolidate", h.handleConsolidate) // POST /api/v1/memories/decay - 衰减旧记忆 mux.HandleFunc("/api/v1/memories/decay", h.handleDecay) // GET /api/v1/memories/categories - 获取类别统计 mux.HandleFunc("/api/v1/memories/categories", h.handleCategories) // 自主思考日志 API mux.HandleFunc("/api/v1/thinking", h.handleThinking) mux.HandleFunc("/api/v1/thinking/", h.handleThinkingByID) mux.HandleFunc("/api/v1/thinking/stats", h.handleThinkingStats) } // handleMemories 处理 /api/v1/memories // GET - 列出用户记忆 (?user_id=xxx&category=xxx&min_importance=xxx&limit=xxx) // POST - 创建记忆 func (h *MemoryHandler) handleMemories(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: h.listMemories(w, r) case http.MethodPost: h.createMemory(w, r) default: writeError(w, http.StatusMethodNotAllowed, "method not allowed") } } // listMemories GET /api/v1/memories?user_id=xxx func (h *MemoryHandler) listMemories(w http.ResponseWriter, r *http.Request) { userID := r.URL.Query().Get("user_id") if userID == "" { writeError(w, http.StatusBadRequest, "缺少 user_id 参数") return } category := r.URL.Query().Get("category") limit := queryInt(r, "limit", 50) minImportance := queryInt(r, "min_importance", 0) memories, err := h.svc.ListMemories(r.Context(), userID, category, minImportance, limit) if err != nil { log.Printf("[memory-handler] 列出记忆失败: %v", err) writeError(w, http.StatusInternalServerError, err.Error()) return } if memories == nil { memories = []model.MemoryEntry{} } writeJSON(w, http.StatusOK, map[string]interface{}{ "user_id": userID, "memories": memories, "total": len(memories), }) } // createMemory POST /api/v1/memories func (h *MemoryHandler) createMemory(w http.ResponseWriter, r *http.Request) { var req struct { UserID string `json:"user_id"` Content string `json:"content"` Summary string `json:"summary"` Category string `json:"category"` Priority int `json:"priority"` Importance int `json:"importance"` Keywords []string `json:"keywords"` SessionID string `json:"session_id"` Source string `json:"source"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error()) return } if req.UserID == "" || req.Content == "" { writeError(w, http.StatusBadRequest, "缺少 user_id 或 content") return } entry := &model.MemoryEntry{ UserID: req.UserID, Content: req.Content, Summary: req.Summary, Category: model.MemoryCategory(req.Category), Priority: model.MemoryPriority(req.Priority), Importance: req.Importance, Keywords: req.Keywords, SessionID: req.SessionID, Source: req.Source, } if err := h.svc.CreateMemory(r.Context(), entry); err != nil { log.Printf("[memory-handler] 创建记忆失败: %v", err) writeError(w, http.StatusInternalServerError, err.Error()) return } writeJSON(w, http.StatusCreated, map[string]interface{}{ "status": "saved", "memory": entry, }) } // handleMemoryByID 处理 /api/v1/memories/{id} // GET - 获取单个记忆 // PUT - 更新记忆 // DELETE - 删除记忆 func (h *MemoryHandler) handleMemoryByID(w http.ResponseWriter, r *http.Request) { id := strings.TrimPrefix(r.URL.Path, "/api/v1/memories/") // 排除子路径 (query, consolidate, decay, categories) switch id { case "query", "consolidate", "decay", "categories": return // 这些有自己独立的处理器 } if id == "" { writeError(w, http.StatusBadRequest, "缺少记忆 ID") return } switch r.Method { case http.MethodGet: h.getMemory(w, r, id) case http.MethodPut: h.updateMemory(w, r, id) case http.MethodDelete: h.deleteMemory(w, r, id) default: writeError(w, http.StatusMethodNotAllowed, "method not allowed") } } // getMemory GET /api/v1/memories/:id func (h *MemoryHandler) getMemory(w http.ResponseWriter, r *http.Request, id string) { entry, err := h.svc.GetMemory(r.Context(), id) if err != nil { log.Printf("[memory-handler] 获取记忆失败: %v", err) writeError(w, http.StatusInternalServerError, err.Error()) return } if entry == nil { writeError(w, http.StatusNotFound, "记忆不存在") return } writeJSON(w, http.StatusOK, map[string]interface{}{ "memory": entry, }) } // updateMemory PUT /api/v1/memories/:id func (h *MemoryHandler) updateMemory(w http.ResponseWriter, r *http.Request, id string) { var req struct { Content string `json:"content"` Summary string `json:"summary"` Category string `json:"category"` Priority int `json:"priority"` Importance int `json:"importance"` Keywords []string `json:"keywords"` Source string `json:"source"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error()) return } entry := &model.MemoryEntry{ ID: id, Content: req.Content, Summary: req.Summary, Category: model.MemoryCategory(req.Category), Priority: model.MemoryPriority(req.Priority), Importance: req.Importance, Keywords: req.Keywords, Source: req.Source, } if err := h.svc.UpdateMemory(r.Context(), entry); err != nil { log.Printf("[memory-handler] 更新记忆失败: %v", err) writeError(w, http.StatusInternalServerError, err.Error()) return } writeJSON(w, http.StatusOK, map[string]interface{}{ "status": "updated", "memory_id": id, }) } // deleteMemory DELETE /api/v1/memories/:id func (h *MemoryHandler) deleteMemory(w http.ResponseWriter, r *http.Request, id string) { if err := h.svc.DeleteMemory(r.Context(), id); err != nil { log.Printf("[memory-handler] 删除记忆失败: %v", err) writeError(w, http.StatusInternalServerError, err.Error()) return } writeJSON(w, http.StatusOK, map[string]interface{}{ "status": "deleted", "memory_id": id, }) } // handleQuery POST /api/v1/memories/query func (h *MemoryHandler) handleQuery(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return } var req struct { UserID string `json:"user_id"` QueryText string `json:"query_text"` Category string `json:"category"` MinImportance int `json:"min_importance"` Limit int `json:"limit"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error()) return } if req.UserID == "" { writeError(w, http.StatusBadRequest, "缺少 user_id") return } if req.Limit <= 0 { req.Limit = 10 } memories, err := h.svc.QueryMemories(r.Context(), req.UserID, req.QueryText, req.Category, req.MinImportance, req.Limit) if err != nil { log.Printf("[memory-handler] 查询记忆失败: %v", err) writeError(w, http.StatusInternalServerError, err.Error()) return } if memories == nil { memories = []model.MemoryEntry{} } writeJSON(w, http.StatusOK, map[string]interface{}{ "user_id": req.UserID, "query": req.QueryText, "memories": memories, "total": len(memories), }) } // handleConsolidate POST /api/v1/memories/consolidate func (h *MemoryHandler) handleConsolidate(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return } var req struct { UserID string `json:"user_id"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error()) return } if req.UserID == "" { writeError(w, http.StatusBadRequest, "缺少 user_id") return } merged, err := h.svc.ConsolidateMemories(r.Context(), req.UserID) if err != nil { log.Printf("[memory-handler] 合并记忆失败: %v", err) writeError(w, http.StatusInternalServerError, err.Error()) return } writeJSON(w, http.StatusOK, map[string]interface{}{ "status": "consolidated", "user_id": req.UserID, "merged": merged, "message": "记忆整理完成", }) } // handleDecay POST /api/v1/memories/decay func (h *MemoryHandler) handleDecay(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return } var req struct { UserID string `json:"user_id"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error()) return } if req.UserID == "" { writeError(w, http.StatusBadRequest, "缺少 user_id") return } decayed, deleted, err := h.svc.DecayMemories(r.Context(), req.UserID) if err != nil { log.Printf("[memory-handler] 衰减记忆失败: %v", err) writeError(w, http.StatusInternalServerError, err.Error()) return } writeJSON(w, http.StatusOK, map[string]interface{}{ "status": "decayed", "user_id": req.UserID, "decayed": decayed, "deleted": deleted, "message": "记忆衰减完成", }) } // handleCategories GET /api/v1/memories/categories?user_id=xxx func (h *MemoryHandler) handleCategories(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return } userID := r.URL.Query().Get("user_id") if userID == "" { writeError(w, http.StatusBadRequest, "缺少 user_id 参数") return } categories, err := h.svc.GetCategories(r.Context(), userID) if err != nil { log.Printf("[memory-handler] 获取分类统计失败: %v", err) writeError(w, http.StatusInternalServerError, err.Error()) return } writeJSON(w, http.StatusOK, map[string]interface{}{ "user_id": userID, "categories": categories, }) } // --- 辅助函数 --- func writeJSON(w http.ResponseWriter, status int, data interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(data) } func writeError(w http.ResponseWriter, status int, message string) { writeJSON(w, status, map[string]interface{}{ "error": message, }) } func queryInt(r *http.Request, key string, fallback int) int { v := r.URL.Query().Get(key) if v == "" { return fallback } var result int for _, c := range v { if c < '0' || c > '9' { return fallback } result = result*10 + int(c-'0') } return result } // handleThinking 处理 /api/v1/thinking // GET - 分页查询思考日志 (?user_id=xxx&limit=xxx&offset=xxx) // POST - 保存思考日志 func (h *MemoryHandler) handleThinking(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: h.listThinkingLogs(w, r) case http.MethodPost: h.createThinkingLog(w, r) default: writeError(w, http.StatusMethodNotAllowed, "method not allowed") } } // createThinkingLog POST /api/v1/thinking func (h *MemoryHandler) createThinkingLog(w http.ResponseWriter, r *http.Request) { var req struct { UserID string `json:"user_id"` Content string `json:"content"` ToolCalls string `json:"tool_calls"` ToolCallCount int `json:"tool_call_count"` ContentLength int `json:"content_length"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error()) return } if req.Content == "" { writeError(w, http.StatusBadRequest, "缺少 content") return } tl := &model.ThinkingLog{ UserID: req.UserID, Content: req.Content, ToolCalls: req.ToolCalls, ToolCallCount: req.ToolCallCount, ContentLength: req.ContentLength, } if err := h.svc.SaveThinkingLog(r.Context(), tl); err != nil { log.Printf("[memory-handler] 保存思考日志失败: %v", err) writeError(w, http.StatusInternalServerError, err.Error()) return } writeJSON(w, http.StatusCreated, map[string]interface{}{ "status": "saved", "thinking": tl, }) } // listThinkingLogs GET /api/v1/thinking?user_id=xxx&limit=xxx&offset=xxx func (h *MemoryHandler) listThinkingLogs(w http.ResponseWriter, r *http.Request) { userID := r.URL.Query().Get("user_id") limit := queryInt(r, "limit", 20) offset := queryInt(r, "offset", 0) logs, err := h.svc.QueryThinkingLogs(r.Context(), model.ThinkingQuery{ UserID: userID, Limit: limit, Offset: offset, }) if err != nil { log.Printf("[memory-handler] 查询思考日志失败: %v", err) writeError(w, http.StatusInternalServerError, err.Error()) return } if logs == nil { logs = []model.ThinkingLog{} } writeJSON(w, http.StatusOK, map[string]interface{}{ "logs": logs, "total": len(logs), }) } // handleThinkingByID 处理 /api/v1/thinking/{id} // GET - 获取单条思考日志 func (h *MemoryHandler) handleThinkingByID(w http.ResponseWriter, r *http.Request) { id := strings.TrimPrefix(r.URL.Path, "/api/v1/thinking/") // 排除子路径 if id == "stats" || id == "" { return } if r.Method != http.MethodGet { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return } thinkingLog, err := h.svc.GetThinkingLogByID(r.Context(), id) if err != nil { log.Printf("[memory-handler] 获取思考日志失败: %v", err) writeError(w, http.StatusInternalServerError, err.Error()) return } if thinkingLog == nil { writeError(w, http.StatusNotFound, "思考日志不存在") return } writeJSON(w, http.StatusOK, map[string]interface{}{ "thinking": thinkingLog, }) } // handleThinkingStats GET /api/v1/thinking/stats func (h *MemoryHandler) handleThinkingStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return } stats, err := h.svc.GetThinkingStats(r.Context()) if err != nil { log.Printf("[memory-handler] 获取思考日志统计失败: %v", err) writeError(w, http.StatusInternalServerError, err.Error()) return } writeJSON(w, http.StatusOK, map[string]interface{}{ "stats": stats, }) }