dev 分支暂存

This commit is contained in:
2026-05-16 08:26:56 +08:00
parent 58c8caa570
commit eb4129176c
71 changed files with 8474 additions and 214 deletions
@@ -0,0 +1,206 @@
package memory
import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
)
// Extractor 记忆提取器 —— 从对话中提取结构化记忆
type Extractor struct {
store *Store
llmChat func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)
}
// NewExtractor 创建记忆提取器
// llmChat: LLM对话函数,用于分析对话内容并提取记忆
// 如果为nil,则使用规则提取(降级模式)
func NewExtractor(store *Store, llmChat func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)) *Extractor {
return &Extractor{
store: store,
llmChat: llmChat,
}
}
// ExtractAndStore 从一轮对话中提取记忆并存储
// 异步执行,不阻塞主流程
func (e *Extractor) ExtractAndStore(ctx context.Context, userID, sessionID, userMessage, assistantResponse string) {
memories, err := e.extract(ctx, userMessage, assistantResponse)
if err != nil {
log.Printf("[memory] 记忆提取失败: %v", err)
return
}
for _, mem := range memories {
mem.UserID = userID
mem.SessionID = sessionID
if err := e.store.Save(ctx, &mem); err != nil {
log.Printf("[memory] 记忆保存失败: %v", err)
continue
}
log.Printf("[memory] 新记忆已保存 [%s]: %s", mem.Category, mem.Summary)
}
}
// extract 从对话中提取记忆
func (e *Extractor) extract(ctx context.Context, userMessage, assistantResponse string) ([]model.MemoryEntry, error) {
// 如果有LLM,使用LLM提取
if e.llmChat != nil {
return e.extractWithLLM(ctx, userMessage, assistantResponse)
}
// 降级:规则提取
return e.extractWithRules(userMessage, assistantResponse), nil
}
// MemoryExtractionResult LLM提取结果的结构
type MemoryExtractionResult struct {
Memories []struct {
Content string `json:"content"`
Summary string `json:"summary"`
Category string `json:"category"`
Priority int `json:"priority"`
} `json:"memories"`
}
// extractWithLLM 使用LLM提取记忆
func (e *Extractor) extractWithLLM(ctx context.Context, userMessage, assistantResponse string) ([]model.MemoryEntry, error) {
prompt := fmt.Sprintf(`分析以下对话,提取关于用户(开拓者)的重要信息作为记忆。
用户消息: %s
昔涟回复: %s
请以JSON格式返回提取的记忆。每条记忆需要包含:
- content: 完整的记忆内容(一句话描述)
- summary: 简短摘要(10字以内)
- category: 分类 (preference/fact/event/relationship/habit/other)
- priority: 优先级 (0=临时, 1=普通, 2=重要, 3=核心)
只提取有意义的信息,不要提取无意义的闲聊。如果没有值得记住的内容,返回空数组。
输出格式:
{"memories": [{"content": "...", "summary": "...", "category": "...", "priority": 1}]}
`, userMessage, assistantResponse)
resp, err := e.llmChat(ctx, []model.LLMMessage{
{Role: "system", Content: "你是一个记忆提取助手。你只输出JSON格式的结果,不输出其他内容。"},
{Role: "user", Content: prompt},
})
if err != nil {
return nil, fmt.Errorf("LLM提取记忆失败: %w", err)
}
// 解析JSON
result := MemoryExtractionResult{}
content := extractJSON(resp.Content)
if err := json.Unmarshal([]byte(content), &result); err != nil {
// 尝试作为数组解析
var arrResult []struct {
Content string `json:"content"`
Summary string `json:"summary"`
Category string `json:"category"`
Priority int `json:"priority"`
}
if err2 := json.Unmarshal([]byte(content), &arrResult); err2 != nil {
return nil, fmt.Errorf("解析记忆JSON失败: %w (原始: %s)", err, content[:min(len(content), 100)])
}
result.Memories = arrResult
}
var entries []model.MemoryEntry
for _, m := range result.Memories {
entries = append(entries, model.MemoryEntry{
Content: m.Content,
Summary: m.Summary,
Category: model.MemoryCategory(m.Category),
Priority: model.MemoryPriority(m.Priority),
})
}
return entries, nil
}
// extractWithRules 基于规则提取记忆(降级方案)
func (e *Extractor) extractWithRules(userMessage, _ string) []model.MemoryEntry {
var entries []model.MemoryEntry
// 规则1: 检测用户偏好表达
prefPatterns := map[string]string{
"喜欢": "preference",
"爱": "preference",
"最喜欢": "preference",
"讨厌": "preference",
"不喜欢": "preference",
"经常": "habit",
"每天都": "habit",
"一直": "habit",
"我叫": "fact",
"我是": "fact",
"我家": "fact",
"住在": "fact",
"生日": "fact",
}
for pattern, category := range prefPatterns {
if idx := strings.Index(userMessage, pattern); idx != -1 {
// 提取包含关键词的句子片段
start := max(0, idx-5)
end := min(len([]rune(userMessage)), idx+len([]rune(pattern))+15)
content := strings.TrimSpace(string([]rune(userMessage)[start:end]))
entries = append(entries, model.MemoryEntry{
Content: content,
Summary: truncateString(content, 20),
Category: model.MemoryCategory(category),
Priority: model.MemoryNormal,
})
break // 每条消息最多提取一条规则记忆
}
}
return entries
}
// extractJSON 从LLM回复中提取JSON内容
func extractJSON(text string) string {
text = strings.TrimSpace(text)
// 移除 markdown 代码块标记
if strings.HasPrefix(text, "```json") {
text = strings.TrimPrefix(text, "```json")
text = strings.TrimSuffix(text, "```")
text = strings.TrimSpace(text)
} else if strings.HasPrefix(text, "```") {
text = strings.TrimPrefix(text, "```")
text = strings.TrimSuffix(text, "```")
text = strings.TrimSpace(text)
}
return text
}
func truncateString(s string, maxLen int) string {
runes := []rune(s)
if len(runes) <= maxLen {
return s
}
return string(runes[:maxLen]) + "..."
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
@@ -0,0 +1,152 @@
package memory
import (
"context"
"fmt"
"strings"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
)
// MemoryEntry 记忆条目别名(避免与model包冲突)
type MemoryEntry = model.MemoryEntry
// Retriever 记忆检索器
type Retriever struct {
store *Store
embedder Embedder // 文本转向量的接口
}
// Embedder 文本嵌入接口
type Embedder interface {
Embed(ctx context.Context, text string) ([]float64, error)
}
// SimpleEmbedder 基于关键词的简单嵌入(MVP阶段可用,无需外部API)
type SimpleEmbedder struct{}
// Embed 简单的关键词哈希嵌入(用于MVP快速验证)
func (e *SimpleEmbedder) Embed(ctx context.Context, text string) ([]float64, error) {
// 生成一个简单的1536维特征向量
// 基于字符频率的简单表示,用于MVP阶段
vec := make([]float64, 1536)
runes := []rune(strings.ToLower(text))
for i, r := range runes {
idx := int(r) % 1536
vec[idx] += 1.0 / float64(len(runes))
// 考虑位置信息
posIdx := (int(r) + i) % 1536
vec[posIdx] += 0.5 / float64(len(runes))
}
return vec, nil
}
// NewRetriever 创建记忆检索器
func NewRetriever(store *Store, embedder Embedder) *Retriever {
if embedder == nil {
embedder = &SimpleEmbedder{}
}
return &Retriever{
store: store,
embedder: embedder,
}
}
// Retrieve 检索与查询相关的记忆
// 策略: 向量相似度 + 关键词匹配混合
func (r *Retriever) Retrieve(ctx context.Context, userID string, query string) ([]MemoryEntry, error) {
var allEntries []MemoryEntry
seen := make(map[string]bool)
// 1. 向量相似度检索
embedding, err := r.embedder.Embed(ctx, query)
if err == nil {
vecEntries, err := r.store.SearchByVector(ctx, userID, embedding, 5)
if err == nil {
for _, e := range vecEntries {
if !seen[e.ID] {
seen[e.ID] = true
allEntries = append(allEntries, e)
}
}
}
}
// 2. 关键词匹配检索(核心/重要记忆优先)
keywordEntries, err := r.keywordSearch(ctx, userID, query)
if err == nil {
for _, e := range keywordEntries {
if !seen[e.ID] {
seen[e.ID] = true
allEntries = append(allEntries, e)
}
}
}
// 3. 如果没有匹配,返回最近的重要记忆
if len(allEntries) == 0 {
recentEntries, err := r.store.Query(ctx, model.MemoryQuery{
UserID: userID,
Priority: int(model.MemoryImportant),
Limit: 3,
})
if err == nil {
allEntries = recentEntries
}
}
// 限制返回数量
if len(allEntries) > 10 {
allEntries = allEntries[:10]
}
return allEntries, nil
}
// keywordSearch 关键词匹配检索
func (r *Retriever) keywordSearch(ctx context.Context, userID string, query string) ([]MemoryEntry, error) {
// 查询最近的核心和重要记忆
entries, err := r.store.Query(ctx, model.MemoryQuery{
UserID: userID,
Priority: int(model.MemoryImportant),
Limit: 50,
})
if err != nil {
return nil, err
}
// 简单的关键词匹配过滤
var matched []MemoryEntry
queryLower := strings.ToLower(query)
for _, entry := range entries {
contentLower := strings.ToLower(entry.Content)
summaryLower := strings.ToLower(entry.Summary)
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
matched = append(matched, entry)
}
}
// 也匹配普通记忆
normalEntries, err := r.store.Query(ctx, model.MemoryQuery{
UserID: userID,
Priority: int(model.MemoryNormal),
Limit: 100,
})
if err == nil {
for _, entry := range normalEntries {
contentLower := strings.ToLower(entry.Content)
summaryLower := strings.ToLower(entry.Summary)
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
matched = append(matched, entry)
}
}
}
return matched, nil
}
// Ensure fmt is used
var _ = fmt.Sprintf
+251
View File
@@ -0,0 +1,251 @@
package memory
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
_ "github.com/lib/pq"
)
// Store 记忆持久化存储(PostgreSQL + pgvector
type Store struct {
db *sql.DB
}
// NewStore 创建记忆存储
func NewStore(connStr string) (*Store, error) {
db, err := sql.Open("postgres", connStr)
if err != nil {
return nil, fmt.Errorf("连接数据库失败: %w", err)
}
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("数据库ping失败: %w", err)
}
s := &Store{db: db}
if err := s.migrate(); err != nil {
return nil, fmt.Errorf("数据库迁移失败: %w", err)
}
return s, nil
}
// migrate 创建表结构
func (s *Store) migrate() error {
queries := []string{
`CREATE EXTENSION IF NOT EXISTS vector`,
`CREATE TABLE IF NOT EXISTS memories (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id VARCHAR(64) NOT NULL,
content TEXT NOT NULL,
summary TEXT DEFAULT '',
category VARCHAR(32) DEFAULT 'other',
priority INT DEFAULT 1,
session_id VARCHAR(64) DEFAULT '',
source TEXT DEFAULT '',
embedding vector(1536),
access_count INT DEFAULT 0,
last_access TIMESTAMPTZ DEFAULT NOW(),
created_at TIMESTAMPTZ DEFAULT NOW(),
expires_at TIMESTAMPTZ
)`,
`CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category)`,
`CREATE INDEX IF NOT EXISTS idx_memories_priority ON memories(priority)`,
`CREATE INDEX IF NOT EXISTS idx_memories_user_priority ON memories(user_id, priority DESC)`,
}
for _, q := range queries {
if _, err := s.db.Exec(q); err != nil {
return fmt.Errorf("执行迁移 '%s' 失败: %w", q[:50], err)
}
}
return nil
}
// Save 保存记忆
func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
query := `INSERT INTO memories (user_id, content, summary, category, priority, session_id, source, embedding, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id, created_at`
var embedding interface{}
if len(entry.Embedding) > 0 {
vec := make([]float64, len(entry.Embedding))
for i, v := range entry.Embedding {
vec[i] = float64(v)
}
embedding = fmt.Sprintf("[%s]", joinFloats(vec))
}
return s.db.QueryRowContext(ctx, query,
entry.UserID, entry.Content, entry.Summary,
string(entry.Category), int(entry.Priority),
entry.SessionID, entry.Source, embedding, entry.ExpiresAt,
).Scan(&entry.ID, &entry.CreatedAt)
}
// GetByID 根据ID获取记忆
func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, error) {
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
access_count, last_access, created_at, expires_at
FROM memories WHERE id = $1`
entry := &model.MemoryEntry{}
var category string
err := s.db.QueryRowContext(ctx, query, id).Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.SessionID, &entry.Source,
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询记忆失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
// 更新访问计数
go s.incrementAccess(context.Background(), id)
return entry, nil
}
// Query 按条件查询记忆
func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryEntry, error) {
if q.Limit <= 0 {
q.Limit = 10
}
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
access_count, last_access, created_at, expires_at
FROM memories WHERE user_id = $1`
args := []interface{}{q.UserID}
argIdx := 2
if q.Category != "" {
query += fmt.Sprintf(" AND category = $%d", argIdx)
args = append(args, string(q.Category))
argIdx++
}
if q.Priority >= 0 {
query += fmt.Sprintf(" AND priority >= $%d", argIdx)
args = append(args, int(q.Priority))
argIdx++
}
query += fmt.Sprintf(" ORDER BY priority DESC, created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
args = append(args, q.Limit, q.Offset)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("查询记忆失败: %w", err)
}
defer rows.Close()
var entries []model.MemoryEntry
for rows.Next() {
var entry model.MemoryEntry
var category string
if err := rows.Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.SessionID, &entry.Source,
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
); err != nil {
return nil, fmt.Errorf("扫描记忆行失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entries = append(entries, entry)
}
return entries, rows.Err()
}
// Delete 删除记忆
func (s *Store) Delete(ctx context.Context, id string) error {
_, err := s.db.ExecContext(ctx, `DELETE FROM memories WHERE id = $1`, id)
return err
}
// PurgeExpired 清理过期记忆
func (s *Store) PurgeExpired(ctx context.Context) (int64, error) {
result, err := s.db.ExecContext(ctx,
`DELETE FROM memories WHERE expires_at IS NOT NULL AND expires_at < NOW()`)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// SearchByVector 向量相似度搜索
func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []float64, limit int) ([]model.MemoryEntry, error) {
if limit <= 0 {
limit = 5
}
vecStr := fmt.Sprintf("[%s]", joinFloats(embedding))
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
access_count, last_access, created_at, expires_at,
1 - (embedding <=> $1) AS similarity
FROM memories
WHERE user_id = $2 AND embedding IS NOT NULL
ORDER BY embedding <=> $1
LIMIT $3`
rows, err := s.db.QueryContext(ctx, query, vecStr, userID, limit)
if err != nil {
return nil, fmt.Errorf("向量搜索失败: %w", err)
}
defer rows.Close()
var entries []model.MemoryEntry
for rows.Next() {
var entry model.MemoryEntry
var category string
var similarity float64
if err := rows.Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.SessionID, &entry.Source,
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
&similarity,
); err != nil {
return nil, fmt.Errorf("扫描向量搜索结果失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entries = append(entries, entry)
}
return entries, rows.Err()
}
func (s *Store) incrementAccess(ctx context.Context, id string) {
s.db.ExecContext(ctx,
`UPDATE memories SET access_count = access_count + 1, last_access = NOW() WHERE id = $1`, id)
}
// Close 关闭数据库连接
func (s *Store) Close() error {
return s.db.Close()
}
// joinFloats 将 float64 切片转为逗号分隔字符串
func joinFloats(vec []float64) string {
if len(vec) == 0 {
return ""
}
s := fmt.Sprintf("%f", vec[0])
for i := 1; i < len(vec); i++ {
s += fmt.Sprintf(",%f", vec[i])
}
return s
}