package store import ( "context" "database/sql" "fmt" "log" "sync" "time" "github.com/yourname/cyrene-ai/memory-service/internal/model" _ "github.com/lib/pq" ) // deDupThreshold 去重相似度阈值 const deDupThreshold = 0.75 const reconnectInterval = 30 * time.Second // Store 记忆持久化存储(PostgreSQL + pgvector) type Store struct { databaseURL string mu sync.RWMutex db *sql.DB } // errDBNotReady 数据库未就绪时返回的友好错误 var errDBNotReady = fmt.Errorf("记忆系统未就绪: 数据库连接不可用,正在后台重试连接") // NewStore 创建记忆存储 // 连接失败时不返回 error,而是启动后台重连循环 func NewStore(connStr string) *Store { s := &Store{ databaseURL: connStr, } // 尝试初始连接 if err := s.Reconnect(); err != nil { log.Printf("[memory-service] ⚠ 记忆存储初始化: 数据库连接失败 (%v),将在后台每30秒重试", err) } else { log.Println("[memory-service] 记忆存储已就绪") } // 启动后台重连 goroutine go s.reconnectLoop() return s } // reconnectLoop 后台重连循环 func (s *Store) reconnectLoop() { ticker := time.NewTicker(reconnectInterval) defer ticker.Stop() for range ticker.C { s.mu.RLock() ready := s.db != nil s.mu.RUnlock() if ready { // 数据库已连接,检查连接是否仍然有效 s.mu.RLock() db := s.db s.mu.RUnlock() if db != nil { if err := db.Ping(); err != nil { log.Printf("[memory-service] ⚠ 数据库连接丢失: %v,开始重连", err) s.mu.Lock() if s.db != nil { s.db.Close() s.db = nil } s.mu.Unlock() } } } if !s.IsReady() { if err := s.Reconnect(); err != nil { log.Printf("[memory-service] ⚠ 数据库重连失败: %v", err) } } } } // Reconnect 尝试重连数据库并执行迁移 func (s *Store) Reconnect() error { s.mu.Lock() defer s.mu.Unlock() // 如果已有有效连接,先检查 if s.db != nil { if err := s.db.Ping(); err == nil { return nil // 仍然有效 } // 连接已失效,关闭旧连接 s.db.Close() s.db = nil } db, err := sql.Open("postgres", s.databaseURL) if err != nil { return fmt.Errorf("连接数据库失败: %w", err) } db.SetMaxOpenConns(25) db.SetMaxIdleConns(5) db.SetConnMaxLifetime(5 * time.Minute) if err := db.Ping(); err != nil { db.Close() return fmt.Errorf("数据库ping失败: %w", err) } s.db = db // 执行建表迁移 if err := s.migrate(); err != nil { log.Printf("[memory-service] ⚠ 数据库迁移失败: %v", err) s.db.Close() s.db = nil return fmt.Errorf("数据库迁移失败: %w", err) } log.Println("[memory-service] ✅ 数据库重连成功,记忆系统已就绪") return nil } // IsReady 返回数据库是否可用 func (s *Store) IsReady() bool { s.mu.RLock() defer s.mu.RUnlock() return s.db != nil } // getDB 获取当前数据库连接(带读锁保护) func (s *Store) getDB() *sql.DB { s.mu.RLock() defer s.mu.RUnlock() return s.db } // migrate 创建表结构 func (s *Store) migrate() error { queries := []string{ `CREATE EXTENSION IF NOT EXISTS vector`, `CREATE TABLE IF NOT EXISTS memory_entries ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), user_id VARCHAR(64) NOT NULL, content TEXT NOT NULL, summary TEXT DEFAULT '', category VARCHAR(32) DEFAULT 'knowledge', priority INT DEFAULT 1, importance INT DEFAULT 5, keywords TEXT DEFAULT '[]', session_id VARCHAR(64) DEFAULT '', source TEXT DEFAULT 'conversation', embedding vector(1536), access_count INT DEFAULT 0, last_access TIMESTAMPTZ DEFAULT NOW(), created_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW(), expires_at TIMESTAMPTZ )`, `CREATE INDEX IF NOT EXISTS idx_me_user_id ON memory_entries(user_id)`, `CREATE INDEX IF NOT EXISTS idx_me_category ON memory_entries(category)`, `CREATE INDEX IF NOT EXISTS idx_me_priority ON memory_entries(priority)`, `CREATE INDEX IF NOT EXISTS idx_me_importance ON memory_entries(importance)`, `CREATE INDEX IF NOT EXISTS idx_me_user_priority ON memory_entries(user_id, priority DESC)`, `CREATE INDEX IF NOT EXISTS idx_me_user_importance ON memory_entries(user_id, importance DESC)`, `CREATE INDEX IF NOT EXISTS idx_me_source ON memory_entries(source)`, `CREATE INDEX IF NOT EXISTS idx_me_category_importance ON memory_entries(category, importance DESC)`, `CREATE TABLE IF NOT EXISTS thinking_logs ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), user_id VARCHAR(64) NOT NULL DEFAULT 'admin_admin', content TEXT NOT NULL, tool_calls TEXT DEFAULT '[]', tool_call_count INT DEFAULT 0, content_length INT DEFAULT 0, created_at TIMESTAMPTZ DEFAULT NOW() )`, `CREATE INDEX IF NOT EXISTS idx_tl_user_id ON thinking_logs(user_id)`, `CREATE INDEX IF NOT EXISTS idx_tl_created_at ON thinking_logs(created_at DESC)`, } for _, q := range queries { if _, err := s.db.Exec(q); err != nil { return fmt.Errorf("执行迁移 '%s' 失败: %w", q[:min(50, len(q))], err) } } return nil } // Save 保存记忆 func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error { db := s.getDB() if db == nil { return errDBNotReady } // 设置默认值 if entry.Source == "" { entry.Source = "conversation" } if entry.Importance == 0 { entry.Importance = 5 } query := `INSERT INTO memory_entries (user_id, content, summary, category, priority, importance, keywords, session_id, source, embedding, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id, created_at` var embedding interface{} if len(entry.Embedding) > 0 { vec := make([]float64, len(entry.Embedding)) for i, v := range entry.Embedding { vec[i] = float64(v) } embedding = fmt.Sprintf("[%s]", joinFloats(vec)) } return db.QueryRowContext(ctx, query, entry.UserID, entry.Content, entry.Summary, string(entry.Category), int(entry.Priority), entry.Importance, entry.KeywordsJSON(), entry.SessionID, entry.Source, embedding, entry.ExpiresAt, ).Scan(&entry.ID, &entry.CreatedAt) } // GetByID 根据ID获取记忆 func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, error) { db := s.getDB() if db == nil { return nil, errDBNotReady } query := `SELECT id, user_id, content, summary, category, priority, importance, keywords, session_id, source, access_count, last_access, created_at, updated_at, expires_at FROM memory_entries WHERE id = $1` entry := &model.MemoryEntry{} var category, keywordsRaw string err := db.QueryRowContext(ctx, query, id).Scan( &entry.ID, &entry.UserID, &entry.Content, &entry.Summary, &category, &entry.Priority, &entry.Importance, &keywordsRaw, &entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt, ) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("查询记忆失败: %w", err) } entry.Category = model.MemoryCategory(category) entry.Keywords = model.ParseKeywords(keywordsRaw) // 更新访问计数 go s.incrementAccess(context.Background(), id) return entry, nil } // Query 按条件查询记忆 func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryEntry, error) { db := s.getDB() if db == nil { return nil, errDBNotReady } if q.Limit <= 0 { q.Limit = 10 } query := `SELECT id, user_id, content, summary, category, priority, importance, keywords, session_id, source, access_count, last_access, created_at, updated_at, expires_at FROM memory_entries WHERE user_id = $1` args := []interface{}{q.UserID} argIdx := 2 if q.Category != "" { query += fmt.Sprintf(" AND category = $%d", argIdx) args = append(args, string(q.Category)) argIdx++ } if q.Priority >= 0 { query += fmt.Sprintf(" AND priority >= $%d", argIdx) args = append(args, int(q.Priority)) argIdx++ } if q.MinImportance > 0 { query += fmt.Sprintf(" AND importance >= $%d", argIdx) args = append(args, q.MinImportance) argIdx++ } query += fmt.Sprintf(" ORDER BY priority DESC, importance DESC, created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1) args = append(args, q.Limit, q.Offset) rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("查询记忆失败: %w", err) } defer rows.Close() return scanMemoryRows(rows) } // Delete 删除记忆 func (s *Store) Delete(ctx context.Context, id string) error { db := s.getDB() if db == nil { return errDBNotReady } _, err := db.ExecContext(ctx, `DELETE FROM memory_entries WHERE id = $1`, id) return err } // PurgeExpired 清理过期记忆 func (s *Store) PurgeExpired(ctx context.Context) (int64, error) { db := s.getDB() if db == nil { return 0, errDBNotReady } result, err := db.ExecContext(ctx, `DELETE FROM memory_entries WHERE expires_at IS NOT NULL AND expires_at < NOW()`) if err != nil { return 0, err } return result.RowsAffected() } // SearchByVector 向量相似度搜索 func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []float64, limit int) ([]model.MemoryEntry, error) { db := s.getDB() if db == nil { return nil, errDBNotReady } if limit <= 0 { limit = 5 } vecStr := fmt.Sprintf("[%s]", joinFloats(embedding)) query := `SELECT id, user_id, content, summary, category, priority, importance, keywords, session_id, source, access_count, last_access, created_at, updated_at, expires_at, 1 - (embedding <=> $1) AS similarity FROM memory_entries WHERE user_id = $2 AND embedding IS NOT NULL ORDER BY embedding <=> $1 LIMIT $3` rows, err := db.QueryContext(ctx, query, vecStr, userID, limit) if err != nil { return nil, fmt.Errorf("向量搜索失败: %w", err) } defer rows.Close() var entries []model.MemoryEntry for rows.Next() { var entry model.MemoryEntry var category, keywordsRaw string var similarity float64 if err := rows.Scan( &entry.ID, &entry.UserID, &entry.Content, &entry.Summary, &category, &entry.Priority, &entry.Importance, &keywordsRaw, &entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt, &similarity, ); err != nil { return nil, fmt.Errorf("扫描向量搜索结果失败: %w", err) } entry.Category = model.MemoryCategory(category) entry.Keywords = model.ParseKeywords(keywordsRaw) entries = append(entries, entry) } return entries, rows.Err() } // SearchByKeyword 关键词匹配查询 func (s *Store) SearchByKeyword(ctx context.Context, userID, keyword string, limit int) ([]model.MemoryEntry, error) { db := s.getDB() if db == nil { return nil, errDBNotReady } if limit <= 0 { limit = 20 } query := `SELECT id, user_id, content, summary, category, priority, importance, keywords, session_id, source, access_count, last_access, created_at, updated_at, expires_at FROM memory_entries WHERE user_id = $1 AND (content ILIKE $2 OR summary ILIKE $2 OR keywords ILIKE $2) ORDER BY priority DESC, importance DESC LIMIT $3` likePattern := "%" + keyword + "%" rows, err := db.QueryContext(ctx, query, userID, likePattern, limit) if err != nil { return nil, fmt.Errorf("关键词搜索失败: %w", err) } defer rows.Close() return scanMemoryRows(rows) } // Update 更新记忆 func (s *Store) Update(ctx context.Context, entry *model.MemoryEntry) error { db := s.getDB() if db == nil { return errDBNotReady } query := `UPDATE memory_entries SET content = $1, summary = $2, category = $3, priority = $4, importance = $5, keywords = $6, source = $7, updated_at = NOW() WHERE id = $8` _, err := db.ExecContext(ctx, query, entry.Content, entry.Summary, string(entry.Category), int(entry.Priority), entry.Importance, entry.KeywordsJSON(), entry.Source, entry.ID, ) return err } // GetCategories 获取用户所有分类及计数 func (s *Store) GetCategories(ctx context.Context, userID string) (map[string]int, error) { db := s.getDB() if db == nil { return nil, errDBNotReady } rows, err := db.QueryContext(ctx, `SELECT category, COUNT(*) FROM memory_entries WHERE user_id = $1 GROUP BY category ORDER BY category`, userID, ) if err != nil { return nil, fmt.Errorf("查询分类统计失败: %w", err) } defer rows.Close() categories := make(map[string]int) for rows.Next() { var cat string var count int if err := rows.Scan(&cat, &count); err != nil { return nil, fmt.Errorf("扫描分类统计失败: %w", err) } categories[cat] = count } return categories, rows.Err() } // Count 获取用户的记忆总数 func (s *Store) Count(ctx context.Context, userID string) (int, error) { db := s.getDB() if db == nil { return 0, errDBNotReady } var count int err := db.QueryRowContext(ctx, `SELECT COUNT(*) FROM memory_entries WHERE user_id = $1`, userID, ).Scan(&count) if err != nil { return 0, fmt.Errorf("统计记忆失败: %w", err) } return count, nil } // ConsolidateMemories 记忆整理:合并相似记忆 func (s *Store) ConsolidateMemories(ctx context.Context, userID string) (int, error) { db := s.getDB() if db == nil { return 0, errDBNotReady } // 获取用户所有记忆 allMems, err := s.Query(ctx, model.MemoryQuery{ UserID: userID, Limit: 500, }) if err != nil { return 0, fmt.Errorf("查询记忆失败: %w", err) } if len(allMems) < 2 { return 0, nil } merged := 0 for i := 0; i < len(allMems); i++ { if allMems[i].ID == "" { continue } for j := i + 1; j < len(allMems); j++ { if allMems[j].ID == "" { continue } score := allMems[i].SimilarityScore(&allMems[j]) if score >= deDupThreshold { keep, discard := &allMems[i], &allMems[j] if discard.Importance > keep.Importance || discard.Priority > keep.Priority { keep, discard = discard, keep } // 合并关键词 keywordSet := make(map[string]bool) for _, k := range keep.Keywords { keywordSet[k] = true } for _, k := range discard.Keywords { keywordSet[k] = true } mergedKeywords := make([]string, 0, len(keywordSet)) for k := range keywordSet { mergedKeywords = append(mergedKeywords, k) } keep.Keywords = mergedKeywords if keep.Importance < 10 { keep.Importance++ } keep.Source = "consolidated" if err := s.Update(ctx, keep); err != nil { log.Printf("[memory-service] 合并更新记忆 %s 失败: %v", keep.ID, err) continue } if err := s.Delete(ctx, discard.ID); err != nil { log.Printf("[memory-service] 合并删除记忆 %s 失败: %v", discard.ID, err) continue } discard.ID = "" merged++ log.Printf("[memory-service] 合并相似记忆: %s <- %s (相似度 %.0f%%)", keep.ID[:min(8, len(keep.ID))], discard.ID[:min(8, len(discard.ID))], score*100) } } } if merged > 0 { log.Printf("[memory-service] 记忆整理完成: 用户 %s 合并 %d 条相似记忆", userID, merged) } return merged, nil } // DecayMemories 记忆衰减:降低长期未访问的低重要性记忆 func (s *Store) DecayMemories(ctx context.Context, userID string) (int, int, error) { db := s.getDB() if db == nil { return 0, 0, errDBNotReady } result1, err := db.ExecContext(ctx, ` UPDATE memory_entries SET priority = GREATEST(priority - 1, 0), updated_at = NOW() WHERE user_id = $1 AND access_count < 3 AND last_access < NOW() - INTERVAL '30 days' AND importance < 3 AND priority > 0 AND category NOT IN ('personal_info', 'user_preference') `, userID) if err != nil { return 0, 0, fmt.Errorf("衰减低活跃记忆失败: %w", err) } decayed1, _ := result1.RowsAffected() result2, err := db.ExecContext(ctx, ` DELETE FROM memory_entries WHERE user_id = $1 AND priority = 0 AND access_count = 0 AND last_access < NOW() - INTERVAL '14 days' `, userID) if err != nil { return 0, 0, fmt.Errorf("清理临时记忆失败: %w", err) } deleted2, _ := result2.RowsAffected() total := decayed1 + deleted2 if total > 0 { log.Printf("[memory-service] 记忆衰减完成: 用户 %s 降级 %d 条, 删除 %d 条过期临时记忆", userID, decayed1, deleted2) } return int(decayed1), int(deleted2), nil } func (s *Store) incrementAccess(ctx context.Context, id string) { db := s.getDB() if db == nil { return } db.ExecContext(ctx, `UPDATE memory_entries SET access_count = access_count + 1, last_access = NOW() WHERE id = $1`, id) } // Close 关闭数据库连接 func (s *Store) Close() error { s.mu.Lock() defer s.mu.Unlock() if s.db != nil { return s.db.Close() } return nil } // scanMemoryRows 扫描记忆行(通用方法) func scanMemoryRows(rows *sql.Rows) ([]model.MemoryEntry, error) { var entries []model.MemoryEntry for rows.Next() { var entry model.MemoryEntry var category, keywordsRaw string if err := rows.Scan( &entry.ID, &entry.UserID, &entry.Content, &entry.Summary, &category, &entry.Priority, &entry.Importance, &keywordsRaw, &entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt, ); err != nil { return nil, fmt.Errorf("扫描记忆行失败: %w", err) } entry.Category = model.MemoryCategory(category) entry.Keywords = model.ParseKeywords(keywordsRaw) entries = append(entries, entry) } return entries, rows.Err() } // joinFloats 将 float64 切片转为逗号分隔字符串 func joinFloats(vec []float64) string { if len(vec) == 0 { return "" } s := fmt.Sprintf("%f", vec[0]) for i := 1; i < len(vec); i++ { s += fmt.Sprintf(",%f", vec[i]) } return s } // SaveThinkingLog 保存自主思考日志 func (s *Store) SaveThinkingLog(ctx context.Context, log *model.ThinkingLog) error { db := s.getDB() if db == nil { return errDBNotReady } if log.UserID == "" { log.UserID = "admin_admin" } if log.ToolCalls == "" { log.ToolCalls = "[]" } query := `INSERT INTO thinking_logs (user_id, content, tool_calls, tool_call_count, content_length) VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at` return db.QueryRowContext(ctx, query, log.UserID, log.Content, log.ToolCalls, log.ToolCallCount, log.ContentLength, ).Scan(&log.ID, &log.CreatedAt) } // QueryThinkingLogs 分页查询思考日志 func (s *Store) QueryThinkingLogs(ctx context.Context, q model.ThinkingQuery) ([]model.ThinkingLog, error) { db := s.getDB() if db == nil { return nil, errDBNotReady } if q.Limit <= 0 { q.Limit = 20 } query := `SELECT id, user_id, content, tool_calls, tool_call_count, content_length, created_at FROM thinking_logs` args := []interface{}{} argIdx := 1 if q.UserID != "" { query += fmt.Sprintf(" WHERE user_id = $%d", argIdx) args = append(args, q.UserID) argIdx++ } query += fmt.Sprintf(" ORDER BY created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1) args = append(args, q.Limit, q.Offset) rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("查询思考日志失败: %w", err) } defer rows.Close() var logs []model.ThinkingLog for rows.Next() { var tl model.ThinkingLog if err := rows.Scan(&tl.ID, &tl.UserID, &tl.Content, &tl.ToolCalls, &tl.ToolCallCount, &tl.ContentLength, &tl.CreatedAt); err != nil { return nil, fmt.Errorf("扫描思考日志行失败: %w", err) } logs = append(logs, tl) } return logs, rows.Err() } // GetThinkingLogByID 根据ID获取单条思考日志 func (s *Store) GetThinkingLogByID(ctx context.Context, id string) (*model.ThinkingLog, error) { db := s.getDB() if db == nil { return nil, errDBNotReady } query := `SELECT id, user_id, content, tool_calls, tool_call_count, content_length, created_at FROM thinking_logs WHERE id = $1` tl := &model.ThinkingLog{} err := db.QueryRowContext(ctx, query, id).Scan( &tl.ID, &tl.UserID, &tl.Content, &tl.ToolCalls, &tl.ToolCallCount, &tl.ContentLength, &tl.CreatedAt, ) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("查询思考日志失败: %w", err) } return tl, nil } // GetThinkingStats 获取思考日志统计信息 func (s *Store) GetThinkingStats(ctx context.Context) (*model.ThinkingStats, error) { db := s.getDB() if db == nil { return nil, errDBNotReady } query := `SELECT COALESCE(COUNT(*), 0), COALESCE(SUM(tool_call_count), 0), COALESCE(AVG(content_length), 0), COALESCE(MAX(created_at)::TEXT, '') FROM thinking_logs` stats := &model.ThinkingStats{} err := db.QueryRowContext(ctx, query).Scan( &stats.TotalLogs, &stats.TotalToolCalls, &stats.AvgContentLen, &stats.LatestAt, ) if err != nil { return nil, fmt.Errorf("查询思考日志统计失败: %w", err) } return stats, nil }