Files
Cyrene/backend/memory-service/internal/store/store.go
T
AskaEth 78e3f450c2 feat: Round 5 - Memory Service, Tool Engine, Call Records, Thinking Logs
- Fix: Session history flash (race condition + WS guard)
- Fix: Chat background overlay + sidebar transparency
- Fix: IoT device control (Chinese action names, status field)
- Feat: Independent memory-service (port 8091, 13 endpoints)
- Feat: Independent tool-engine service (port 8092, 13 tools)
- Feat: Tool call logs with paginated DevTools panel
- Feat: Thinking log records with DevTools panel
- Feat: Future development roadmap document
- Chore: Updated .gitignore, go.work, DevTools config
- Chore: 5-service health check, project review docs
2026-05-18 20:05:14 +08:00

766 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package store
import (
"context"
"database/sql"
"fmt"
"log"
"sync"
"time"
"github.com/yourname/cyrene-ai/memory-service/internal/model"
_ "github.com/lib/pq"
)
// deDupThreshold 去重相似度阈值
const deDupThreshold = 0.75
const reconnectInterval = 30 * time.Second
// Store 记忆持久化存储(PostgreSQL + pgvector
type Store struct {
databaseURL string
mu sync.RWMutex
db *sql.DB
}
// errDBNotReady 数据库未就绪时返回的友好错误
var errDBNotReady = fmt.Errorf("记忆系统未就绪: 数据库连接不可用,正在后台重试连接")
// NewStore 创建记忆存储
// 连接失败时不返回 error,而是启动后台重连循环
func NewStore(connStr string) *Store {
s := &Store{
databaseURL: connStr,
}
// 尝试初始连接
if err := s.Reconnect(); err != nil {
log.Printf("[memory-service] ⚠ 记忆存储初始化: 数据库连接失败 (%v),将在后台每30秒重试", err)
} else {
log.Println("[memory-service] 记忆存储已就绪")
}
// 启动后台重连 goroutine
go s.reconnectLoop()
return s
}
// reconnectLoop 后台重连循环
func (s *Store) reconnectLoop() {
ticker := time.NewTicker(reconnectInterval)
defer ticker.Stop()
for range ticker.C {
s.mu.RLock()
ready := s.db != nil
s.mu.RUnlock()
if ready {
// 数据库已连接,检查连接是否仍然有效
s.mu.RLock()
db := s.db
s.mu.RUnlock()
if db != nil {
if err := db.Ping(); err != nil {
log.Printf("[memory-service] ⚠ 数据库连接丢失: %v,开始重连", err)
s.mu.Lock()
if s.db != nil {
s.db.Close()
s.db = nil
}
s.mu.Unlock()
}
}
}
if !s.IsReady() {
if err := s.Reconnect(); err != nil {
log.Printf("[memory-service] ⚠ 数据库重连失败: %v", err)
}
}
}
}
// Reconnect 尝试重连数据库并执行迁移
func (s *Store) Reconnect() error {
s.mu.Lock()
defer s.mu.Unlock()
// 如果已有有效连接,先检查
if s.db != nil {
if err := s.db.Ping(); err == nil {
return nil // 仍然有效
}
// 连接已失效,关闭旧连接
s.db.Close()
s.db = nil
}
db, err := sql.Open("postgres", s.databaseURL)
if err != nil {
return fmt.Errorf("连接数据库失败: %w", err)
}
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
if err := db.Ping(); err != nil {
db.Close()
return fmt.Errorf("数据库ping失败: %w", err)
}
s.db = db
// 执行建表迁移
if err := s.migrate(); err != nil {
log.Printf("[memory-service] ⚠ 数据库迁移失败: %v", err)
s.db.Close()
s.db = nil
return fmt.Errorf("数据库迁移失败: %w", err)
}
log.Println("[memory-service] ✅ 数据库重连成功,记忆系统已就绪")
return nil
}
// IsReady 返回数据库是否可用
func (s *Store) IsReady() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.db != nil
}
// getDB 获取当前数据库连接(带读锁保护)
func (s *Store) getDB() *sql.DB {
s.mu.RLock()
defer s.mu.RUnlock()
return s.db
}
// migrate 创建表结构
func (s *Store) migrate() error {
queries := []string{
`CREATE EXTENSION IF NOT EXISTS vector`,
`CREATE TABLE IF NOT EXISTS memory_entries (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id VARCHAR(64) NOT NULL,
content TEXT NOT NULL,
summary TEXT DEFAULT '',
category VARCHAR(32) DEFAULT 'knowledge',
priority INT DEFAULT 1,
importance INT DEFAULT 5,
keywords TEXT DEFAULT '[]',
session_id VARCHAR(64) DEFAULT '',
source TEXT DEFAULT 'conversation',
embedding vector(1536),
access_count INT DEFAULT 0,
last_access TIMESTAMPTZ DEFAULT NOW(),
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
expires_at TIMESTAMPTZ
)`,
`CREATE INDEX IF NOT EXISTS idx_me_user_id ON memory_entries(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_me_category ON memory_entries(category)`,
`CREATE INDEX IF NOT EXISTS idx_me_priority ON memory_entries(priority)`,
`CREATE INDEX IF NOT EXISTS idx_me_importance ON memory_entries(importance)`,
`CREATE INDEX IF NOT EXISTS idx_me_user_priority ON memory_entries(user_id, priority DESC)`,
`CREATE INDEX IF NOT EXISTS idx_me_user_importance ON memory_entries(user_id, importance DESC)`,
`CREATE INDEX IF NOT EXISTS idx_me_source ON memory_entries(source)`,
`CREATE INDEX IF NOT EXISTS idx_me_category_importance ON memory_entries(category, importance DESC)`,
`CREATE TABLE IF NOT EXISTS thinking_logs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id VARCHAR(64) NOT NULL DEFAULT 'admin_admin',
content TEXT NOT NULL,
tool_calls TEXT DEFAULT '[]',
tool_call_count INT DEFAULT 0,
content_length INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_tl_user_id ON thinking_logs(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_tl_created_at ON thinking_logs(created_at DESC)`,
}
for _, q := range queries {
if _, err := s.db.Exec(q); err != nil {
return fmt.Errorf("执行迁移 '%s' 失败: %w", q[:min(50, len(q))], err)
}
}
return nil
}
// Save 保存记忆
func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
db := s.getDB()
if db == nil {
return errDBNotReady
}
// 设置默认值
if entry.Source == "" {
entry.Source = "conversation"
}
if entry.Importance == 0 {
entry.Importance = 5
}
query := `INSERT INTO memory_entries (user_id, content, summary, category, priority, importance, keywords, session_id, source, embedding, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id, created_at`
var embedding interface{}
if len(entry.Embedding) > 0 {
vec := make([]float64, len(entry.Embedding))
for i, v := range entry.Embedding {
vec[i] = float64(v)
}
embedding = fmt.Sprintf("[%s]", joinFloats(vec))
}
return db.QueryRowContext(ctx, query,
entry.UserID, entry.Content, entry.Summary,
string(entry.Category), int(entry.Priority),
entry.Importance, entry.KeywordsJSON(),
entry.SessionID, entry.Source, embedding, entry.ExpiresAt,
).Scan(&entry.ID, &entry.CreatedAt)
}
// GetByID 根据ID获取记忆
func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
session_id, source, access_count, last_access, created_at, updated_at, expires_at
FROM memory_entries WHERE id = $1`
entry := &model.MemoryEntry{}
var category, keywordsRaw string
err := db.QueryRowContext(ctx, query, id).Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询记忆失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entry.Keywords = model.ParseKeywords(keywordsRaw)
// 更新访问计数
go s.incrementAccess(context.Background(), id)
return entry, nil
}
// Query 按条件查询记忆
func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryEntry, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
if q.Limit <= 0 {
q.Limit = 10
}
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
session_id, source, access_count, last_access, created_at, updated_at, expires_at
FROM memory_entries WHERE user_id = $1`
args := []interface{}{q.UserID}
argIdx := 2
if q.Category != "" {
query += fmt.Sprintf(" AND category = $%d", argIdx)
args = append(args, string(q.Category))
argIdx++
}
if q.Priority >= 0 {
query += fmt.Sprintf(" AND priority >= $%d", argIdx)
args = append(args, int(q.Priority))
argIdx++
}
if q.MinImportance > 0 {
query += fmt.Sprintf(" AND importance >= $%d", argIdx)
args = append(args, q.MinImportance)
argIdx++
}
query += fmt.Sprintf(" ORDER BY priority DESC, importance DESC, created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
args = append(args, q.Limit, q.Offset)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("查询记忆失败: %w", err)
}
defer rows.Close()
return scanMemoryRows(rows)
}
// Delete 删除记忆
func (s *Store) Delete(ctx context.Context, id string) error {
db := s.getDB()
if db == nil {
return errDBNotReady
}
_, err := db.ExecContext(ctx, `DELETE FROM memory_entries WHERE id = $1`, id)
return err
}
// PurgeExpired 清理过期记忆
func (s *Store) PurgeExpired(ctx context.Context) (int64, error) {
db := s.getDB()
if db == nil {
return 0, errDBNotReady
}
result, err := db.ExecContext(ctx,
`DELETE FROM memory_entries WHERE expires_at IS NOT NULL AND expires_at < NOW()`)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// SearchByVector 向量相似度搜索
func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []float64, limit int) ([]model.MemoryEntry, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
if limit <= 0 {
limit = 5
}
vecStr := fmt.Sprintf("[%s]", joinFloats(embedding))
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
session_id, source, access_count, last_access, created_at, updated_at, expires_at,
1 - (embedding <=> $1) AS similarity
FROM memory_entries
WHERE user_id = $2 AND embedding IS NOT NULL
ORDER BY embedding <=> $1
LIMIT $3`
rows, err := db.QueryContext(ctx, query, vecStr, userID, limit)
if err != nil {
return nil, fmt.Errorf("向量搜索失败: %w", err)
}
defer rows.Close()
var entries []model.MemoryEntry
for rows.Next() {
var entry model.MemoryEntry
var category, keywordsRaw string
var similarity float64
if err := rows.Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
&similarity,
); err != nil {
return nil, fmt.Errorf("扫描向量搜索结果失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entry.Keywords = model.ParseKeywords(keywordsRaw)
entries = append(entries, entry)
}
return entries, rows.Err()
}
// SearchByKeyword 关键词匹配查询
func (s *Store) SearchByKeyword(ctx context.Context, userID, keyword string, limit int) ([]model.MemoryEntry, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
if limit <= 0 {
limit = 20
}
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
session_id, source, access_count, last_access, created_at, updated_at, expires_at
FROM memory_entries
WHERE user_id = $1 AND (content ILIKE $2 OR summary ILIKE $2 OR keywords ILIKE $2)
ORDER BY priority DESC, importance DESC
LIMIT $3`
likePattern := "%" + keyword + "%"
rows, err := db.QueryContext(ctx, query, userID, likePattern, limit)
if err != nil {
return nil, fmt.Errorf("关键词搜索失败: %w", err)
}
defer rows.Close()
return scanMemoryRows(rows)
}
// Update 更新记忆
func (s *Store) Update(ctx context.Context, entry *model.MemoryEntry) error {
db := s.getDB()
if db == nil {
return errDBNotReady
}
query := `UPDATE memory_entries SET content = $1, summary = $2, category = $3, priority = $4,
importance = $5, keywords = $6, source = $7, updated_at = NOW()
WHERE id = $8`
_, err := db.ExecContext(ctx, query,
entry.Content, entry.Summary, string(entry.Category), int(entry.Priority),
entry.Importance, entry.KeywordsJSON(), entry.Source, entry.ID,
)
return err
}
// GetCategories 获取用户所有分类及计数
func (s *Store) GetCategories(ctx context.Context, userID string) (map[string]int, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
rows, err := db.QueryContext(ctx,
`SELECT category, COUNT(*) FROM memory_entries WHERE user_id = $1 GROUP BY category ORDER BY category`,
userID,
)
if err != nil {
return nil, fmt.Errorf("查询分类统计失败: %w", err)
}
defer rows.Close()
categories := make(map[string]int)
for rows.Next() {
var cat string
var count int
if err := rows.Scan(&cat, &count); err != nil {
return nil, fmt.Errorf("扫描分类统计失败: %w", err)
}
categories[cat] = count
}
return categories, rows.Err()
}
// Count 获取用户的记忆总数
func (s *Store) Count(ctx context.Context, userID string) (int, error) {
db := s.getDB()
if db == nil {
return 0, errDBNotReady
}
var count int
err := db.QueryRowContext(ctx,
`SELECT COUNT(*) FROM memory_entries WHERE user_id = $1`,
userID,
).Scan(&count)
if err != nil {
return 0, fmt.Errorf("统计记忆失败: %w", err)
}
return count, nil
}
// ConsolidateMemories 记忆整理:合并相似记忆
func (s *Store) ConsolidateMemories(ctx context.Context, userID string) (int, error) {
db := s.getDB()
if db == nil {
return 0, errDBNotReady
}
// 获取用户所有记忆
allMems, err := s.Query(ctx, model.MemoryQuery{
UserID: userID,
Limit: 500,
})
if err != nil {
return 0, fmt.Errorf("查询记忆失败: %w", err)
}
if len(allMems) < 2 {
return 0, nil
}
merged := 0
for i := 0; i < len(allMems); i++ {
if allMems[i].ID == "" {
continue
}
for j := i + 1; j < len(allMems); j++ {
if allMems[j].ID == "" {
continue
}
score := allMems[i].SimilarityScore(&allMems[j])
if score >= deDupThreshold {
keep, discard := &allMems[i], &allMems[j]
if discard.Importance > keep.Importance || discard.Priority > keep.Priority {
keep, discard = discard, keep
}
// 合并关键词
keywordSet := make(map[string]bool)
for _, k := range keep.Keywords {
keywordSet[k] = true
}
for _, k := range discard.Keywords {
keywordSet[k] = true
}
mergedKeywords := make([]string, 0, len(keywordSet))
for k := range keywordSet {
mergedKeywords = append(mergedKeywords, k)
}
keep.Keywords = mergedKeywords
if keep.Importance < 10 {
keep.Importance++
}
keep.Source = "consolidated"
if err := s.Update(ctx, keep); err != nil {
log.Printf("[memory-service] 合并更新记忆 %s 失败: %v", keep.ID, err)
continue
}
if err := s.Delete(ctx, discard.ID); err != nil {
log.Printf("[memory-service] 合并删除记忆 %s 失败: %v", discard.ID, err)
continue
}
discard.ID = ""
merged++
log.Printf("[memory-service] 合并相似记忆: %s <- %s (相似度 %.0f%%)",
keep.ID[:min(8, len(keep.ID))], discard.ID[:min(8, len(discard.ID))], score*100)
}
}
}
if merged > 0 {
log.Printf("[memory-service] 记忆整理完成: 用户 %s 合并 %d 条相似记忆", userID, merged)
}
return merged, nil
}
// DecayMemories 记忆衰减:降低长期未访问的低重要性记忆
func (s *Store) DecayMemories(ctx context.Context, userID string) (int, int, error) {
db := s.getDB()
if db == nil {
return 0, 0, errDBNotReady
}
result1, err := db.ExecContext(ctx, `
UPDATE memory_entries SET priority = GREATEST(priority - 1, 0), updated_at = NOW()
WHERE user_id = $1
AND access_count < 3
AND last_access < NOW() - INTERVAL '30 days'
AND importance < 3
AND priority > 0
AND category NOT IN ('personal_info', 'user_preference')
`, userID)
if err != nil {
return 0, 0, fmt.Errorf("衰减低活跃记忆失败: %w", err)
}
decayed1, _ := result1.RowsAffected()
result2, err := db.ExecContext(ctx, `
DELETE FROM memory_entries
WHERE user_id = $1
AND priority = 0
AND access_count = 0
AND last_access < NOW() - INTERVAL '14 days'
`, userID)
if err != nil {
return 0, 0, fmt.Errorf("清理临时记忆失败: %w", err)
}
deleted2, _ := result2.RowsAffected()
total := decayed1 + deleted2
if total > 0 {
log.Printf("[memory-service] 记忆衰减完成: 用户 %s 降级 %d 条, 删除 %d 条过期临时记忆",
userID, decayed1, deleted2)
}
return int(decayed1), int(deleted2), nil
}
func (s *Store) incrementAccess(ctx context.Context, id string) {
db := s.getDB()
if db == nil {
return
}
db.ExecContext(ctx,
`UPDATE memory_entries SET access_count = access_count + 1, last_access = NOW() WHERE id = $1`, id)
}
// Close 关闭数据库连接
func (s *Store) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.db != nil {
return s.db.Close()
}
return nil
}
// scanMemoryRows 扫描记忆行(通用方法)
func scanMemoryRows(rows *sql.Rows) ([]model.MemoryEntry, error) {
var entries []model.MemoryEntry
for rows.Next() {
var entry model.MemoryEntry
var category, keywordsRaw string
if err := rows.Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
); err != nil {
return nil, fmt.Errorf("扫描记忆行失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entry.Keywords = model.ParseKeywords(keywordsRaw)
entries = append(entries, entry)
}
return entries, rows.Err()
}
// joinFloats 将 float64 切片转为逗号分隔字符串
func joinFloats(vec []float64) string {
if len(vec) == 0 {
return ""
}
s := fmt.Sprintf("%f", vec[0])
for i := 1; i < len(vec); i++ {
s += fmt.Sprintf(",%f", vec[i])
}
return s
}
// SaveThinkingLog 保存自主思考日志
func (s *Store) SaveThinkingLog(ctx context.Context, log *model.ThinkingLog) error {
db := s.getDB()
if db == nil {
return errDBNotReady
}
if log.UserID == "" {
log.UserID = "admin_admin"
}
if log.ToolCalls == "" {
log.ToolCalls = "[]"
}
query := `INSERT INTO thinking_logs (user_id, content, tool_calls, tool_call_count, content_length)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, created_at`
return db.QueryRowContext(ctx, query,
log.UserID, log.Content, log.ToolCalls,
log.ToolCallCount, log.ContentLength,
).Scan(&log.ID, &log.CreatedAt)
}
// QueryThinkingLogs 分页查询思考日志
func (s *Store) QueryThinkingLogs(ctx context.Context, q model.ThinkingQuery) ([]model.ThinkingLog, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
if q.Limit <= 0 {
q.Limit = 20
}
query := `SELECT id, user_id, content, tool_calls, tool_call_count, content_length, created_at
FROM thinking_logs`
args := []interface{}{}
argIdx := 1
if q.UserID != "" {
query += fmt.Sprintf(" WHERE user_id = $%d", argIdx)
args = append(args, q.UserID)
argIdx++
}
query += fmt.Sprintf(" ORDER BY created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
args = append(args, q.Limit, q.Offset)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("查询思考日志失败: %w", err)
}
defer rows.Close()
var logs []model.ThinkingLog
for rows.Next() {
var tl model.ThinkingLog
if err := rows.Scan(&tl.ID, &tl.UserID, &tl.Content, &tl.ToolCalls,
&tl.ToolCallCount, &tl.ContentLength, &tl.CreatedAt); err != nil {
return nil, fmt.Errorf("扫描思考日志行失败: %w", err)
}
logs = append(logs, tl)
}
return logs, rows.Err()
}
// GetThinkingLogByID 根据ID获取单条思考日志
func (s *Store) GetThinkingLogByID(ctx context.Context, id string) (*model.ThinkingLog, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
query := `SELECT id, user_id, content, tool_calls, tool_call_count, content_length, created_at
FROM thinking_logs WHERE id = $1`
tl := &model.ThinkingLog{}
err := db.QueryRowContext(ctx, query, id).Scan(
&tl.ID, &tl.UserID, &tl.Content, &tl.ToolCalls,
&tl.ToolCallCount, &tl.ContentLength, &tl.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询思考日志失败: %w", err)
}
return tl, nil
}
// GetThinkingStats 获取思考日志统计信息
func (s *Store) GetThinkingStats(ctx context.Context) (*model.ThinkingStats, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
query := `SELECT
COALESCE(COUNT(*), 0),
COALESCE(SUM(tool_call_count), 0),
COALESCE(AVG(content_length), 0),
COALESCE(MAX(created_at)::TEXT, '')
FROM thinking_logs`
stats := &model.ThinkingStats{}
err := db.QueryRowContext(ctx, query).Scan(
&stats.TotalLogs, &stats.TotalToolCalls,
&stats.AvgContentLen, &stats.LatestAt,
)
if err != nil {
return nil, fmt.Errorf("查询思考日志统计失败: %w", err)
}
return stats, nil
}