feat: Round 5 - Memory Service, Tool Engine, Call Records, Thinking Logs
- Fix: Session history flash (race condition + WS guard) - Fix: Chat background overlay + sidebar transparency - Fix: IoT device control (Chinese action names, status field) - Feat: Independent memory-service (port 8091, 13 endpoints) - Feat: Independent tool-engine service (port 8092, 13 tools) - Feat: Tool call logs with paginated DevTools panel - Feat: Thinking log records with DevTools panel - Feat: Future development roadmap document - Chore: Updated .gitignore, go.work, DevTools config - Chore: 5-service health check, project review docs
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
# Build stage
|
||||
FROM golang:1.24-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /memory-service ./cmd/
|
||||
|
||||
# Runtime stage
|
||||
FROM alpine:3.21
|
||||
|
||||
RUN apk --no-cache add ca-certificates
|
||||
WORKDIR /app
|
||||
COPY --from=builder /memory-service .
|
||||
|
||||
EXPOSE 8091
|
||||
|
||||
ENTRYPOINT ["./memory-service"]
|
||||
@@ -0,0 +1,76 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/config"
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/handler"
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/service"
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/store"
|
||||
)
|
||||
|
||||
func main() {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
log.Println("🧠 Memory-Service 启动中...")
|
||||
|
||||
// 加载配置
|
||||
cfg := config.Load()
|
||||
|
||||
log.Printf("配置: 端口=%s, 数据库=%s...", cfg.Port, maskDBURL(cfg.DatabaseURL))
|
||||
|
||||
// 初始化数据库存储
|
||||
memStore := store.NewStore(cfg.DatabaseURL)
|
||||
defer memStore.Close()
|
||||
|
||||
// 初始化服务层
|
||||
svc := service.NewMemoryService(memStore)
|
||||
|
||||
// 初始化 HTTP 处理器
|
||||
h := handler.NewMemoryHandler(svc)
|
||||
|
||||
// 注册路由
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
// 健康检查端点
|
||||
mux.HandleFunc("/api/v1/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
status := "ok"
|
||||
if !memStore.IsReady() {
|
||||
status = "degraded"
|
||||
}
|
||||
w.Write([]byte(`{"status":"` + status + `","service":"memory-service"}`))
|
||||
})
|
||||
|
||||
// 启动 HTTP 服务
|
||||
srv := &http.Server{
|
||||
Addr: ":" + cfg.Port,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("🚀 Memory-Service 已启动在端口 %s", cfg.Port)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("服务启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 优雅关闭
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
log.Println("正在关闭 Memory-Service...")
|
||||
srv.Close()
|
||||
log.Println("Memory-Service 已关闭")
|
||||
}
|
||||
|
||||
func maskDBURL(url string) string {
|
||||
if len(url) > 30 {
|
||||
return url[:30] + "..."
|
||||
}
|
||||
return url
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
module github.com/yourname/cyrene-ai/memory-service
|
||||
|
||||
go 1.26.2
|
||||
|
||||
require github.com/lib/pq v1.10.9
|
||||
@@ -0,0 +1,2 @@
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
@@ -0,0 +1,45 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Config 记忆服务配置
|
||||
type Config struct {
|
||||
Port string
|
||||
DatabaseURL string
|
||||
}
|
||||
|
||||
// Load 从环境变量加载配置
|
||||
func Load() *Config {
|
||||
return &Config{
|
||||
Port: getEnv("PORT", "8091"),
|
||||
DatabaseURL: buildDatabaseURL(),
|
||||
}
|
||||
}
|
||||
|
||||
// buildDatabaseURL 构建 PostgreSQL 连接字符串
|
||||
func buildDatabaseURL() string {
|
||||
// 优先使用 DB_URL 环境变量(简化模式)
|
||||
if url := os.Getenv("DB_URL"); url != "" {
|
||||
return url
|
||||
}
|
||||
|
||||
host := getEnv("POSTGRES_HOST", "localhost")
|
||||
port := getEnv("POSTGRES_PORT", "5432")
|
||||
user := getEnv("POSTGRES_USER", "cyrene")
|
||||
password := getEnv("POSTGRES_PASSWORD", "change_me")
|
||||
dbname := getEnv("POSTGRES_DB", "cyrene_ai")
|
||||
sslmode := getEnv("POSTGRES_SSLMODE", "disable")
|
||||
|
||||
return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s",
|
||||
user, password, host, port, dbname, sslmode)
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
@@ -0,0 +1,542 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/model"
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/service"
|
||||
)
|
||||
|
||||
// MemoryHandler HTTP API 处理器
|
||||
type MemoryHandler struct {
|
||||
svc *service.MemoryService
|
||||
}
|
||||
|
||||
// NewMemoryHandler 创建记忆处理器
|
||||
func NewMemoryHandler(svc *service.MemoryService) *MemoryHandler {
|
||||
return &MemoryHandler{svc: svc}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册所有路由到 mux
|
||||
func (h *MemoryHandler) RegisterRoutes(mux *http.ServeMux) {
|
||||
// POST /api/v1/memories - 创建/保存记忆
|
||||
mux.HandleFunc("/api/v1/memories", h.handleMemories)
|
||||
// GET/DELETE/PUT /api/v1/memories/... (带 ID)
|
||||
mux.HandleFunc("/api/v1/memories/", h.handleMemoryByID)
|
||||
// POST /api/v1/memories/query - 语义查询
|
||||
mux.HandleFunc("/api/v1/memories/query", h.handleQuery)
|
||||
// POST /api/v1/memories/consolidate - 合并相似记忆
|
||||
mux.HandleFunc("/api/v1/memories/consolidate", h.handleConsolidate)
|
||||
// POST /api/v1/memories/decay - 衰减旧记忆
|
||||
mux.HandleFunc("/api/v1/memories/decay", h.handleDecay)
|
||||
// GET /api/v1/memories/categories - 获取类别统计
|
||||
mux.HandleFunc("/api/v1/memories/categories", h.handleCategories)
|
||||
// 自主思考日志 API
|
||||
mux.HandleFunc("/api/v1/thinking", h.handleThinking)
|
||||
mux.HandleFunc("/api/v1/thinking/", h.handleThinkingByID)
|
||||
mux.HandleFunc("/api/v1/thinking/stats", h.handleThinkingStats)
|
||||
}
|
||||
|
||||
// handleMemories 处理 /api/v1/memories
|
||||
// GET - 列出用户记忆 (?user_id=xxx&category=xxx&min_importance=xxx&limit=xxx)
|
||||
// POST - 创建记忆
|
||||
func (h *MemoryHandler) handleMemories(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.listMemories(w, r)
|
||||
case http.MethodPost:
|
||||
h.createMemory(w, r)
|
||||
default:
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// listMemories GET /api/v1/memories?user_id=xxx
|
||||
func (h *MemoryHandler) listMemories(w http.ResponseWriter, r *http.Request) {
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
if userID == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 user_id 参数")
|
||||
return
|
||||
}
|
||||
|
||||
category := r.URL.Query().Get("category")
|
||||
limit := queryInt(r, "limit", 50)
|
||||
minImportance := queryInt(r, "min_importance", 0)
|
||||
|
||||
memories, err := h.svc.ListMemories(r.Context(), userID, category, minImportance, limit)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 列出记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if memories == nil {
|
||||
memories = []model.MemoryEntry{}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"memories": memories,
|
||||
"total": len(memories),
|
||||
})
|
||||
}
|
||||
|
||||
// createMemory POST /api/v1/memories
|
||||
func (h *MemoryHandler) createMemory(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
UserID string `json:"user_id"`
|
||||
Content string `json:"content"`
|
||||
Summary string `json:"summary"`
|
||||
Category string `json:"category"`
|
||||
Priority int `json:"priority"`
|
||||
Importance int `json:"importance"`
|
||||
Keywords []string `json:"keywords"`
|
||||
SessionID string `json:"session_id"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.UserID == "" || req.Content == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 user_id 或 content")
|
||||
return
|
||||
}
|
||||
|
||||
entry := &model.MemoryEntry{
|
||||
UserID: req.UserID,
|
||||
Content: req.Content,
|
||||
Summary: req.Summary,
|
||||
Category: model.MemoryCategory(req.Category),
|
||||
Priority: model.MemoryPriority(req.Priority),
|
||||
Importance: req.Importance,
|
||||
Keywords: req.Keywords,
|
||||
SessionID: req.SessionID,
|
||||
Source: req.Source,
|
||||
}
|
||||
|
||||
if err := h.svc.CreateMemory(r.Context(), entry); err != nil {
|
||||
log.Printf("[memory-handler] 创建记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"status": "saved",
|
||||
"memory": entry,
|
||||
})
|
||||
}
|
||||
|
||||
// handleMemoryByID 处理 /api/v1/memories/{id}
|
||||
// GET - 获取单个记忆
|
||||
// PUT - 更新记忆
|
||||
// DELETE - 删除记忆
|
||||
func (h *MemoryHandler) handleMemoryByID(w http.ResponseWriter, r *http.Request) {
|
||||
id := strings.TrimPrefix(r.URL.Path, "/api/v1/memories/")
|
||||
// 排除子路径 (query, consolidate, decay, categories)
|
||||
switch id {
|
||||
case "query", "consolidate", "decay", "categories":
|
||||
return // 这些有自己独立的处理器
|
||||
}
|
||||
|
||||
if id == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少记忆 ID")
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.getMemory(w, r, id)
|
||||
case http.MethodPut:
|
||||
h.updateMemory(w, r, id)
|
||||
case http.MethodDelete:
|
||||
h.deleteMemory(w, r, id)
|
||||
default:
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// getMemory GET /api/v1/memories/:id
|
||||
func (h *MemoryHandler) getMemory(w http.ResponseWriter, r *http.Request, id string) {
|
||||
entry, err := h.svc.GetMemory(r.Context(), id)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 获取记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if entry == nil {
|
||||
writeError(w, http.StatusNotFound, "记忆不存在")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"memory": entry,
|
||||
})
|
||||
}
|
||||
|
||||
// updateMemory PUT /api/v1/memories/:id
|
||||
func (h *MemoryHandler) updateMemory(w http.ResponseWriter, r *http.Request, id string) {
|
||||
var req struct {
|
||||
Content string `json:"content"`
|
||||
Summary string `json:"summary"`
|
||||
Category string `json:"category"`
|
||||
Priority int `json:"priority"`
|
||||
Importance int `json:"importance"`
|
||||
Keywords []string `json:"keywords"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
entry := &model.MemoryEntry{
|
||||
ID: id,
|
||||
Content: req.Content,
|
||||
Summary: req.Summary,
|
||||
Category: model.MemoryCategory(req.Category),
|
||||
Priority: model.MemoryPriority(req.Priority),
|
||||
Importance: req.Importance,
|
||||
Keywords: req.Keywords,
|
||||
Source: req.Source,
|
||||
}
|
||||
|
||||
if err := h.svc.UpdateMemory(r.Context(), entry); err != nil {
|
||||
log.Printf("[memory-handler] 更新记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"status": "updated",
|
||||
"memory_id": id,
|
||||
})
|
||||
}
|
||||
|
||||
// deleteMemory DELETE /api/v1/memories/:id
|
||||
func (h *MemoryHandler) deleteMemory(w http.ResponseWriter, r *http.Request, id string) {
|
||||
if err := h.svc.DeleteMemory(r.Context(), id); err != nil {
|
||||
log.Printf("[memory-handler] 删除记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"status": "deleted",
|
||||
"memory_id": id,
|
||||
})
|
||||
}
|
||||
|
||||
// handleQuery POST /api/v1/memories/query
|
||||
func (h *MemoryHandler) handleQuery(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
UserID string `json:"user_id"`
|
||||
QueryText string `json:"query_text"`
|
||||
Category string `json:"category"`
|
||||
MinImportance int `json:"min_importance"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.UserID == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 user_id")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Limit <= 0 {
|
||||
req.Limit = 10
|
||||
}
|
||||
|
||||
memories, err := h.svc.QueryMemories(r.Context(), req.UserID, req.QueryText, req.Category, req.MinImportance, req.Limit)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 查询记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if memories == nil {
|
||||
memories = []model.MemoryEntry{}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"user_id": req.UserID,
|
||||
"query": req.QueryText,
|
||||
"memories": memories,
|
||||
"total": len(memories),
|
||||
})
|
||||
}
|
||||
|
||||
// handleConsolidate POST /api/v1/memories/consolidate
|
||||
func (h *MemoryHandler) handleConsolidate(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.UserID == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 user_id")
|
||||
return
|
||||
}
|
||||
|
||||
merged, err := h.svc.ConsolidateMemories(r.Context(), req.UserID)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 合并记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"status": "consolidated",
|
||||
"user_id": req.UserID,
|
||||
"merged": merged,
|
||||
"message": "记忆整理完成",
|
||||
})
|
||||
}
|
||||
|
||||
// handleDecay POST /api/v1/memories/decay
|
||||
func (h *MemoryHandler) handleDecay(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.UserID == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 user_id")
|
||||
return
|
||||
}
|
||||
|
||||
decayed, deleted, err := h.svc.DecayMemories(r.Context(), req.UserID)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 衰减记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"status": "decayed",
|
||||
"user_id": req.UserID,
|
||||
"decayed": decayed,
|
||||
"deleted": deleted,
|
||||
"message": "记忆衰减完成",
|
||||
})
|
||||
}
|
||||
|
||||
// handleCategories GET /api/v1/memories/categories?user_id=xxx
|
||||
func (h *MemoryHandler) handleCategories(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
if userID == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 user_id 参数")
|
||||
return
|
||||
}
|
||||
|
||||
categories, err := h.svc.GetCategories(r.Context(), userID)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 获取分类统计失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"categories": categories,
|
||||
})
|
||||
}
|
||||
|
||||
// --- 辅助函数 ---
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, message string) {
|
||||
writeJSON(w, status, map[string]interface{}{
|
||||
"error": message,
|
||||
})
|
||||
}
|
||||
|
||||
func queryInt(r *http.Request, key string, fallback int) int {
|
||||
v := r.URL.Query().Get(key)
|
||||
if v == "" {
|
||||
return fallback
|
||||
}
|
||||
var result int
|
||||
for _, c := range v {
|
||||
if c < '0' || c > '9' {
|
||||
return fallback
|
||||
}
|
||||
result = result*10 + int(c-'0')
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// handleThinking 处理 /api/v1/thinking
|
||||
// GET - 分页查询思考日志 (?user_id=xxx&limit=xxx&offset=xxx)
|
||||
// POST - 保存思考日志
|
||||
func (h *MemoryHandler) handleThinking(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.listThinkingLogs(w, r)
|
||||
case http.MethodPost:
|
||||
h.createThinkingLog(w, r)
|
||||
default:
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// createThinkingLog POST /api/v1/thinking
|
||||
func (h *MemoryHandler) createThinkingLog(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
UserID string `json:"user_id"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls string `json:"tool_calls"`
|
||||
ToolCallCount int `json:"tool_call_count"`
|
||||
ContentLength int `json:"content_length"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Content == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 content")
|
||||
return
|
||||
}
|
||||
|
||||
tl := &model.ThinkingLog{
|
||||
UserID: req.UserID,
|
||||
Content: req.Content,
|
||||
ToolCalls: req.ToolCalls,
|
||||
ToolCallCount: req.ToolCallCount,
|
||||
ContentLength: req.ContentLength,
|
||||
}
|
||||
|
||||
if err := h.svc.SaveThinkingLog(r.Context(), tl); err != nil {
|
||||
log.Printf("[memory-handler] 保存思考日志失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"status": "saved",
|
||||
"thinking": tl,
|
||||
})
|
||||
}
|
||||
|
||||
// listThinkingLogs GET /api/v1/thinking?user_id=xxx&limit=xxx&offset=xxx
|
||||
func (h *MemoryHandler) listThinkingLogs(w http.ResponseWriter, r *http.Request) {
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
limit := queryInt(r, "limit", 20)
|
||||
offset := queryInt(r, "offset", 0)
|
||||
|
||||
logs, err := h.svc.QueryThinkingLogs(r.Context(), model.ThinkingQuery{
|
||||
UserID: userID,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 查询思考日志失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if logs == nil {
|
||||
logs = []model.ThinkingLog{}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"logs": logs,
|
||||
"total": len(logs),
|
||||
})
|
||||
}
|
||||
|
||||
// handleThinkingByID 处理 /api/v1/thinking/{id}
|
||||
// GET - 获取单条思考日志
|
||||
func (h *MemoryHandler) handleThinkingByID(w http.ResponseWriter, r *http.Request) {
|
||||
id := strings.TrimPrefix(r.URL.Path, "/api/v1/thinking/")
|
||||
// 排除子路径
|
||||
if id == "stats" || id == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
thinkingLog, err := h.svc.GetThinkingLogByID(r.Context(), id)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 获取思考日志失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if thinkingLog == nil {
|
||||
writeError(w, http.StatusNotFound, "思考日志不存在")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"thinking": thinkingLog,
|
||||
})
|
||||
}
|
||||
|
||||
// handleThinkingStats GET /api/v1/thinking/stats
|
||||
func (h *MemoryHandler) handleThinkingStats(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.svc.GetThinkingStats(r.Context())
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 获取思考日志统计失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"stats": stats,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryPriority 记忆优先级
|
||||
type MemoryPriority int
|
||||
|
||||
const (
|
||||
MemoryTemp MemoryPriority = 0 // 临时记忆 (会话内)
|
||||
MemoryNormal MemoryPriority = 1 // 普通记忆
|
||||
MemoryImportant MemoryPriority = 2 // 重要记忆
|
||||
MemoryCore MemoryPriority = 3 // 核心记忆 (永远保留)
|
||||
)
|
||||
|
||||
// String 返回优先级的中文描述
|
||||
func (p MemoryPriority) String() string {
|
||||
switch p {
|
||||
case MemoryCore:
|
||||
return "核心"
|
||||
case MemoryImportant:
|
||||
return "重要"
|
||||
case MemoryNormal:
|
||||
return "普通"
|
||||
case MemoryTemp:
|
||||
return "临时"
|
||||
default:
|
||||
return "未知"
|
||||
}
|
||||
}
|
||||
|
||||
// MemoryCategory 记忆分类
|
||||
type MemoryCategory string
|
||||
|
||||
const (
|
||||
CategoryUserPreference MemoryCategory = "user_preference" // 用户偏好 (食物、颜色、习惯)
|
||||
CategoryPersonalInfo MemoryCategory = "personal_info" // 个人信息 (姓名、年龄、职业)
|
||||
CategoryConversation MemoryCategory = "conversation" // 对话摘要
|
||||
CategoryKnowledge MemoryCategory = "knowledge" // 知识性信息
|
||||
CategoryEvent MemoryCategory = "event" // 事件记录
|
||||
CategoryTask MemoryCategory = "task" // 任务/计划
|
||||
CategoryRelationship MemoryCategory = "relationship" // 关系信息
|
||||
|
||||
// 向后兼容的旧分类别名
|
||||
CategoryPreference = CategoryUserPreference
|
||||
CategoryFact = CategoryPersonalInfo
|
||||
CategoryHabit = CategoryUserPreference
|
||||
CategoryOther = CategoryKnowledge
|
||||
)
|
||||
|
||||
// CategoryDisplayName 返回分类的中文显示名
|
||||
func (c MemoryCategory) DisplayName() string {
|
||||
switch c {
|
||||
case CategoryUserPreference:
|
||||
return "用户偏好"
|
||||
case CategoryPersonalInfo:
|
||||
return "个人信息"
|
||||
case CategoryConversation:
|
||||
return "对话摘要"
|
||||
case CategoryKnowledge:
|
||||
return "知识信息"
|
||||
case CategoryEvent:
|
||||
return "事件记录"
|
||||
case CategoryTask:
|
||||
return "任务计划"
|
||||
case CategoryRelationship:
|
||||
return "关系情感"
|
||||
default:
|
||||
return "其他"
|
||||
}
|
||||
}
|
||||
|
||||
// MemoryEntry 记忆条目
|
||||
type MemoryEntry struct {
|
||||
ID string `json:"id" db:"id"`
|
||||
UserID string `json:"user_id" db:"user_id"`
|
||||
Content string `json:"content" db:"content"`
|
||||
Summary string `json:"summary" db:"summary"` // 简短摘要
|
||||
Category MemoryCategory `json:"category" db:"category"`
|
||||
Priority MemoryPriority `json:"priority" db:"priority"`
|
||||
Importance int `json:"importance" db:"importance"` // 重要程度 1-10
|
||||
Keywords []string `json:"keywords" db:"keywords"` // 关键词标签
|
||||
SessionID string `json:"session_id" db:"session_id"` // 来源会话
|
||||
Source string `json:"source" db:"source"` // 来源 (conversation/thinking)
|
||||
Embedding []float32 `json:"-" db:"embedding"` // 向量 (pgvector)
|
||||
AccessCount int `json:"access_count" db:"access_count"`
|
||||
LastAccess time.Time `json:"last_access" db:"last_access"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"` // 最后更新时间
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty" db:"expires_at"` // 临时记忆过期时间
|
||||
}
|
||||
|
||||
// KeywordsJSON 将关键词序列化为 JSON 字符串(用于数据库存储)
|
||||
func (e *MemoryEntry) KeywordsJSON() string {
|
||||
if len(e.Keywords) == 0 {
|
||||
return "[]"
|
||||
}
|
||||
data, _ := json.Marshal(e.Keywords)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// ParseKeywords 从 JSON 字符串解析关键词
|
||||
func ParseKeywords(raw string) []string {
|
||||
if raw == "" || raw == "[]" {
|
||||
return nil
|
||||
}
|
||||
var keywords []string
|
||||
if err := json.Unmarshal([]byte(raw), &keywords); err != nil {
|
||||
return nil
|
||||
}
|
||||
return keywords
|
||||
}
|
||||
|
||||
// SimilarityScore 计算两个记忆条目的简单文本相似度(基于词汇重叠)
|
||||
// 返回值 0.0 - 1.0
|
||||
func (e *MemoryEntry) SimilarityScore(other *MemoryEntry) float64 {
|
||||
if e.Content == other.Content {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
// 基于关键词的重叠度
|
||||
if len(e.Keywords) > 0 && len(other.Keywords) > 0 {
|
||||
keywordSet := make(map[string]bool, len(e.Keywords))
|
||||
for _, k := range e.Keywords {
|
||||
keywordSet[k] = true
|
||||
}
|
||||
overlap := 0
|
||||
for _, k := range other.Keywords {
|
||||
if keywordSet[k] {
|
||||
overlap++
|
||||
}
|
||||
}
|
||||
keywordScore := float64(overlap) / float64(max(len(e.Keywords), len(other.Keywords)))
|
||||
if keywordScore > 0.6 {
|
||||
return keywordScore
|
||||
}
|
||||
}
|
||||
|
||||
// 基于内容的字符级 Jaccard 相似度
|
||||
return jaccardSimilarity(e.Content, other.Content)
|
||||
}
|
||||
|
||||
// jaccardSimilarity 计算两个字符串的 Jaccard 相似度
|
||||
func jaccardSimilarity(a, b string) float64 {
|
||||
if a == b {
|
||||
return 1.0
|
||||
}
|
||||
if len(a) == 0 || len(b) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 使用 bigram 分词
|
||||
bigramsA := make(map[string]int)
|
||||
runesA := []rune(a)
|
||||
for i := 0; i < len(runesA)-1; i++ {
|
||||
bigramsA[string(runesA[i:i+2])]++
|
||||
}
|
||||
|
||||
bigramsB := make(map[string]int)
|
||||
runesB := []rune(b)
|
||||
for i := 0; i < len(runesB)-1; i++ {
|
||||
bigramsB[string(runesB[i:i+2])]++
|
||||
}
|
||||
|
||||
intersection := 0
|
||||
for bg, countA := range bigramsA {
|
||||
if countB, ok := bigramsB[bg]; ok {
|
||||
intersection += min(countA, countB)
|
||||
}
|
||||
}
|
||||
|
||||
union := 0
|
||||
allBigrams := make(map[string]bool)
|
||||
for bg := range bigramsA {
|
||||
allBigrams[bg] = true
|
||||
}
|
||||
for bg := range bigramsB {
|
||||
allBigrams[bg] = true
|
||||
}
|
||||
for bg := range allBigrams {
|
||||
union += max(bigramsA[bg], bigramsB[bg])
|
||||
}
|
||||
|
||||
if union == 0 {
|
||||
return 0.0
|
||||
}
|
||||
return float64(intersection) / float64(union)
|
||||
}
|
||||
|
||||
// MemoryQuery 记忆查询参数
|
||||
type MemoryQuery struct {
|
||||
UserID string
|
||||
Query string // 查询文本
|
||||
Category MemoryCategory
|
||||
Priority MemoryPriority
|
||||
MinImportance int // 最低重要程度筛选
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// ThinkingLog 自主思考日志
|
||||
type ThinkingLog struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls string `json:"tool_calls"` // JSON 数组
|
||||
ToolCallCount int `json:"tool_call_count"`
|
||||
ContentLength int `json:"content_length"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// ThinkingQuery 思考日志查询参数
|
||||
type ThinkingQuery struct {
|
||||
UserID string
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// ThinkingStats 思考日志统计
|
||||
type ThinkingStats struct {
|
||||
TotalLogs int `json:"total_logs"`
|
||||
TotalToolCalls int `json:"total_tool_calls"`
|
||||
AvgContentLen float64 `json:"avg_content_length"`
|
||||
LatestAt string `json:"latest_at"`
|
||||
}
|
||||
@@ -0,0 +1,316 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/model"
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/store"
|
||||
)
|
||||
|
||||
// deDupThreshold 去重相似度阈值
|
||||
const deDupThreshold = 0.75
|
||||
|
||||
// MemoryService 记忆业务逻辑
|
||||
type MemoryService struct {
|
||||
store *store.Store
|
||||
}
|
||||
|
||||
// NewMemoryService 创建记忆服务
|
||||
func NewMemoryService(s *store.Store) *MemoryService {
|
||||
return &MemoryService{store: s}
|
||||
}
|
||||
|
||||
// CreateMemory 创建/保存记忆
|
||||
func (svc *MemoryService) CreateMemory(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
if entry.UserID == "" {
|
||||
return fmt.Errorf("user_id 不能为空")
|
||||
}
|
||||
if entry.Content == "" {
|
||||
return fmt.Errorf("content 不能为空")
|
||||
}
|
||||
if entry.Category == "" {
|
||||
entry.Category = model.CategoryKnowledge
|
||||
}
|
||||
if entry.Importance < 1 {
|
||||
entry.Importance = 5
|
||||
}
|
||||
if entry.Priority < 0 || entry.Priority > 3 {
|
||||
entry.Priority = model.MemoryNormal
|
||||
}
|
||||
if entry.Source == "" {
|
||||
entry.Source = "manual"
|
||||
}
|
||||
|
||||
// 去重检查
|
||||
similar, err := svc.findSimilar(ctx, entry.UserID, entry)
|
||||
if err == nil && similar != nil {
|
||||
// 合并到已有记忆
|
||||
return svc.mergeMemory(ctx, similar, entry)
|
||||
}
|
||||
|
||||
return svc.store.Save(ctx, entry)
|
||||
}
|
||||
|
||||
// GetMemory 获取单个记忆
|
||||
func (svc *MemoryService) GetMemory(ctx context.Context, id string) (*model.MemoryEntry, error) {
|
||||
return svc.store.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// ListMemories 列出用户所有记忆
|
||||
func (svc *MemoryService) ListMemories(ctx context.Context, userID string, category string, minImportance int, limit int) ([]model.MemoryEntry, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
q := model.MemoryQuery{
|
||||
UserID: userID,
|
||||
MinImportance: minImportance,
|
||||
Limit: limit,
|
||||
}
|
||||
if category != "" {
|
||||
q.Category = model.MemoryCategory(category)
|
||||
}
|
||||
|
||||
return svc.store.Query(ctx, q)
|
||||
}
|
||||
|
||||
// UpdateMemory 更新记忆
|
||||
func (svc *MemoryService) UpdateMemory(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
if entry.ID == "" {
|
||||
return fmt.Errorf("id 不能为空")
|
||||
}
|
||||
return svc.store.Update(ctx, entry)
|
||||
}
|
||||
|
||||
// DeleteMemory 删除记忆
|
||||
func (svc *MemoryService) DeleteMemory(ctx context.Context, id string) error {
|
||||
return svc.store.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// QueryMemories 语义查询 + 关键词匹配
|
||||
func (svc *MemoryService) QueryMemories(ctx context.Context, userID string, queryText string, category string, minImportance int, limit int) ([]model.MemoryEntry, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
var allEntries []model.MemoryEntry
|
||||
seen := make(map[string]bool)
|
||||
|
||||
// 1. 关键词匹配检索
|
||||
keywordEntries, err := svc.store.SearchByKeyword(ctx, userID, queryText, limit*2)
|
||||
if err == nil {
|
||||
for _, e := range keywordEntries {
|
||||
if !seen[e.ID] {
|
||||
seen[e.ID] = true
|
||||
allEntries = append(allEntries, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 补充按分类/重要性查询
|
||||
q := model.MemoryQuery{
|
||||
UserID: userID,
|
||||
MinImportance: minImportance,
|
||||
Limit: limit,
|
||||
}
|
||||
if category != "" {
|
||||
q.Category = model.MemoryCategory(category)
|
||||
}
|
||||
categoryEntries, err := svc.store.Query(ctx, q)
|
||||
if err == nil {
|
||||
for _, e := range categoryEntries {
|
||||
if !seen[e.ID] {
|
||||
seen[e.ID] = true
|
||||
allEntries = append(allEntries, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 在应用层做内容匹配过滤
|
||||
queryLower := strings.ToLower(queryText)
|
||||
var matched []model.MemoryEntry
|
||||
for _, entry := range allEntries {
|
||||
contentLower := strings.ToLower(entry.Content)
|
||||
summaryLower := strings.ToLower(entry.Summary)
|
||||
|
||||
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
|
||||
matched = append(matched, entry)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, kw := range entry.Keywords {
|
||||
if strings.Contains(queryLower, strings.ToLower(kw)) ||
|
||||
strings.Contains(strings.ToLower(kw), queryLower) {
|
||||
matched = append(matched, entry)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(matched) == 0 {
|
||||
matched = allEntries
|
||||
}
|
||||
|
||||
// 4. 去重合并
|
||||
matched = svc.deduplicate(matched)
|
||||
|
||||
// 5. 按重要性降序
|
||||
sortByImportance(matched)
|
||||
|
||||
if len(matched) > limit {
|
||||
matched = matched[:limit]
|
||||
}
|
||||
|
||||
return matched, nil
|
||||
}
|
||||
|
||||
// ConsolidateMemories 合并相似记忆
|
||||
func (svc *MemoryService) ConsolidateMemories(ctx context.Context, userID string) (int, error) {
|
||||
return svc.store.ConsolidateMemories(ctx, userID)
|
||||
}
|
||||
|
||||
// DecayMemories 衰减旧记忆
|
||||
func (svc *MemoryService) DecayMemories(ctx context.Context, userID string) (int, int, error) {
|
||||
return svc.store.DecayMemories(ctx, userID)
|
||||
}
|
||||
|
||||
// GetCategories 获取用户分类统计
|
||||
func (svc *MemoryService) GetCategories(ctx context.Context, userID string) (map[string]int, error) {
|
||||
return svc.store.GetCategories(ctx, userID)
|
||||
}
|
||||
|
||||
// findSimilar 查找相似记忆
|
||||
func (svc *MemoryService) findSimilar(ctx context.Context, userID string, newMem *model.MemoryEntry) (*model.MemoryEntry, error) {
|
||||
existing, err := svc.store.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Limit: 100,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range existing {
|
||||
score := existing[i].SimilarityScore(newMem)
|
||||
if score >= deDupThreshold {
|
||||
return &existing[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// mergeMemory 合并新记忆到已有记忆
|
||||
func (svc *MemoryService) mergeMemory(ctx context.Context, existing *model.MemoryEntry, newMem *model.MemoryEntry) error {
|
||||
// 更新内容(如果新内容更有价值)
|
||||
if newMem.Importance > existing.Importance || len(newMem.Content) > len(existing.Content) {
|
||||
existing.Content = newMem.Content
|
||||
existing.Summary = newMem.Summary
|
||||
}
|
||||
|
||||
// 合并关键词
|
||||
keywordSet := make(map[string]bool)
|
||||
for _, k := range existing.Keywords {
|
||||
keywordSet[k] = true
|
||||
}
|
||||
for _, k := range newMem.Keywords {
|
||||
keywordSet[k] = true
|
||||
}
|
||||
mergedKeywords := make([]string, 0, len(keywordSet))
|
||||
for k := range keywordSet {
|
||||
mergedKeywords = append(mergedKeywords, k)
|
||||
}
|
||||
existing.Keywords = mergedKeywords
|
||||
|
||||
// 取最高重要性
|
||||
if newMem.Importance > existing.Importance {
|
||||
existing.Importance = newMem.Importance
|
||||
}
|
||||
|
||||
// 取最高优先级
|
||||
if newMem.Priority > existing.Priority {
|
||||
existing.Priority = newMem.Priority
|
||||
}
|
||||
|
||||
existing.AccessCount++
|
||||
existing.Source = "merged"
|
||||
|
||||
log.Printf("[memory-service] 合并记忆 [%s|%d★]: %s (去重)", existing.Category, existing.Importance, existing.Summary)
|
||||
return svc.store.Update(ctx, existing)
|
||||
}
|
||||
|
||||
// deduplicate 去重合并
|
||||
func (svc *MemoryService) deduplicate(entries []model.MemoryEntry) []model.MemoryEntry {
|
||||
if len(entries) < 2 {
|
||||
return entries
|
||||
}
|
||||
|
||||
result := make([]model.MemoryEntry, 0, len(entries))
|
||||
discarded := make(map[int]bool)
|
||||
|
||||
for i := 0; i < len(entries); i++ {
|
||||
if discarded[i] {
|
||||
continue
|
||||
}
|
||||
for j := i + 1; j < len(entries); j++ {
|
||||
if discarded[j] {
|
||||
continue
|
||||
}
|
||||
score := entries[i].SimilarityScore(&entries[j])
|
||||
if score >= deDupThreshold {
|
||||
if entries[j].Importance > entries[i].Importance ||
|
||||
(entries[j].Importance == entries[i].Importance && entries[j].Priority > entries[i].Priority) {
|
||||
discarded[i] = true
|
||||
break
|
||||
} else {
|
||||
discarded[j] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !discarded[i] {
|
||||
result = append(result, entries[i])
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// sortByImportance 按 Importance 降序, Priority 降序排列
|
||||
func sortByImportance(entries []model.MemoryEntry) {
|
||||
for i := 0; i < len(entries); i++ {
|
||||
for j := i + 1; j < len(entries); j++ {
|
||||
if entries[j].Importance > entries[i].Importance ||
|
||||
(entries[j].Importance == entries[i].Importance && entries[j].Priority > entries[i].Priority) {
|
||||
entries[i], entries[j] = entries[j], entries[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SaveThinkingLog 保存自主思考日志
|
||||
func (svc *MemoryService) SaveThinkingLog(ctx context.Context, tl *model.ThinkingLog) error {
|
||||
if tl.Content == "" {
|
||||
return fmt.Errorf("content 不能为空")
|
||||
}
|
||||
if tl.UserID == "" {
|
||||
tl.UserID = "admin_admin"
|
||||
}
|
||||
return svc.store.SaveThinkingLog(ctx, tl)
|
||||
}
|
||||
|
||||
// QueryThinkingLogs 分页查询思考日志
|
||||
func (svc *MemoryService) QueryThinkingLogs(ctx context.Context, q model.ThinkingQuery) ([]model.ThinkingLog, error) {
|
||||
return svc.store.QueryThinkingLogs(ctx, q)
|
||||
}
|
||||
|
||||
// GetThinkingLogByID 获取单条思考日志
|
||||
func (svc *MemoryService) GetThinkingLogByID(ctx context.Context, id string) (*model.ThinkingLog, error) {
|
||||
return svc.store.GetThinkingLogByID(ctx, id)
|
||||
}
|
||||
|
||||
// GetThinkingStats 获取思考日志统计
|
||||
func (svc *MemoryService) GetThinkingStats(ctx context.Context) (*model.ThinkingStats, error) {
|
||||
return svc.store.GetThinkingStats(ctx)
|
||||
}
|
||||
@@ -0,0 +1,765 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/model"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// deDupThreshold 去重相似度阈值
|
||||
const deDupThreshold = 0.75
|
||||
|
||||
const reconnectInterval = 30 * time.Second
|
||||
|
||||
// Store 记忆持久化存储(PostgreSQL + pgvector)
|
||||
type Store struct {
|
||||
databaseURL string
|
||||
mu sync.RWMutex
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// errDBNotReady 数据库未就绪时返回的友好错误
|
||||
var errDBNotReady = fmt.Errorf("记忆系统未就绪: 数据库连接不可用,正在后台重试连接")
|
||||
|
||||
// NewStore 创建记忆存储
|
||||
// 连接失败时不返回 error,而是启动后台重连循环
|
||||
func NewStore(connStr string) *Store {
|
||||
s := &Store{
|
||||
databaseURL: connStr,
|
||||
}
|
||||
|
||||
// 尝试初始连接
|
||||
if err := s.Reconnect(); err != nil {
|
||||
log.Printf("[memory-service] ⚠ 记忆存储初始化: 数据库连接失败 (%v),将在后台每30秒重试", err)
|
||||
} else {
|
||||
log.Println("[memory-service] 记忆存储已就绪")
|
||||
}
|
||||
|
||||
// 启动后台重连 goroutine
|
||||
go s.reconnectLoop()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// reconnectLoop 后台重连循环
|
||||
func (s *Store) reconnectLoop() {
|
||||
ticker := time.NewTicker(reconnectInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.mu.RLock()
|
||||
ready := s.db != nil
|
||||
s.mu.RUnlock()
|
||||
|
||||
if ready {
|
||||
// 数据库已连接,检查连接是否仍然有效
|
||||
s.mu.RLock()
|
||||
db := s.db
|
||||
s.mu.RUnlock()
|
||||
if db != nil {
|
||||
if err := db.Ping(); err != nil {
|
||||
log.Printf("[memory-service] ⚠ 数据库连接丢失: %v,开始重连", err)
|
||||
s.mu.Lock()
|
||||
if s.db != nil {
|
||||
s.db.Close()
|
||||
s.db = nil
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !s.IsReady() {
|
||||
if err := s.Reconnect(); err != nil {
|
||||
log.Printf("[memory-service] ⚠ 数据库重连失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reconnect 尝试重连数据库并执行迁移
|
||||
func (s *Store) Reconnect() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 如果已有有效连接,先检查
|
||||
if s.db != nil {
|
||||
if err := s.db.Ping(); err == nil {
|
||||
return nil // 仍然有效
|
||||
}
|
||||
// 连接已失效,关闭旧连接
|
||||
s.db.Close()
|
||||
s.db = nil
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", s.databaseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return fmt.Errorf("数据库ping失败: %w", err)
|
||||
}
|
||||
|
||||
s.db = db
|
||||
|
||||
// 执行建表迁移
|
||||
if err := s.migrate(); err != nil {
|
||||
log.Printf("[memory-service] ⚠ 数据库迁移失败: %v", err)
|
||||
s.db.Close()
|
||||
s.db = nil
|
||||
return fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
log.Println("[memory-service] ✅ 数据库重连成功,记忆系统已就绪")
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsReady 返回数据库是否可用
|
||||
func (s *Store) IsReady() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.db != nil
|
||||
}
|
||||
|
||||
// getDB 获取当前数据库连接(带读锁保护)
|
||||
func (s *Store) getDB() *sql.DB {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.db
|
||||
}
|
||||
|
||||
// migrate 创建表结构
|
||||
func (s *Store) migrate() error {
|
||||
queries := []string{
|
||||
`CREATE EXTENSION IF NOT EXISTS vector`,
|
||||
`CREATE TABLE IF NOT EXISTS memory_entries (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id VARCHAR(64) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
summary TEXT DEFAULT '',
|
||||
category VARCHAR(32) DEFAULT 'knowledge',
|
||||
priority INT DEFAULT 1,
|
||||
importance INT DEFAULT 5,
|
||||
keywords TEXT DEFAULT '[]',
|
||||
session_id VARCHAR(64) DEFAULT '',
|
||||
source TEXT DEFAULT 'conversation',
|
||||
embedding vector(1536),
|
||||
access_count INT DEFAULT 0,
|
||||
last_access TIMESTAMPTZ DEFAULT NOW(),
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_user_id ON memory_entries(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_category ON memory_entries(category)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_priority ON memory_entries(priority)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_importance ON memory_entries(importance)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_user_priority ON memory_entries(user_id, priority DESC)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_user_importance ON memory_entries(user_id, importance DESC)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_source ON memory_entries(source)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_category_importance ON memory_entries(category, importance DESC)`,
|
||||
`CREATE TABLE IF NOT EXISTS thinking_logs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id VARCHAR(64) NOT NULL DEFAULT 'admin_admin',
|
||||
content TEXT NOT NULL,
|
||||
tool_calls TEXT DEFAULT '[]',
|
||||
tool_call_count INT DEFAULT 0,
|
||||
content_length INT DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_tl_user_id ON thinking_logs(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_tl_created_at ON thinking_logs(created_at DESC)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
if _, err := s.db.Exec(q); err != nil {
|
||||
return fmt.Errorf("执行迁移 '%s' 失败: %w", q[:min(50, len(q))], err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save 保存记忆
|
||||
func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return errDBNotReady
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if entry.Source == "" {
|
||||
entry.Source = "conversation"
|
||||
}
|
||||
if entry.Importance == 0 {
|
||||
entry.Importance = 5
|
||||
}
|
||||
|
||||
query := `INSERT INTO memory_entries (user_id, content, summary, category, priority, importance, keywords, session_id, source, embedding, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id, created_at`
|
||||
|
||||
var embedding interface{}
|
||||
if len(entry.Embedding) > 0 {
|
||||
vec := make([]float64, len(entry.Embedding))
|
||||
for i, v := range entry.Embedding {
|
||||
vec[i] = float64(v)
|
||||
}
|
||||
embedding = fmt.Sprintf("[%s]", joinFloats(vec))
|
||||
}
|
||||
|
||||
return db.QueryRowContext(ctx, query,
|
||||
entry.UserID, entry.Content, entry.Summary,
|
||||
string(entry.Category), int(entry.Priority),
|
||||
entry.Importance, entry.KeywordsJSON(),
|
||||
entry.SessionID, entry.Source, embedding, entry.ExpiresAt,
|
||||
).Scan(&entry.ID, &entry.CreatedAt)
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取记忆
|
||||
func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
|
||||
session_id, source, access_count, last_access, created_at, updated_at, expires_at
|
||||
FROM memory_entries WHERE id = $1`
|
||||
|
||||
entry := &model.MemoryEntry{}
|
||||
var category, keywordsRaw string
|
||||
err := db.QueryRowContext(ctx, query, id).Scan(
|
||||
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
|
||||
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
|
||||
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
|
||||
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询记忆失败: %w", err)
|
||||
}
|
||||
entry.Category = model.MemoryCategory(category)
|
||||
entry.Keywords = model.ParseKeywords(keywordsRaw)
|
||||
|
||||
// 更新访问计数
|
||||
go s.incrementAccess(context.Background(), id)
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// Query 按条件查询记忆
|
||||
func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryEntry, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
if q.Limit <= 0 {
|
||||
q.Limit = 10
|
||||
}
|
||||
|
||||
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
|
||||
session_id, source, access_count, last_access, created_at, updated_at, expires_at
|
||||
FROM memory_entries WHERE user_id = $1`
|
||||
args := []interface{}{q.UserID}
|
||||
argIdx := 2
|
||||
|
||||
if q.Category != "" {
|
||||
query += fmt.Sprintf(" AND category = $%d", argIdx)
|
||||
args = append(args, string(q.Category))
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if q.Priority >= 0 {
|
||||
query += fmt.Sprintf(" AND priority >= $%d", argIdx)
|
||||
args = append(args, int(q.Priority))
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if q.MinImportance > 0 {
|
||||
query += fmt.Sprintf(" AND importance >= $%d", argIdx)
|
||||
args = append(args, q.MinImportance)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
query += fmt.Sprintf(" ORDER BY priority DESC, importance DESC, created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
|
||||
args = append(args, q.Limit, q.Offset)
|
||||
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询记忆失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanMemoryRows(rows)
|
||||
}
|
||||
|
||||
// Delete 删除记忆
|
||||
func (s *Store) Delete(ctx context.Context, id string) error {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return errDBNotReady
|
||||
}
|
||||
_, err := db.ExecContext(ctx, `DELETE FROM memory_entries WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// PurgeExpired 清理过期记忆
|
||||
func (s *Store) PurgeExpired(ctx context.Context) (int64, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return 0, errDBNotReady
|
||||
}
|
||||
result, err := db.ExecContext(ctx,
|
||||
`DELETE FROM memory_entries WHERE expires_at IS NOT NULL AND expires_at < NOW()`)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// SearchByVector 向量相似度搜索
|
||||
func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []float64, limit int) ([]model.MemoryEntry, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
vecStr := fmt.Sprintf("[%s]", joinFloats(embedding))
|
||||
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
|
||||
session_id, source, access_count, last_access, created_at, updated_at, expires_at,
|
||||
1 - (embedding <=> $1) AS similarity
|
||||
FROM memory_entries
|
||||
WHERE user_id = $2 AND embedding IS NOT NULL
|
||||
ORDER BY embedding <=> $1
|
||||
LIMIT $3`
|
||||
|
||||
rows, err := db.QueryContext(ctx, query, vecStr, userID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("向量搜索失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []model.MemoryEntry
|
||||
for rows.Next() {
|
||||
var entry model.MemoryEntry
|
||||
var category, keywordsRaw string
|
||||
var similarity float64
|
||||
if err := rows.Scan(
|
||||
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
|
||||
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
|
||||
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
|
||||
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
|
||||
&similarity,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("扫描向量搜索结果失败: %w", err)
|
||||
}
|
||||
entry.Category = model.MemoryCategory(category)
|
||||
entry.Keywords = model.ParseKeywords(keywordsRaw)
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
// SearchByKeyword 关键词匹配查询
|
||||
func (s *Store) SearchByKeyword(ctx context.Context, userID, keyword string, limit int) ([]model.MemoryEntry, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
|
||||
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
|
||||
session_id, source, access_count, last_access, created_at, updated_at, expires_at
|
||||
FROM memory_entries
|
||||
WHERE user_id = $1 AND (content ILIKE $2 OR summary ILIKE $2 OR keywords ILIKE $2)
|
||||
ORDER BY priority DESC, importance DESC
|
||||
LIMIT $3`
|
||||
|
||||
likePattern := "%" + keyword + "%"
|
||||
rows, err := db.QueryContext(ctx, query, userID, likePattern, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("关键词搜索失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanMemoryRows(rows)
|
||||
}
|
||||
|
||||
// Update 更新记忆
|
||||
func (s *Store) Update(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return errDBNotReady
|
||||
}
|
||||
|
||||
query := `UPDATE memory_entries SET content = $1, summary = $2, category = $3, priority = $4,
|
||||
importance = $5, keywords = $6, source = $7, updated_at = NOW()
|
||||
WHERE id = $8`
|
||||
|
||||
_, err := db.ExecContext(ctx, query,
|
||||
entry.Content, entry.Summary, string(entry.Category), int(entry.Priority),
|
||||
entry.Importance, entry.KeywordsJSON(), entry.Source, entry.ID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetCategories 获取用户所有分类及计数
|
||||
func (s *Store) GetCategories(ctx context.Context, userID string) (map[string]int, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
rows, err := db.QueryContext(ctx,
|
||||
`SELECT category, COUNT(*) FROM memory_entries WHERE user_id = $1 GROUP BY category ORDER BY category`,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询分类统计失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
categories := make(map[string]int)
|
||||
for rows.Next() {
|
||||
var cat string
|
||||
var count int
|
||||
if err := rows.Scan(&cat, &count); err != nil {
|
||||
return nil, fmt.Errorf("扫描分类统计失败: %w", err)
|
||||
}
|
||||
categories[cat] = count
|
||||
}
|
||||
|
||||
return categories, rows.Err()
|
||||
}
|
||||
|
||||
// Count 获取用户的记忆总数
|
||||
func (s *Store) Count(ctx context.Context, userID string) (int, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return 0, errDBNotReady
|
||||
}
|
||||
|
||||
var count int
|
||||
err := db.QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM memory_entries WHERE user_id = $1`,
|
||||
userID,
|
||||
).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("统计记忆失败: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ConsolidateMemories 记忆整理:合并相似记忆
|
||||
func (s *Store) ConsolidateMemories(ctx context.Context, userID string) (int, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return 0, errDBNotReady
|
||||
}
|
||||
|
||||
// 获取用户所有记忆
|
||||
allMems, err := s.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Limit: 500,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("查询记忆失败: %w", err)
|
||||
}
|
||||
|
||||
if len(allMems) < 2 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
merged := 0
|
||||
for i := 0; i < len(allMems); i++ {
|
||||
if allMems[i].ID == "" {
|
||||
continue
|
||||
}
|
||||
for j := i + 1; j < len(allMems); j++ {
|
||||
if allMems[j].ID == "" {
|
||||
continue
|
||||
}
|
||||
score := allMems[i].SimilarityScore(&allMems[j])
|
||||
if score >= deDupThreshold {
|
||||
keep, discard := &allMems[i], &allMems[j]
|
||||
if discard.Importance > keep.Importance || discard.Priority > keep.Priority {
|
||||
keep, discard = discard, keep
|
||||
}
|
||||
|
||||
// 合并关键词
|
||||
keywordSet := make(map[string]bool)
|
||||
for _, k := range keep.Keywords {
|
||||
keywordSet[k] = true
|
||||
}
|
||||
for _, k := range discard.Keywords {
|
||||
keywordSet[k] = true
|
||||
}
|
||||
mergedKeywords := make([]string, 0, len(keywordSet))
|
||||
for k := range keywordSet {
|
||||
mergedKeywords = append(mergedKeywords, k)
|
||||
}
|
||||
keep.Keywords = mergedKeywords
|
||||
|
||||
if keep.Importance < 10 {
|
||||
keep.Importance++
|
||||
}
|
||||
keep.Source = "consolidated"
|
||||
|
||||
if err := s.Update(ctx, keep); err != nil {
|
||||
log.Printf("[memory-service] 合并更新记忆 %s 失败: %v", keep.ID, err)
|
||||
continue
|
||||
}
|
||||
if err := s.Delete(ctx, discard.ID); err != nil {
|
||||
log.Printf("[memory-service] 合并删除记忆 %s 失败: %v", discard.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
discard.ID = ""
|
||||
merged++
|
||||
log.Printf("[memory-service] 合并相似记忆: %s <- %s (相似度 %.0f%%)",
|
||||
keep.ID[:min(8, len(keep.ID))], discard.ID[:min(8, len(discard.ID))], score*100)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if merged > 0 {
|
||||
log.Printf("[memory-service] 记忆整理完成: 用户 %s 合并 %d 条相似记忆", userID, merged)
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
// DecayMemories 记忆衰减:降低长期未访问的低重要性记忆
|
||||
func (s *Store) DecayMemories(ctx context.Context, userID string) (int, int, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return 0, 0, errDBNotReady
|
||||
}
|
||||
|
||||
result1, err := db.ExecContext(ctx, `
|
||||
UPDATE memory_entries SET priority = GREATEST(priority - 1, 0), updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
AND access_count < 3
|
||||
AND last_access < NOW() - INTERVAL '30 days'
|
||||
AND importance < 3
|
||||
AND priority > 0
|
||||
AND category NOT IN ('personal_info', 'user_preference')
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("衰减低活跃记忆失败: %w", err)
|
||||
}
|
||||
|
||||
decayed1, _ := result1.RowsAffected()
|
||||
|
||||
result2, err := db.ExecContext(ctx, `
|
||||
DELETE FROM memory_entries
|
||||
WHERE user_id = $1
|
||||
AND priority = 0
|
||||
AND access_count = 0
|
||||
AND last_access < NOW() - INTERVAL '14 days'
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("清理临时记忆失败: %w", err)
|
||||
}
|
||||
|
||||
deleted2, _ := result2.RowsAffected()
|
||||
|
||||
total := decayed1 + deleted2
|
||||
if total > 0 {
|
||||
log.Printf("[memory-service] 记忆衰减完成: 用户 %s 降级 %d 条, 删除 %d 条过期临时记忆",
|
||||
userID, decayed1, deleted2)
|
||||
}
|
||||
|
||||
return int(decayed1), int(deleted2), nil
|
||||
}
|
||||
|
||||
func (s *Store) incrementAccess(ctx context.Context, id string) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
db.ExecContext(ctx,
|
||||
`UPDATE memory_entries SET access_count = access_count + 1, last_access = NOW() WHERE id = $1`, id)
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (s *Store) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.db != nil {
|
||||
return s.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// scanMemoryRows 扫描记忆行(通用方法)
|
||||
func scanMemoryRows(rows *sql.Rows) ([]model.MemoryEntry, error) {
|
||||
var entries []model.MemoryEntry
|
||||
for rows.Next() {
|
||||
var entry model.MemoryEntry
|
||||
var category, keywordsRaw string
|
||||
if err := rows.Scan(
|
||||
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
|
||||
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
|
||||
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
|
||||
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("扫描记忆行失败: %w", err)
|
||||
}
|
||||
entry.Category = model.MemoryCategory(category)
|
||||
entry.Keywords = model.ParseKeywords(keywordsRaw)
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
// joinFloats 将 float64 切片转为逗号分隔字符串
|
||||
func joinFloats(vec []float64) string {
|
||||
if len(vec) == 0 {
|
||||
return ""
|
||||
}
|
||||
s := fmt.Sprintf("%f", vec[0])
|
||||
for i := 1; i < len(vec); i++ {
|
||||
s += fmt.Sprintf(",%f", vec[i])
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// SaveThinkingLog 保存自主思考日志
|
||||
func (s *Store) SaveThinkingLog(ctx context.Context, log *model.ThinkingLog) error {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return errDBNotReady
|
||||
}
|
||||
|
||||
if log.UserID == "" {
|
||||
log.UserID = "admin_admin"
|
||||
}
|
||||
if log.ToolCalls == "" {
|
||||
log.ToolCalls = "[]"
|
||||
}
|
||||
|
||||
query := `INSERT INTO thinking_logs (user_id, content, tool_calls, tool_call_count, content_length)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id, created_at`
|
||||
|
||||
return db.QueryRowContext(ctx, query,
|
||||
log.UserID, log.Content, log.ToolCalls,
|
||||
log.ToolCallCount, log.ContentLength,
|
||||
).Scan(&log.ID, &log.CreatedAt)
|
||||
}
|
||||
|
||||
// QueryThinkingLogs 分页查询思考日志
|
||||
func (s *Store) QueryThinkingLogs(ctx context.Context, q model.ThinkingQuery) ([]model.ThinkingLog, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
if q.Limit <= 0 {
|
||||
q.Limit = 20
|
||||
}
|
||||
|
||||
query := `SELECT id, user_id, content, tool_calls, tool_call_count, content_length, created_at
|
||||
FROM thinking_logs`
|
||||
args := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
if q.UserID != "" {
|
||||
query += fmt.Sprintf(" WHERE user_id = $%d", argIdx)
|
||||
args = append(args, q.UserID)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
query += fmt.Sprintf(" ORDER BY created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
|
||||
args = append(args, q.Limit, q.Offset)
|
||||
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询思考日志失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var logs []model.ThinkingLog
|
||||
for rows.Next() {
|
||||
var tl model.ThinkingLog
|
||||
if err := rows.Scan(&tl.ID, &tl.UserID, &tl.Content, &tl.ToolCalls,
|
||||
&tl.ToolCallCount, &tl.ContentLength, &tl.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描思考日志行失败: %w", err)
|
||||
}
|
||||
logs = append(logs, tl)
|
||||
}
|
||||
return logs, rows.Err()
|
||||
}
|
||||
|
||||
// GetThinkingLogByID 根据ID获取单条思考日志
|
||||
func (s *Store) GetThinkingLogByID(ctx context.Context, id string) (*model.ThinkingLog, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
query := `SELECT id, user_id, content, tool_calls, tool_call_count, content_length, created_at
|
||||
FROM thinking_logs WHERE id = $1`
|
||||
|
||||
tl := &model.ThinkingLog{}
|
||||
err := db.QueryRowContext(ctx, query, id).Scan(
|
||||
&tl.ID, &tl.UserID, &tl.Content, &tl.ToolCalls,
|
||||
&tl.ToolCallCount, &tl.ContentLength, &tl.CreatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询思考日志失败: %w", err)
|
||||
}
|
||||
return tl, nil
|
||||
}
|
||||
|
||||
// GetThinkingStats 获取思考日志统计信息
|
||||
func (s *Store) GetThinkingStats(ctx context.Context) (*model.ThinkingStats, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
query := `SELECT
|
||||
COALESCE(COUNT(*), 0),
|
||||
COALESCE(SUM(tool_call_count), 0),
|
||||
COALESCE(AVG(content_length), 0),
|
||||
COALESCE(MAX(created_at)::TEXT, '')
|
||||
FROM thinking_logs`
|
||||
|
||||
stats := &model.ThinkingStats{}
|
||||
err := db.QueryRowContext(ctx, query).Scan(
|
||||
&stats.TotalLogs, &stats.TotalToolCalls,
|
||||
&stats.AvgContentLen, &stats.LatestAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询思考日志统计失败: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
Reference in New Issue
Block a user