feat: Round 5 - Memory Service, Tool Engine, Call Records, Thinking Logs

- Fix: Session history flash (race condition + WS guard)
- Fix: Chat background overlay + sidebar transparency
- Fix: IoT device control (Chinese action names, status field)
- Feat: Independent memory-service (port 8091, 13 endpoints)
- Feat: Independent tool-engine service (port 8092, 13 tools)
- Feat: Tool call logs with paginated DevTools panel
- Feat: Thinking log records with DevTools panel
- Feat: Future development roadmap document
- Chore: Updated .gitignore, go.work, DevTools config
- Chore: 5-service health check, project review docs
This commit is contained in:
2026-05-18 20:05:14 +08:00
parent b6ec36886c
commit 78e3f450c2
54 changed files with 7846 additions and 106 deletions
@@ -0,0 +1,45 @@
package config
import (
"fmt"
"os"
)
// Config 记忆服务配置
type Config struct {
Port string
DatabaseURL string
}
// Load 从环境变量加载配置
func Load() *Config {
return &Config{
Port: getEnv("PORT", "8091"),
DatabaseURL: buildDatabaseURL(),
}
}
// buildDatabaseURL 构建 PostgreSQL 连接字符串
func buildDatabaseURL() string {
// 优先使用 DB_URL 环境变量(简化模式)
if url := os.Getenv("DB_URL"); url != "" {
return url
}
host := getEnv("POSTGRES_HOST", "localhost")
port := getEnv("POSTGRES_PORT", "5432")
user := getEnv("POSTGRES_USER", "cyrene")
password := getEnv("POSTGRES_PASSWORD", "change_me")
dbname := getEnv("POSTGRES_DB", "cyrene_ai")
sslmode := getEnv("POSTGRES_SSLMODE", "disable")
return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s",
user, password, host, port, dbname, sslmode)
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
@@ -0,0 +1,542 @@
package handler
import (
"encoding/json"
"log"
"net/http"
"strings"
"github.com/yourname/cyrene-ai/memory-service/internal/model"
"github.com/yourname/cyrene-ai/memory-service/internal/service"
)
// MemoryHandler HTTP API 处理器
type MemoryHandler struct {
svc *service.MemoryService
}
// NewMemoryHandler 创建记忆处理器
func NewMemoryHandler(svc *service.MemoryService) *MemoryHandler {
return &MemoryHandler{svc: svc}
}
// RegisterRoutes 注册所有路由到 mux
func (h *MemoryHandler) RegisterRoutes(mux *http.ServeMux) {
// POST /api/v1/memories - 创建/保存记忆
mux.HandleFunc("/api/v1/memories", h.handleMemories)
// GET/DELETE/PUT /api/v1/memories/... (带 ID)
mux.HandleFunc("/api/v1/memories/", h.handleMemoryByID)
// POST /api/v1/memories/query - 语义查询
mux.HandleFunc("/api/v1/memories/query", h.handleQuery)
// POST /api/v1/memories/consolidate - 合并相似记忆
mux.HandleFunc("/api/v1/memories/consolidate", h.handleConsolidate)
// POST /api/v1/memories/decay - 衰减旧记忆
mux.HandleFunc("/api/v1/memories/decay", h.handleDecay)
// GET /api/v1/memories/categories - 获取类别统计
mux.HandleFunc("/api/v1/memories/categories", h.handleCategories)
// 自主思考日志 API
mux.HandleFunc("/api/v1/thinking", h.handleThinking)
mux.HandleFunc("/api/v1/thinking/", h.handleThinkingByID)
mux.HandleFunc("/api/v1/thinking/stats", h.handleThinkingStats)
}
// handleMemories 处理 /api/v1/memories
// GET - 列出用户记忆 (?user_id=xxx&category=xxx&min_importance=xxx&limit=xxx)
// POST - 创建记忆
func (h *MemoryHandler) handleMemories(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
h.listMemories(w, r)
case http.MethodPost:
h.createMemory(w, r)
default:
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
}
}
// listMemories GET /api/v1/memories?user_id=xxx
func (h *MemoryHandler) listMemories(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
if userID == "" {
writeError(w, http.StatusBadRequest, "缺少 user_id 参数")
return
}
category := r.URL.Query().Get("category")
limit := queryInt(r, "limit", 50)
minImportance := queryInt(r, "min_importance", 0)
memories, err := h.svc.ListMemories(r.Context(), userID, category, minImportance, limit)
if err != nil {
log.Printf("[memory-handler] 列出记忆失败: %v", err)
writeError(w, http.StatusInternalServerError, err.Error())
return
}
if memories == nil {
memories = []model.MemoryEntry{}
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"user_id": userID,
"memories": memories,
"total": len(memories),
})
}
// createMemory POST /api/v1/memories
func (h *MemoryHandler) createMemory(w http.ResponseWriter, r *http.Request) {
var req struct {
UserID string `json:"user_id"`
Content string `json:"content"`
Summary string `json:"summary"`
Category string `json:"category"`
Priority int `json:"priority"`
Importance int `json:"importance"`
Keywords []string `json:"keywords"`
SessionID string `json:"session_id"`
Source string `json:"source"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
return
}
if req.UserID == "" || req.Content == "" {
writeError(w, http.StatusBadRequest, "缺少 user_id 或 content")
return
}
entry := &model.MemoryEntry{
UserID: req.UserID,
Content: req.Content,
Summary: req.Summary,
Category: model.MemoryCategory(req.Category),
Priority: model.MemoryPriority(req.Priority),
Importance: req.Importance,
Keywords: req.Keywords,
SessionID: req.SessionID,
Source: req.Source,
}
if err := h.svc.CreateMemory(r.Context(), entry); err != nil {
log.Printf("[memory-handler] 创建记忆失败: %v", err)
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusCreated, map[string]interface{}{
"status": "saved",
"memory": entry,
})
}
// handleMemoryByID 处理 /api/v1/memories/{id}
// GET - 获取单个记忆
// PUT - 更新记忆
// DELETE - 删除记忆
func (h *MemoryHandler) handleMemoryByID(w http.ResponseWriter, r *http.Request) {
id := strings.TrimPrefix(r.URL.Path, "/api/v1/memories/")
// 排除子路径 (query, consolidate, decay, categories)
switch id {
case "query", "consolidate", "decay", "categories":
return // 这些有自己独立的处理器
}
if id == "" {
writeError(w, http.StatusBadRequest, "缺少记忆 ID")
return
}
switch r.Method {
case http.MethodGet:
h.getMemory(w, r, id)
case http.MethodPut:
h.updateMemory(w, r, id)
case http.MethodDelete:
h.deleteMemory(w, r, id)
default:
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
}
}
// getMemory GET /api/v1/memories/:id
func (h *MemoryHandler) getMemory(w http.ResponseWriter, r *http.Request, id string) {
entry, err := h.svc.GetMemory(r.Context(), id)
if err != nil {
log.Printf("[memory-handler] 获取记忆失败: %v", err)
writeError(w, http.StatusInternalServerError, err.Error())
return
}
if entry == nil {
writeError(w, http.StatusNotFound, "记忆不存在")
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"memory": entry,
})
}
// updateMemory PUT /api/v1/memories/:id
func (h *MemoryHandler) updateMemory(w http.ResponseWriter, r *http.Request, id string) {
var req struct {
Content string `json:"content"`
Summary string `json:"summary"`
Category string `json:"category"`
Priority int `json:"priority"`
Importance int `json:"importance"`
Keywords []string `json:"keywords"`
Source string `json:"source"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
return
}
entry := &model.MemoryEntry{
ID: id,
Content: req.Content,
Summary: req.Summary,
Category: model.MemoryCategory(req.Category),
Priority: model.MemoryPriority(req.Priority),
Importance: req.Importance,
Keywords: req.Keywords,
Source: req.Source,
}
if err := h.svc.UpdateMemory(r.Context(), entry); err != nil {
log.Printf("[memory-handler] 更新记忆失败: %v", err)
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"status": "updated",
"memory_id": id,
})
}
// deleteMemory DELETE /api/v1/memories/:id
func (h *MemoryHandler) deleteMemory(w http.ResponseWriter, r *http.Request, id string) {
if err := h.svc.DeleteMemory(r.Context(), id); err != nil {
log.Printf("[memory-handler] 删除记忆失败: %v", err)
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"status": "deleted",
"memory_id": id,
})
}
// handleQuery POST /api/v1/memories/query
func (h *MemoryHandler) handleQuery(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
var req struct {
UserID string `json:"user_id"`
QueryText string `json:"query_text"`
Category string `json:"category"`
MinImportance int `json:"min_importance"`
Limit int `json:"limit"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
return
}
if req.UserID == "" {
writeError(w, http.StatusBadRequest, "缺少 user_id")
return
}
if req.Limit <= 0 {
req.Limit = 10
}
memories, err := h.svc.QueryMemories(r.Context(), req.UserID, req.QueryText, req.Category, req.MinImportance, req.Limit)
if err != nil {
log.Printf("[memory-handler] 查询记忆失败: %v", err)
writeError(w, http.StatusInternalServerError, err.Error())
return
}
if memories == nil {
memories = []model.MemoryEntry{}
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"user_id": req.UserID,
"query": req.QueryText,
"memories": memories,
"total": len(memories),
})
}
// handleConsolidate POST /api/v1/memories/consolidate
func (h *MemoryHandler) handleConsolidate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
var req struct {
UserID string `json:"user_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
return
}
if req.UserID == "" {
writeError(w, http.StatusBadRequest, "缺少 user_id")
return
}
merged, err := h.svc.ConsolidateMemories(r.Context(), req.UserID)
if err != nil {
log.Printf("[memory-handler] 合并记忆失败: %v", err)
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"status": "consolidated",
"user_id": req.UserID,
"merged": merged,
"message": "记忆整理完成",
})
}
// handleDecay POST /api/v1/memories/decay
func (h *MemoryHandler) handleDecay(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
var req struct {
UserID string `json:"user_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
return
}
if req.UserID == "" {
writeError(w, http.StatusBadRequest, "缺少 user_id")
return
}
decayed, deleted, err := h.svc.DecayMemories(r.Context(), req.UserID)
if err != nil {
log.Printf("[memory-handler] 衰减记忆失败: %v", err)
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"status": "decayed",
"user_id": req.UserID,
"decayed": decayed,
"deleted": deleted,
"message": "记忆衰减完成",
})
}
// handleCategories GET /api/v1/memories/categories?user_id=xxx
func (h *MemoryHandler) handleCategories(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
userID := r.URL.Query().Get("user_id")
if userID == "" {
writeError(w, http.StatusBadRequest, "缺少 user_id 参数")
return
}
categories, err := h.svc.GetCategories(r.Context(), userID)
if err != nil {
log.Printf("[memory-handler] 获取分类统计失败: %v", err)
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"user_id": userID,
"categories": categories,
})
}
// --- 辅助函数 ---
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func writeError(w http.ResponseWriter, status int, message string) {
writeJSON(w, status, map[string]interface{}{
"error": message,
})
}
func queryInt(r *http.Request, key string, fallback int) int {
v := r.URL.Query().Get(key)
if v == "" {
return fallback
}
var result int
for _, c := range v {
if c < '0' || c > '9' {
return fallback
}
result = result*10 + int(c-'0')
}
return result
}
// handleThinking 处理 /api/v1/thinking
// GET - 分页查询思考日志 (?user_id=xxx&limit=xxx&offset=xxx)
// POST - 保存思考日志
func (h *MemoryHandler) handleThinking(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
h.listThinkingLogs(w, r)
case http.MethodPost:
h.createThinkingLog(w, r)
default:
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
}
}
// createThinkingLog POST /api/v1/thinking
func (h *MemoryHandler) createThinkingLog(w http.ResponseWriter, r *http.Request) {
var req struct {
UserID string `json:"user_id"`
Content string `json:"content"`
ToolCalls string `json:"tool_calls"`
ToolCallCount int `json:"tool_call_count"`
ContentLength int `json:"content_length"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
return
}
if req.Content == "" {
writeError(w, http.StatusBadRequest, "缺少 content")
return
}
tl := &model.ThinkingLog{
UserID: req.UserID,
Content: req.Content,
ToolCalls: req.ToolCalls,
ToolCallCount: req.ToolCallCount,
ContentLength: req.ContentLength,
}
if err := h.svc.SaveThinkingLog(r.Context(), tl); err != nil {
log.Printf("[memory-handler] 保存思考日志失败: %v", err)
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusCreated, map[string]interface{}{
"status": "saved",
"thinking": tl,
})
}
// listThinkingLogs GET /api/v1/thinking?user_id=xxx&limit=xxx&offset=xxx
func (h *MemoryHandler) listThinkingLogs(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id")
limit := queryInt(r, "limit", 20)
offset := queryInt(r, "offset", 0)
logs, err := h.svc.QueryThinkingLogs(r.Context(), model.ThinkingQuery{
UserID: userID,
Limit: limit,
Offset: offset,
})
if err != nil {
log.Printf("[memory-handler] 查询思考日志失败: %v", err)
writeError(w, http.StatusInternalServerError, err.Error())
return
}
if logs == nil {
logs = []model.ThinkingLog{}
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"logs": logs,
"total": len(logs),
})
}
// handleThinkingByID 处理 /api/v1/thinking/{id}
// GET - 获取单条思考日志
func (h *MemoryHandler) handleThinkingByID(w http.ResponseWriter, r *http.Request) {
id := strings.TrimPrefix(r.URL.Path, "/api/v1/thinking/")
// 排除子路径
if id == "stats" || id == "" {
return
}
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
thinkingLog, err := h.svc.GetThinkingLogByID(r.Context(), id)
if err != nil {
log.Printf("[memory-handler] 获取思考日志失败: %v", err)
writeError(w, http.StatusInternalServerError, err.Error())
return
}
if thinkingLog == nil {
writeError(w, http.StatusNotFound, "思考日志不存在")
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"thinking": thinkingLog,
})
}
// handleThinkingStats GET /api/v1/thinking/stats
func (h *MemoryHandler) handleThinkingStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
stats, err := h.svc.GetThinkingStats(r.Context())
if err != nil {
log.Printf("[memory-handler] 获取思考日志统计失败: %v", err)
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"stats": stats,
})
}
@@ -0,0 +1,227 @@
package model
import (
"encoding/json"
"time"
)
// MemoryPriority 记忆优先级
type MemoryPriority int
const (
MemoryTemp MemoryPriority = 0 // 临时记忆 (会话内)
MemoryNormal MemoryPriority = 1 // 普通记忆
MemoryImportant MemoryPriority = 2 // 重要记忆
MemoryCore MemoryPriority = 3 // 核心记忆 (永远保留)
)
// String 返回优先级的中文描述
func (p MemoryPriority) String() string {
switch p {
case MemoryCore:
return "核心"
case MemoryImportant:
return "重要"
case MemoryNormal:
return "普通"
case MemoryTemp:
return "临时"
default:
return "未知"
}
}
// MemoryCategory 记忆分类
type MemoryCategory string
const (
CategoryUserPreference MemoryCategory = "user_preference" // 用户偏好 (食物、颜色、习惯)
CategoryPersonalInfo MemoryCategory = "personal_info" // 个人信息 (姓名、年龄、职业)
CategoryConversation MemoryCategory = "conversation" // 对话摘要
CategoryKnowledge MemoryCategory = "knowledge" // 知识性信息
CategoryEvent MemoryCategory = "event" // 事件记录
CategoryTask MemoryCategory = "task" // 任务/计划
CategoryRelationship MemoryCategory = "relationship" // 关系信息
// 向后兼容的旧分类别名
CategoryPreference = CategoryUserPreference
CategoryFact = CategoryPersonalInfo
CategoryHabit = CategoryUserPreference
CategoryOther = CategoryKnowledge
)
// CategoryDisplayName 返回分类的中文显示名
func (c MemoryCategory) DisplayName() string {
switch c {
case CategoryUserPreference:
return "用户偏好"
case CategoryPersonalInfo:
return "个人信息"
case CategoryConversation:
return "对话摘要"
case CategoryKnowledge:
return "知识信息"
case CategoryEvent:
return "事件记录"
case CategoryTask:
return "任务计划"
case CategoryRelationship:
return "关系情感"
default:
return "其他"
}
}
// MemoryEntry 记忆条目
type MemoryEntry struct {
ID string `json:"id" db:"id"`
UserID string `json:"user_id" db:"user_id"`
Content string `json:"content" db:"content"`
Summary string `json:"summary" db:"summary"` // 简短摘要
Category MemoryCategory `json:"category" db:"category"`
Priority MemoryPriority `json:"priority" db:"priority"`
Importance int `json:"importance" db:"importance"` // 重要程度 1-10
Keywords []string `json:"keywords" db:"keywords"` // 关键词标签
SessionID string `json:"session_id" db:"session_id"` // 来源会话
Source string `json:"source" db:"source"` // 来源 (conversation/thinking)
Embedding []float32 `json:"-" db:"embedding"` // 向量 (pgvector)
AccessCount int `json:"access_count" db:"access_count"`
LastAccess time.Time `json:"last_access" db:"last_access"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"` // 最后更新时间
ExpiresAt *time.Time `json:"expires_at,omitempty" db:"expires_at"` // 临时记忆过期时间
}
// KeywordsJSON 将关键词序列化为 JSON 字符串(用于数据库存储)
func (e *MemoryEntry) KeywordsJSON() string {
if len(e.Keywords) == 0 {
return "[]"
}
data, _ := json.Marshal(e.Keywords)
return string(data)
}
// ParseKeywords 从 JSON 字符串解析关键词
func ParseKeywords(raw string) []string {
if raw == "" || raw == "[]" {
return nil
}
var keywords []string
if err := json.Unmarshal([]byte(raw), &keywords); err != nil {
return nil
}
return keywords
}
// SimilarityScore 计算两个记忆条目的简单文本相似度(基于词汇重叠)
// 返回值 0.0 - 1.0
func (e *MemoryEntry) SimilarityScore(other *MemoryEntry) float64 {
if e.Content == other.Content {
return 1.0
}
// 基于关键词的重叠度
if len(e.Keywords) > 0 && len(other.Keywords) > 0 {
keywordSet := make(map[string]bool, len(e.Keywords))
for _, k := range e.Keywords {
keywordSet[k] = true
}
overlap := 0
for _, k := range other.Keywords {
if keywordSet[k] {
overlap++
}
}
keywordScore := float64(overlap) / float64(max(len(e.Keywords), len(other.Keywords)))
if keywordScore > 0.6 {
return keywordScore
}
}
// 基于内容的字符级 Jaccard 相似度
return jaccardSimilarity(e.Content, other.Content)
}
// jaccardSimilarity 计算两个字符串的 Jaccard 相似度
func jaccardSimilarity(a, b string) float64 {
if a == b {
return 1.0
}
if len(a) == 0 || len(b) == 0 {
return 0.0
}
// 使用 bigram 分词
bigramsA := make(map[string]int)
runesA := []rune(a)
for i := 0; i < len(runesA)-1; i++ {
bigramsA[string(runesA[i:i+2])]++
}
bigramsB := make(map[string]int)
runesB := []rune(b)
for i := 0; i < len(runesB)-1; i++ {
bigramsB[string(runesB[i:i+2])]++
}
intersection := 0
for bg, countA := range bigramsA {
if countB, ok := bigramsB[bg]; ok {
intersection += min(countA, countB)
}
}
union := 0
allBigrams := make(map[string]bool)
for bg := range bigramsA {
allBigrams[bg] = true
}
for bg := range bigramsB {
allBigrams[bg] = true
}
for bg := range allBigrams {
union += max(bigramsA[bg], bigramsB[bg])
}
if union == 0 {
return 0.0
}
return float64(intersection) / float64(union)
}
// MemoryQuery 记忆查询参数
type MemoryQuery struct {
UserID string
Query string // 查询文本
Category MemoryCategory
Priority MemoryPriority
MinImportance int // 最低重要程度筛选
Limit int
Offset int
}
// ThinkingLog 自主思考日志
type ThinkingLog struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Content string `json:"content"`
ToolCalls string `json:"tool_calls"` // JSON 数组
ToolCallCount int `json:"tool_call_count"`
ContentLength int `json:"content_length"`
CreatedAt time.Time `json:"created_at"`
}
// ThinkingQuery 思考日志查询参数
type ThinkingQuery struct {
UserID string
Limit int
Offset int
}
// ThinkingStats 思考日志统计
type ThinkingStats struct {
TotalLogs int `json:"total_logs"`
TotalToolCalls int `json:"total_tool_calls"`
AvgContentLen float64 `json:"avg_content_length"`
LatestAt string `json:"latest_at"`
}
@@ -0,0 +1,316 @@
package service
import (
"context"
"fmt"
"log"
"strings"
"github.com/yourname/cyrene-ai/memory-service/internal/model"
"github.com/yourname/cyrene-ai/memory-service/internal/store"
)
// deDupThreshold 去重相似度阈值
const deDupThreshold = 0.75
// MemoryService 记忆业务逻辑
type MemoryService struct {
store *store.Store
}
// NewMemoryService 创建记忆服务
func NewMemoryService(s *store.Store) *MemoryService {
return &MemoryService{store: s}
}
// CreateMemory 创建/保存记忆
func (svc *MemoryService) CreateMemory(ctx context.Context, entry *model.MemoryEntry) error {
if entry.UserID == "" {
return fmt.Errorf("user_id 不能为空")
}
if entry.Content == "" {
return fmt.Errorf("content 不能为空")
}
if entry.Category == "" {
entry.Category = model.CategoryKnowledge
}
if entry.Importance < 1 {
entry.Importance = 5
}
if entry.Priority < 0 || entry.Priority > 3 {
entry.Priority = model.MemoryNormal
}
if entry.Source == "" {
entry.Source = "manual"
}
// 去重检查
similar, err := svc.findSimilar(ctx, entry.UserID, entry)
if err == nil && similar != nil {
// 合并到已有记忆
return svc.mergeMemory(ctx, similar, entry)
}
return svc.store.Save(ctx, entry)
}
// GetMemory 获取单个记忆
func (svc *MemoryService) GetMemory(ctx context.Context, id string) (*model.MemoryEntry, error) {
return svc.store.GetByID(ctx, id)
}
// ListMemories 列出用户所有记忆
func (svc *MemoryService) ListMemories(ctx context.Context, userID string, category string, minImportance int, limit int) ([]model.MemoryEntry, error) {
if limit <= 0 {
limit = 50
}
q := model.MemoryQuery{
UserID: userID,
MinImportance: minImportance,
Limit: limit,
}
if category != "" {
q.Category = model.MemoryCategory(category)
}
return svc.store.Query(ctx, q)
}
// UpdateMemory 更新记忆
func (svc *MemoryService) UpdateMemory(ctx context.Context, entry *model.MemoryEntry) error {
if entry.ID == "" {
return fmt.Errorf("id 不能为空")
}
return svc.store.Update(ctx, entry)
}
// DeleteMemory 删除记忆
func (svc *MemoryService) DeleteMemory(ctx context.Context, id string) error {
return svc.store.Delete(ctx, id)
}
// QueryMemories 语义查询 + 关键词匹配
func (svc *MemoryService) QueryMemories(ctx context.Context, userID string, queryText string, category string, minImportance int, limit int) ([]model.MemoryEntry, error) {
if limit <= 0 {
limit = 10
}
var allEntries []model.MemoryEntry
seen := make(map[string]bool)
// 1. 关键词匹配检索
keywordEntries, err := svc.store.SearchByKeyword(ctx, userID, queryText, limit*2)
if err == nil {
for _, e := range keywordEntries {
if !seen[e.ID] {
seen[e.ID] = true
allEntries = append(allEntries, e)
}
}
}
// 2. 补充按分类/重要性查询
q := model.MemoryQuery{
UserID: userID,
MinImportance: minImportance,
Limit: limit,
}
if category != "" {
q.Category = model.MemoryCategory(category)
}
categoryEntries, err := svc.store.Query(ctx, q)
if err == nil {
for _, e := range categoryEntries {
if !seen[e.ID] {
seen[e.ID] = true
allEntries = append(allEntries, e)
}
}
}
// 3. 在应用层做内容匹配过滤
queryLower := strings.ToLower(queryText)
var matched []model.MemoryEntry
for _, entry := range allEntries {
contentLower := strings.ToLower(entry.Content)
summaryLower := strings.ToLower(entry.Summary)
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
matched = append(matched, entry)
continue
}
for _, kw := range entry.Keywords {
if strings.Contains(queryLower, strings.ToLower(kw)) ||
strings.Contains(strings.ToLower(kw), queryLower) {
matched = append(matched, entry)
break
}
}
}
if len(matched) == 0 {
matched = allEntries
}
// 4. 去重合并
matched = svc.deduplicate(matched)
// 5. 按重要性降序
sortByImportance(matched)
if len(matched) > limit {
matched = matched[:limit]
}
return matched, nil
}
// ConsolidateMemories 合并相似记忆
func (svc *MemoryService) ConsolidateMemories(ctx context.Context, userID string) (int, error) {
return svc.store.ConsolidateMemories(ctx, userID)
}
// DecayMemories 衰减旧记忆
func (svc *MemoryService) DecayMemories(ctx context.Context, userID string) (int, int, error) {
return svc.store.DecayMemories(ctx, userID)
}
// GetCategories 获取用户分类统计
func (svc *MemoryService) GetCategories(ctx context.Context, userID string) (map[string]int, error) {
return svc.store.GetCategories(ctx, userID)
}
// findSimilar 查找相似记忆
func (svc *MemoryService) findSimilar(ctx context.Context, userID string, newMem *model.MemoryEntry) (*model.MemoryEntry, error) {
existing, err := svc.store.Query(ctx, model.MemoryQuery{
UserID: userID,
Limit: 100,
})
if err != nil {
return nil, err
}
for i := range existing {
score := existing[i].SimilarityScore(newMem)
if score >= deDupThreshold {
return &existing[i], nil
}
}
return nil, nil
}
// mergeMemory 合并新记忆到已有记忆
func (svc *MemoryService) mergeMemory(ctx context.Context, existing *model.MemoryEntry, newMem *model.MemoryEntry) error {
// 更新内容(如果新内容更有价值)
if newMem.Importance > existing.Importance || len(newMem.Content) > len(existing.Content) {
existing.Content = newMem.Content
existing.Summary = newMem.Summary
}
// 合并关键词
keywordSet := make(map[string]bool)
for _, k := range existing.Keywords {
keywordSet[k] = true
}
for _, k := range newMem.Keywords {
keywordSet[k] = true
}
mergedKeywords := make([]string, 0, len(keywordSet))
for k := range keywordSet {
mergedKeywords = append(mergedKeywords, k)
}
existing.Keywords = mergedKeywords
// 取最高重要性
if newMem.Importance > existing.Importance {
existing.Importance = newMem.Importance
}
// 取最高优先级
if newMem.Priority > existing.Priority {
existing.Priority = newMem.Priority
}
existing.AccessCount++
existing.Source = "merged"
log.Printf("[memory-service] 合并记忆 [%s|%d★]: %s (去重)", existing.Category, existing.Importance, existing.Summary)
return svc.store.Update(ctx, existing)
}
// deduplicate 去重合并
func (svc *MemoryService) deduplicate(entries []model.MemoryEntry) []model.MemoryEntry {
if len(entries) < 2 {
return entries
}
result := make([]model.MemoryEntry, 0, len(entries))
discarded := make(map[int]bool)
for i := 0; i < len(entries); i++ {
if discarded[i] {
continue
}
for j := i + 1; j < len(entries); j++ {
if discarded[j] {
continue
}
score := entries[i].SimilarityScore(&entries[j])
if score >= deDupThreshold {
if entries[j].Importance > entries[i].Importance ||
(entries[j].Importance == entries[i].Importance && entries[j].Priority > entries[i].Priority) {
discarded[i] = true
break
} else {
discarded[j] = true
}
}
}
if !discarded[i] {
result = append(result, entries[i])
}
}
return result
}
// sortByImportance 按 Importance 降序, Priority 降序排列
func sortByImportance(entries []model.MemoryEntry) {
for i := 0; i < len(entries); i++ {
for j := i + 1; j < len(entries); j++ {
if entries[j].Importance > entries[i].Importance ||
(entries[j].Importance == entries[i].Importance && entries[j].Priority > entries[i].Priority) {
entries[i], entries[j] = entries[j], entries[i]
}
}
}
}
// SaveThinkingLog 保存自主思考日志
func (svc *MemoryService) SaveThinkingLog(ctx context.Context, tl *model.ThinkingLog) error {
if tl.Content == "" {
return fmt.Errorf("content 不能为空")
}
if tl.UserID == "" {
tl.UserID = "admin_admin"
}
return svc.store.SaveThinkingLog(ctx, tl)
}
// QueryThinkingLogs 分页查询思考日志
func (svc *MemoryService) QueryThinkingLogs(ctx context.Context, q model.ThinkingQuery) ([]model.ThinkingLog, error) {
return svc.store.QueryThinkingLogs(ctx, q)
}
// GetThinkingLogByID 获取单条思考日志
func (svc *MemoryService) GetThinkingLogByID(ctx context.Context, id string) (*model.ThinkingLog, error) {
return svc.store.GetThinkingLogByID(ctx, id)
}
// GetThinkingStats 获取思考日志统计
func (svc *MemoryService) GetThinkingStats(ctx context.Context) (*model.ThinkingStats, error) {
return svc.store.GetThinkingStats(ctx)
}
@@ -0,0 +1,765 @@
package store
import (
"context"
"database/sql"
"fmt"
"log"
"sync"
"time"
"github.com/yourname/cyrene-ai/memory-service/internal/model"
_ "github.com/lib/pq"
)
// deDupThreshold 去重相似度阈值
const deDupThreshold = 0.75
const reconnectInterval = 30 * time.Second
// Store 记忆持久化存储(PostgreSQL + pgvector
type Store struct {
databaseURL string
mu sync.RWMutex
db *sql.DB
}
// errDBNotReady 数据库未就绪时返回的友好错误
var errDBNotReady = fmt.Errorf("记忆系统未就绪: 数据库连接不可用,正在后台重试连接")
// NewStore 创建记忆存储
// 连接失败时不返回 error,而是启动后台重连循环
func NewStore(connStr string) *Store {
s := &Store{
databaseURL: connStr,
}
// 尝试初始连接
if err := s.Reconnect(); err != nil {
log.Printf("[memory-service] ⚠ 记忆存储初始化: 数据库连接失败 (%v),将在后台每30秒重试", err)
} else {
log.Println("[memory-service] 记忆存储已就绪")
}
// 启动后台重连 goroutine
go s.reconnectLoop()
return s
}
// reconnectLoop 后台重连循环
func (s *Store) reconnectLoop() {
ticker := time.NewTicker(reconnectInterval)
defer ticker.Stop()
for range ticker.C {
s.mu.RLock()
ready := s.db != nil
s.mu.RUnlock()
if ready {
// 数据库已连接,检查连接是否仍然有效
s.mu.RLock()
db := s.db
s.mu.RUnlock()
if db != nil {
if err := db.Ping(); err != nil {
log.Printf("[memory-service] ⚠ 数据库连接丢失: %v,开始重连", err)
s.mu.Lock()
if s.db != nil {
s.db.Close()
s.db = nil
}
s.mu.Unlock()
}
}
}
if !s.IsReady() {
if err := s.Reconnect(); err != nil {
log.Printf("[memory-service] ⚠ 数据库重连失败: %v", err)
}
}
}
}
// Reconnect 尝试重连数据库并执行迁移
func (s *Store) Reconnect() error {
s.mu.Lock()
defer s.mu.Unlock()
// 如果已有有效连接,先检查
if s.db != nil {
if err := s.db.Ping(); err == nil {
return nil // 仍然有效
}
// 连接已失效,关闭旧连接
s.db.Close()
s.db = nil
}
db, err := sql.Open("postgres", s.databaseURL)
if err != nil {
return fmt.Errorf("连接数据库失败: %w", err)
}
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
if err := db.Ping(); err != nil {
db.Close()
return fmt.Errorf("数据库ping失败: %w", err)
}
s.db = db
// 执行建表迁移
if err := s.migrate(); err != nil {
log.Printf("[memory-service] ⚠ 数据库迁移失败: %v", err)
s.db.Close()
s.db = nil
return fmt.Errorf("数据库迁移失败: %w", err)
}
log.Println("[memory-service] ✅ 数据库重连成功,记忆系统已就绪")
return nil
}
// IsReady 返回数据库是否可用
func (s *Store) IsReady() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.db != nil
}
// getDB 获取当前数据库连接(带读锁保护)
func (s *Store) getDB() *sql.DB {
s.mu.RLock()
defer s.mu.RUnlock()
return s.db
}
// migrate 创建表结构
func (s *Store) migrate() error {
queries := []string{
`CREATE EXTENSION IF NOT EXISTS vector`,
`CREATE TABLE IF NOT EXISTS memory_entries (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id VARCHAR(64) NOT NULL,
content TEXT NOT NULL,
summary TEXT DEFAULT '',
category VARCHAR(32) DEFAULT 'knowledge',
priority INT DEFAULT 1,
importance INT DEFAULT 5,
keywords TEXT DEFAULT '[]',
session_id VARCHAR(64) DEFAULT '',
source TEXT DEFAULT 'conversation',
embedding vector(1536),
access_count INT DEFAULT 0,
last_access TIMESTAMPTZ DEFAULT NOW(),
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
expires_at TIMESTAMPTZ
)`,
`CREATE INDEX IF NOT EXISTS idx_me_user_id ON memory_entries(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_me_category ON memory_entries(category)`,
`CREATE INDEX IF NOT EXISTS idx_me_priority ON memory_entries(priority)`,
`CREATE INDEX IF NOT EXISTS idx_me_importance ON memory_entries(importance)`,
`CREATE INDEX IF NOT EXISTS idx_me_user_priority ON memory_entries(user_id, priority DESC)`,
`CREATE INDEX IF NOT EXISTS idx_me_user_importance ON memory_entries(user_id, importance DESC)`,
`CREATE INDEX IF NOT EXISTS idx_me_source ON memory_entries(source)`,
`CREATE INDEX IF NOT EXISTS idx_me_category_importance ON memory_entries(category, importance DESC)`,
`CREATE TABLE IF NOT EXISTS thinking_logs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id VARCHAR(64) NOT NULL DEFAULT 'admin_admin',
content TEXT NOT NULL,
tool_calls TEXT DEFAULT '[]',
tool_call_count INT DEFAULT 0,
content_length INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_tl_user_id ON thinking_logs(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_tl_created_at ON thinking_logs(created_at DESC)`,
}
for _, q := range queries {
if _, err := s.db.Exec(q); err != nil {
return fmt.Errorf("执行迁移 '%s' 失败: %w", q[:min(50, len(q))], err)
}
}
return nil
}
// Save 保存记忆
func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
db := s.getDB()
if db == nil {
return errDBNotReady
}
// 设置默认值
if entry.Source == "" {
entry.Source = "conversation"
}
if entry.Importance == 0 {
entry.Importance = 5
}
query := `INSERT INTO memory_entries (user_id, content, summary, category, priority, importance, keywords, session_id, source, embedding, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id, created_at`
var embedding interface{}
if len(entry.Embedding) > 0 {
vec := make([]float64, len(entry.Embedding))
for i, v := range entry.Embedding {
vec[i] = float64(v)
}
embedding = fmt.Sprintf("[%s]", joinFloats(vec))
}
return db.QueryRowContext(ctx, query,
entry.UserID, entry.Content, entry.Summary,
string(entry.Category), int(entry.Priority),
entry.Importance, entry.KeywordsJSON(),
entry.SessionID, entry.Source, embedding, entry.ExpiresAt,
).Scan(&entry.ID, &entry.CreatedAt)
}
// GetByID 根据ID获取记忆
func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
session_id, source, access_count, last_access, created_at, updated_at, expires_at
FROM memory_entries WHERE id = $1`
entry := &model.MemoryEntry{}
var category, keywordsRaw string
err := db.QueryRowContext(ctx, query, id).Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询记忆失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entry.Keywords = model.ParseKeywords(keywordsRaw)
// 更新访问计数
go s.incrementAccess(context.Background(), id)
return entry, nil
}
// Query 按条件查询记忆
func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryEntry, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
if q.Limit <= 0 {
q.Limit = 10
}
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
session_id, source, access_count, last_access, created_at, updated_at, expires_at
FROM memory_entries WHERE user_id = $1`
args := []interface{}{q.UserID}
argIdx := 2
if q.Category != "" {
query += fmt.Sprintf(" AND category = $%d", argIdx)
args = append(args, string(q.Category))
argIdx++
}
if q.Priority >= 0 {
query += fmt.Sprintf(" AND priority >= $%d", argIdx)
args = append(args, int(q.Priority))
argIdx++
}
if q.MinImportance > 0 {
query += fmt.Sprintf(" AND importance >= $%d", argIdx)
args = append(args, q.MinImportance)
argIdx++
}
query += fmt.Sprintf(" ORDER BY priority DESC, importance DESC, created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
args = append(args, q.Limit, q.Offset)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("查询记忆失败: %w", err)
}
defer rows.Close()
return scanMemoryRows(rows)
}
// Delete 删除记忆
func (s *Store) Delete(ctx context.Context, id string) error {
db := s.getDB()
if db == nil {
return errDBNotReady
}
_, err := db.ExecContext(ctx, `DELETE FROM memory_entries WHERE id = $1`, id)
return err
}
// PurgeExpired 清理过期记忆
func (s *Store) PurgeExpired(ctx context.Context) (int64, error) {
db := s.getDB()
if db == nil {
return 0, errDBNotReady
}
result, err := db.ExecContext(ctx,
`DELETE FROM memory_entries WHERE expires_at IS NOT NULL AND expires_at < NOW()`)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// SearchByVector 向量相似度搜索
func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []float64, limit int) ([]model.MemoryEntry, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
if limit <= 0 {
limit = 5
}
vecStr := fmt.Sprintf("[%s]", joinFloats(embedding))
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
session_id, source, access_count, last_access, created_at, updated_at, expires_at,
1 - (embedding <=> $1) AS similarity
FROM memory_entries
WHERE user_id = $2 AND embedding IS NOT NULL
ORDER BY embedding <=> $1
LIMIT $3`
rows, err := db.QueryContext(ctx, query, vecStr, userID, limit)
if err != nil {
return nil, fmt.Errorf("向量搜索失败: %w", err)
}
defer rows.Close()
var entries []model.MemoryEntry
for rows.Next() {
var entry model.MemoryEntry
var category, keywordsRaw string
var similarity float64
if err := rows.Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
&similarity,
); err != nil {
return nil, fmt.Errorf("扫描向量搜索结果失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entry.Keywords = model.ParseKeywords(keywordsRaw)
entries = append(entries, entry)
}
return entries, rows.Err()
}
// SearchByKeyword 关键词匹配查询
func (s *Store) SearchByKeyword(ctx context.Context, userID, keyword string, limit int) ([]model.MemoryEntry, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
if limit <= 0 {
limit = 20
}
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
session_id, source, access_count, last_access, created_at, updated_at, expires_at
FROM memory_entries
WHERE user_id = $1 AND (content ILIKE $2 OR summary ILIKE $2 OR keywords ILIKE $2)
ORDER BY priority DESC, importance DESC
LIMIT $3`
likePattern := "%" + keyword + "%"
rows, err := db.QueryContext(ctx, query, userID, likePattern, limit)
if err != nil {
return nil, fmt.Errorf("关键词搜索失败: %w", err)
}
defer rows.Close()
return scanMemoryRows(rows)
}
// Update 更新记忆
func (s *Store) Update(ctx context.Context, entry *model.MemoryEntry) error {
db := s.getDB()
if db == nil {
return errDBNotReady
}
query := `UPDATE memory_entries SET content = $1, summary = $2, category = $3, priority = $4,
importance = $5, keywords = $6, source = $7, updated_at = NOW()
WHERE id = $8`
_, err := db.ExecContext(ctx, query,
entry.Content, entry.Summary, string(entry.Category), int(entry.Priority),
entry.Importance, entry.KeywordsJSON(), entry.Source, entry.ID,
)
return err
}
// GetCategories 获取用户所有分类及计数
func (s *Store) GetCategories(ctx context.Context, userID string) (map[string]int, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
rows, err := db.QueryContext(ctx,
`SELECT category, COUNT(*) FROM memory_entries WHERE user_id = $1 GROUP BY category ORDER BY category`,
userID,
)
if err != nil {
return nil, fmt.Errorf("查询分类统计失败: %w", err)
}
defer rows.Close()
categories := make(map[string]int)
for rows.Next() {
var cat string
var count int
if err := rows.Scan(&cat, &count); err != nil {
return nil, fmt.Errorf("扫描分类统计失败: %w", err)
}
categories[cat] = count
}
return categories, rows.Err()
}
// Count 获取用户的记忆总数
func (s *Store) Count(ctx context.Context, userID string) (int, error) {
db := s.getDB()
if db == nil {
return 0, errDBNotReady
}
var count int
err := db.QueryRowContext(ctx,
`SELECT COUNT(*) FROM memory_entries WHERE user_id = $1`,
userID,
).Scan(&count)
if err != nil {
return 0, fmt.Errorf("统计记忆失败: %w", err)
}
return count, nil
}
// ConsolidateMemories 记忆整理:合并相似记忆
func (s *Store) ConsolidateMemories(ctx context.Context, userID string) (int, error) {
db := s.getDB()
if db == nil {
return 0, errDBNotReady
}
// 获取用户所有记忆
allMems, err := s.Query(ctx, model.MemoryQuery{
UserID: userID,
Limit: 500,
})
if err != nil {
return 0, fmt.Errorf("查询记忆失败: %w", err)
}
if len(allMems) < 2 {
return 0, nil
}
merged := 0
for i := 0; i < len(allMems); i++ {
if allMems[i].ID == "" {
continue
}
for j := i + 1; j < len(allMems); j++ {
if allMems[j].ID == "" {
continue
}
score := allMems[i].SimilarityScore(&allMems[j])
if score >= deDupThreshold {
keep, discard := &allMems[i], &allMems[j]
if discard.Importance > keep.Importance || discard.Priority > keep.Priority {
keep, discard = discard, keep
}
// 合并关键词
keywordSet := make(map[string]bool)
for _, k := range keep.Keywords {
keywordSet[k] = true
}
for _, k := range discard.Keywords {
keywordSet[k] = true
}
mergedKeywords := make([]string, 0, len(keywordSet))
for k := range keywordSet {
mergedKeywords = append(mergedKeywords, k)
}
keep.Keywords = mergedKeywords
if keep.Importance < 10 {
keep.Importance++
}
keep.Source = "consolidated"
if err := s.Update(ctx, keep); err != nil {
log.Printf("[memory-service] 合并更新记忆 %s 失败: %v", keep.ID, err)
continue
}
if err := s.Delete(ctx, discard.ID); err != nil {
log.Printf("[memory-service] 合并删除记忆 %s 失败: %v", discard.ID, err)
continue
}
discard.ID = ""
merged++
log.Printf("[memory-service] 合并相似记忆: %s <- %s (相似度 %.0f%%)",
keep.ID[:min(8, len(keep.ID))], discard.ID[:min(8, len(discard.ID))], score*100)
}
}
}
if merged > 0 {
log.Printf("[memory-service] 记忆整理完成: 用户 %s 合并 %d 条相似记忆", userID, merged)
}
return merged, nil
}
// DecayMemories 记忆衰减:降低长期未访问的低重要性记忆
func (s *Store) DecayMemories(ctx context.Context, userID string) (int, int, error) {
db := s.getDB()
if db == nil {
return 0, 0, errDBNotReady
}
result1, err := db.ExecContext(ctx, `
UPDATE memory_entries SET priority = GREATEST(priority - 1, 0), updated_at = NOW()
WHERE user_id = $1
AND access_count < 3
AND last_access < NOW() - INTERVAL '30 days'
AND importance < 3
AND priority > 0
AND category NOT IN ('personal_info', 'user_preference')
`, userID)
if err != nil {
return 0, 0, fmt.Errorf("衰减低活跃记忆失败: %w", err)
}
decayed1, _ := result1.RowsAffected()
result2, err := db.ExecContext(ctx, `
DELETE FROM memory_entries
WHERE user_id = $1
AND priority = 0
AND access_count = 0
AND last_access < NOW() - INTERVAL '14 days'
`, userID)
if err != nil {
return 0, 0, fmt.Errorf("清理临时记忆失败: %w", err)
}
deleted2, _ := result2.RowsAffected()
total := decayed1 + deleted2
if total > 0 {
log.Printf("[memory-service] 记忆衰减完成: 用户 %s 降级 %d 条, 删除 %d 条过期临时记忆",
userID, decayed1, deleted2)
}
return int(decayed1), int(deleted2), nil
}
func (s *Store) incrementAccess(ctx context.Context, id string) {
db := s.getDB()
if db == nil {
return
}
db.ExecContext(ctx,
`UPDATE memory_entries SET access_count = access_count + 1, last_access = NOW() WHERE id = $1`, id)
}
// Close 关闭数据库连接
func (s *Store) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.db != nil {
return s.db.Close()
}
return nil
}
// scanMemoryRows 扫描记忆行(通用方法)
func scanMemoryRows(rows *sql.Rows) ([]model.MemoryEntry, error) {
var entries []model.MemoryEntry
for rows.Next() {
var entry model.MemoryEntry
var category, keywordsRaw string
if err := rows.Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
); err != nil {
return nil, fmt.Errorf("扫描记忆行失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entry.Keywords = model.ParseKeywords(keywordsRaw)
entries = append(entries, entry)
}
return entries, rows.Err()
}
// joinFloats 将 float64 切片转为逗号分隔字符串
func joinFloats(vec []float64) string {
if len(vec) == 0 {
return ""
}
s := fmt.Sprintf("%f", vec[0])
for i := 1; i < len(vec); i++ {
s += fmt.Sprintf(",%f", vec[i])
}
return s
}
// SaveThinkingLog 保存自主思考日志
func (s *Store) SaveThinkingLog(ctx context.Context, log *model.ThinkingLog) error {
db := s.getDB()
if db == nil {
return errDBNotReady
}
if log.UserID == "" {
log.UserID = "admin_admin"
}
if log.ToolCalls == "" {
log.ToolCalls = "[]"
}
query := `INSERT INTO thinking_logs (user_id, content, tool_calls, tool_call_count, content_length)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, created_at`
return db.QueryRowContext(ctx, query,
log.UserID, log.Content, log.ToolCalls,
log.ToolCallCount, log.ContentLength,
).Scan(&log.ID, &log.CreatedAt)
}
// QueryThinkingLogs 分页查询思考日志
func (s *Store) QueryThinkingLogs(ctx context.Context, q model.ThinkingQuery) ([]model.ThinkingLog, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
if q.Limit <= 0 {
q.Limit = 20
}
query := `SELECT id, user_id, content, tool_calls, tool_call_count, content_length, created_at
FROM thinking_logs`
args := []interface{}{}
argIdx := 1
if q.UserID != "" {
query += fmt.Sprintf(" WHERE user_id = $%d", argIdx)
args = append(args, q.UserID)
argIdx++
}
query += fmt.Sprintf(" ORDER BY created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
args = append(args, q.Limit, q.Offset)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("查询思考日志失败: %w", err)
}
defer rows.Close()
var logs []model.ThinkingLog
for rows.Next() {
var tl model.ThinkingLog
if err := rows.Scan(&tl.ID, &tl.UserID, &tl.Content, &tl.ToolCalls,
&tl.ToolCallCount, &tl.ContentLength, &tl.CreatedAt); err != nil {
return nil, fmt.Errorf("扫描思考日志行失败: %w", err)
}
logs = append(logs, tl)
}
return logs, rows.Err()
}
// GetThinkingLogByID 根据ID获取单条思考日志
func (s *Store) GetThinkingLogByID(ctx context.Context, id string) (*model.ThinkingLog, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
query := `SELECT id, user_id, content, tool_calls, tool_call_count, content_length, created_at
FROM thinking_logs WHERE id = $1`
tl := &model.ThinkingLog{}
err := db.QueryRowContext(ctx, query, id).Scan(
&tl.ID, &tl.UserID, &tl.Content, &tl.ToolCalls,
&tl.ToolCallCount, &tl.ContentLength, &tl.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询思考日志失败: %w", err)
}
return tl, nil
}
// GetThinkingStats 获取思考日志统计信息
func (s *Store) GetThinkingStats(ctx context.Context) (*model.ThinkingStats, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
query := `SELECT
COALESCE(COUNT(*), 0),
COALESCE(SUM(tool_call_count), 0),
COALESCE(AVG(content_length), 0),
COALESCE(MAX(created_at)::TEXT, '')
FROM thinking_logs`
stats := &model.ThinkingStats{}
err := db.QueryRowContext(ctx, query).Scan(
&stats.TotalLogs, &stats.TotalToolCalls,
&stats.AvgContentLen, &stats.LatestAt,
)
if err != nil {
return nil, fmt.Errorf("查询思考日志统计失败: %w", err)
}
return stats, nil
}