dev 分支暂存
This commit is contained in:
@@ -0,0 +1,206 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
)
|
||||
|
||||
// Extractor 记忆提取器 —— 从对话中提取结构化记忆
|
||||
type Extractor struct {
|
||||
store *Store
|
||||
llmChat func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)
|
||||
}
|
||||
|
||||
// NewExtractor 创建记忆提取器
|
||||
// llmChat: LLM对话函数,用于分析对话内容并提取记忆
|
||||
// 如果为nil,则使用规则提取(降级模式)
|
||||
func NewExtractor(store *Store, llmChat func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)) *Extractor {
|
||||
return &Extractor{
|
||||
store: store,
|
||||
llmChat: llmChat,
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractAndStore 从一轮对话中提取记忆并存储
|
||||
// 异步执行,不阻塞主流程
|
||||
func (e *Extractor) ExtractAndStore(ctx context.Context, userID, sessionID, userMessage, assistantResponse string) {
|
||||
memories, err := e.extract(ctx, userMessage, assistantResponse)
|
||||
if err != nil {
|
||||
log.Printf("[memory] 记忆提取失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, mem := range memories {
|
||||
mem.UserID = userID
|
||||
mem.SessionID = sessionID
|
||||
|
||||
if err := e.store.Save(ctx, &mem); err != nil {
|
||||
log.Printf("[memory] 记忆保存失败: %v", err)
|
||||
continue
|
||||
}
|
||||
log.Printf("[memory] 新记忆已保存 [%s]: %s", mem.Category, mem.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
// extract 从对话中提取记忆
|
||||
func (e *Extractor) extract(ctx context.Context, userMessage, assistantResponse string) ([]model.MemoryEntry, error) {
|
||||
// 如果有LLM,使用LLM提取
|
||||
if e.llmChat != nil {
|
||||
return e.extractWithLLM(ctx, userMessage, assistantResponse)
|
||||
}
|
||||
// 降级:规则提取
|
||||
return e.extractWithRules(userMessage, assistantResponse), nil
|
||||
}
|
||||
|
||||
// MemoryExtractionResult LLM提取结果的结构
|
||||
type MemoryExtractionResult struct {
|
||||
Memories []struct {
|
||||
Content string `json:"content"`
|
||||
Summary string `json:"summary"`
|
||||
Category string `json:"category"`
|
||||
Priority int `json:"priority"`
|
||||
} `json:"memories"`
|
||||
}
|
||||
|
||||
// extractWithLLM 使用LLM提取记忆
|
||||
func (e *Extractor) extractWithLLM(ctx context.Context, userMessage, assistantResponse string) ([]model.MemoryEntry, error) {
|
||||
prompt := fmt.Sprintf(`分析以下对话,提取关于用户(开拓者)的重要信息作为记忆。
|
||||
|
||||
用户消息: %s
|
||||
昔涟回复: %s
|
||||
|
||||
请以JSON格式返回提取的记忆。每条记忆需要包含:
|
||||
- content: 完整的记忆内容(一句话描述)
|
||||
- summary: 简短摘要(10字以内)
|
||||
- category: 分类 (preference/fact/event/relationship/habit/other)
|
||||
- priority: 优先级 (0=临时, 1=普通, 2=重要, 3=核心)
|
||||
|
||||
只提取有意义的信息,不要提取无意义的闲聊。如果没有值得记住的内容,返回空数组。
|
||||
|
||||
输出格式:
|
||||
{"memories": [{"content": "...", "summary": "...", "category": "...", "priority": 1}]}
|
||||
`, userMessage, assistantResponse)
|
||||
|
||||
resp, err := e.llmChat(ctx, []model.LLMMessage{
|
||||
{Role: "system", Content: "你是一个记忆提取助手。你只输出JSON格式的结果,不输出其他内容。"},
|
||||
{Role: "user", Content: prompt},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("LLM提取记忆失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析JSON
|
||||
result := MemoryExtractionResult{}
|
||||
content := extractJSON(resp.Content)
|
||||
if err := json.Unmarshal([]byte(content), &result); err != nil {
|
||||
// 尝试作为数组解析
|
||||
var arrResult []struct {
|
||||
Content string `json:"content"`
|
||||
Summary string `json:"summary"`
|
||||
Category string `json:"category"`
|
||||
Priority int `json:"priority"`
|
||||
}
|
||||
if err2 := json.Unmarshal([]byte(content), &arrResult); err2 != nil {
|
||||
return nil, fmt.Errorf("解析记忆JSON失败: %w (原始: %s)", err, content[:min(len(content), 100)])
|
||||
}
|
||||
result.Memories = arrResult
|
||||
}
|
||||
|
||||
var entries []model.MemoryEntry
|
||||
for _, m := range result.Memories {
|
||||
entries = append(entries, model.MemoryEntry{
|
||||
Content: m.Content,
|
||||
Summary: m.Summary,
|
||||
Category: model.MemoryCategory(m.Category),
|
||||
Priority: model.MemoryPriority(m.Priority),
|
||||
})
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// extractWithRules 基于规则提取记忆(降级方案)
|
||||
func (e *Extractor) extractWithRules(userMessage, _ string) []model.MemoryEntry {
|
||||
var entries []model.MemoryEntry
|
||||
|
||||
// 规则1: 检测用户偏好表达
|
||||
prefPatterns := map[string]string{
|
||||
"喜欢": "preference",
|
||||
"爱": "preference",
|
||||
"最喜欢": "preference",
|
||||
"讨厌": "preference",
|
||||
"不喜欢": "preference",
|
||||
"经常": "habit",
|
||||
"每天都": "habit",
|
||||
"一直": "habit",
|
||||
"我叫": "fact",
|
||||
"我是": "fact",
|
||||
"我家": "fact",
|
||||
"住在": "fact",
|
||||
"生日": "fact",
|
||||
}
|
||||
|
||||
for pattern, category := range prefPatterns {
|
||||
if idx := strings.Index(userMessage, pattern); idx != -1 {
|
||||
// 提取包含关键词的句子片段
|
||||
start := max(0, idx-5)
|
||||
end := min(len([]rune(userMessage)), idx+len([]rune(pattern))+15)
|
||||
content := strings.TrimSpace(string([]rune(userMessage)[start:end]))
|
||||
|
||||
entries = append(entries, model.MemoryEntry{
|
||||
Content: content,
|
||||
Summary: truncateString(content, 20),
|
||||
Category: model.MemoryCategory(category),
|
||||
Priority: model.MemoryNormal,
|
||||
})
|
||||
break // 每条消息最多提取一条规则记忆
|
||||
}
|
||||
}
|
||||
|
||||
return entries
|
||||
}
|
||||
|
||||
// extractJSON 从LLM回复中提取JSON内容
|
||||
func extractJSON(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
|
||||
// 移除 markdown 代码块标记
|
||||
if strings.HasPrefix(text, "```json") {
|
||||
text = strings.TrimPrefix(text, "```json")
|
||||
text = strings.TrimSuffix(text, "```")
|
||||
text = strings.TrimSpace(text)
|
||||
} else if strings.HasPrefix(text, "```") {
|
||||
text = strings.TrimPrefix(text, "```")
|
||||
text = strings.TrimSuffix(text, "```")
|
||||
text = strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
func truncateString(s string, maxLen int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return string(runes[:maxLen]) + "..."
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
)
|
||||
|
||||
// MemoryEntry 记忆条目别名(避免与model包冲突)
|
||||
type MemoryEntry = model.MemoryEntry
|
||||
|
||||
// Retriever 记忆检索器
|
||||
type Retriever struct {
|
||||
store *Store
|
||||
embedder Embedder // 文本转向量的接口
|
||||
}
|
||||
|
||||
// Embedder 文本嵌入接口
|
||||
type Embedder interface {
|
||||
Embed(ctx context.Context, text string) ([]float64, error)
|
||||
}
|
||||
|
||||
// SimpleEmbedder 基于关键词的简单嵌入(MVP阶段可用,无需外部API)
|
||||
type SimpleEmbedder struct{}
|
||||
|
||||
// Embed 简单的关键词哈希嵌入(用于MVP快速验证)
|
||||
func (e *SimpleEmbedder) Embed(ctx context.Context, text string) ([]float64, error) {
|
||||
// 生成一个简单的1536维特征向量
|
||||
// 基于字符频率的简单表示,用于MVP阶段
|
||||
vec := make([]float64, 1536)
|
||||
|
||||
runes := []rune(strings.ToLower(text))
|
||||
for i, r := range runes {
|
||||
idx := int(r) % 1536
|
||||
vec[idx] += 1.0 / float64(len(runes))
|
||||
// 考虑位置信息
|
||||
posIdx := (int(r) + i) % 1536
|
||||
vec[posIdx] += 0.5 / float64(len(runes))
|
||||
}
|
||||
|
||||
return vec, nil
|
||||
}
|
||||
|
||||
// NewRetriever 创建记忆检索器
|
||||
func NewRetriever(store *Store, embedder Embedder) *Retriever {
|
||||
if embedder == nil {
|
||||
embedder = &SimpleEmbedder{}
|
||||
}
|
||||
return &Retriever{
|
||||
store: store,
|
||||
embedder: embedder,
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve 检索与查询相关的记忆
|
||||
// 策略: 向量相似度 + 关键词匹配混合
|
||||
func (r *Retriever) Retrieve(ctx context.Context, userID string, query string) ([]MemoryEntry, error) {
|
||||
var allEntries []MemoryEntry
|
||||
seen := make(map[string]bool)
|
||||
|
||||
// 1. 向量相似度检索
|
||||
embedding, err := r.embedder.Embed(ctx, query)
|
||||
if err == nil {
|
||||
vecEntries, err := r.store.SearchByVector(ctx, userID, embedding, 5)
|
||||
if err == nil {
|
||||
for _, e := range vecEntries {
|
||||
if !seen[e.ID] {
|
||||
seen[e.ID] = true
|
||||
allEntries = append(allEntries, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 关键词匹配检索(核心/重要记忆优先)
|
||||
keywordEntries, err := r.keywordSearch(ctx, userID, query)
|
||||
if err == nil {
|
||||
for _, e := range keywordEntries {
|
||||
if !seen[e.ID] {
|
||||
seen[e.ID] = true
|
||||
allEntries = append(allEntries, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 如果没有匹配,返回最近的重要记忆
|
||||
if len(allEntries) == 0 {
|
||||
recentEntries, err := r.store.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Priority: int(model.MemoryImportant),
|
||||
Limit: 3,
|
||||
})
|
||||
if err == nil {
|
||||
allEntries = recentEntries
|
||||
}
|
||||
}
|
||||
|
||||
// 限制返回数量
|
||||
if len(allEntries) > 10 {
|
||||
allEntries = allEntries[:10]
|
||||
}
|
||||
|
||||
return allEntries, nil
|
||||
}
|
||||
|
||||
// keywordSearch 关键词匹配检索
|
||||
func (r *Retriever) keywordSearch(ctx context.Context, userID string, query string) ([]MemoryEntry, error) {
|
||||
// 查询最近的核心和重要记忆
|
||||
entries, err := r.store.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Priority: int(model.MemoryImportant),
|
||||
Limit: 50,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 简单的关键词匹配过滤
|
||||
var matched []MemoryEntry
|
||||
queryLower := strings.ToLower(query)
|
||||
|
||||
for _, entry := range entries {
|
||||
contentLower := strings.ToLower(entry.Content)
|
||||
summaryLower := strings.ToLower(entry.Summary)
|
||||
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
|
||||
matched = append(matched, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// 也匹配普通记忆
|
||||
normalEntries, err := r.store.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Priority: int(model.MemoryNormal),
|
||||
Limit: 100,
|
||||
})
|
||||
if err == nil {
|
||||
for _, entry := range normalEntries {
|
||||
contentLower := strings.ToLower(entry.Content)
|
||||
summaryLower := strings.ToLower(entry.Summary)
|
||||
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
|
||||
matched = append(matched, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return matched, nil
|
||||
}
|
||||
|
||||
// Ensure fmt is used
|
||||
var _ = fmt.Sprintf
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// Store 记忆持久化存储(PostgreSQL + pgvector)
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewStore 创建记忆存储
|
||||
func NewStore(connStr string) (*Store, error) {
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("数据库ping失败: %w", err)
|
||||
}
|
||||
|
||||
s := &Store{db: db}
|
||||
if err := s.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// migrate 创建表结构
|
||||
func (s *Store) migrate() error {
|
||||
queries := []string{
|
||||
`CREATE EXTENSION IF NOT EXISTS vector`,
|
||||
`CREATE TABLE IF NOT EXISTS memories (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id VARCHAR(64) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
summary TEXT DEFAULT '',
|
||||
category VARCHAR(32) DEFAULT 'other',
|
||||
priority INT DEFAULT 1,
|
||||
session_id VARCHAR(64) DEFAULT '',
|
||||
source TEXT DEFAULT '',
|
||||
embedding vector(1536),
|
||||
access_count INT DEFAULT 0,
|
||||
last_access TIMESTAMPTZ DEFAULT NOW(),
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_memories_priority ON memories(priority)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_memories_user_priority ON memories(user_id, priority DESC)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
if _, err := s.db.Exec(q); err != nil {
|
||||
return fmt.Errorf("执行迁移 '%s' 失败: %w", q[:50], err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save 保存记忆
|
||||
func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
query := `INSERT INTO memories (user_id, content, summary, category, priority, session_id, source, embedding, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING id, created_at`
|
||||
|
||||
var embedding interface{}
|
||||
if len(entry.Embedding) > 0 {
|
||||
vec := make([]float64, len(entry.Embedding))
|
||||
for i, v := range entry.Embedding {
|
||||
vec[i] = float64(v)
|
||||
}
|
||||
embedding = fmt.Sprintf("[%s]", joinFloats(vec))
|
||||
}
|
||||
|
||||
return s.db.QueryRowContext(ctx, query,
|
||||
entry.UserID, entry.Content, entry.Summary,
|
||||
string(entry.Category), int(entry.Priority),
|
||||
entry.SessionID, entry.Source, embedding, entry.ExpiresAt,
|
||||
).Scan(&entry.ID, &entry.CreatedAt)
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取记忆
|
||||
func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, error) {
|
||||
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
|
||||
access_count, last_access, created_at, expires_at
|
||||
FROM memories WHERE id = $1`
|
||||
|
||||
entry := &model.MemoryEntry{}
|
||||
var category string
|
||||
err := s.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
|
||||
&category, &entry.Priority, &entry.SessionID, &entry.Source,
|
||||
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询记忆失败: %w", err)
|
||||
}
|
||||
entry.Category = model.MemoryCategory(category)
|
||||
|
||||
// 更新访问计数
|
||||
go s.incrementAccess(context.Background(), id)
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// Query 按条件查询记忆
|
||||
func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryEntry, error) {
|
||||
if q.Limit <= 0 {
|
||||
q.Limit = 10
|
||||
}
|
||||
|
||||
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
|
||||
access_count, last_access, created_at, expires_at
|
||||
FROM memories WHERE user_id = $1`
|
||||
args := []interface{}{q.UserID}
|
||||
argIdx := 2
|
||||
|
||||
if q.Category != "" {
|
||||
query += fmt.Sprintf(" AND category = $%d", argIdx)
|
||||
args = append(args, string(q.Category))
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if q.Priority >= 0 {
|
||||
query += fmt.Sprintf(" AND priority >= $%d", argIdx)
|
||||
args = append(args, int(q.Priority))
|
||||
argIdx++
|
||||
}
|
||||
|
||||
query += fmt.Sprintf(" ORDER BY priority DESC, created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
|
||||
args = append(args, q.Limit, q.Offset)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询记忆失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []model.MemoryEntry
|
||||
for rows.Next() {
|
||||
var entry model.MemoryEntry
|
||||
var category string
|
||||
if err := rows.Scan(
|
||||
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
|
||||
&category, &entry.Priority, &entry.SessionID, &entry.Source,
|
||||
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("扫描记忆行失败: %w", err)
|
||||
}
|
||||
entry.Category = model.MemoryCategory(category)
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
// Delete 删除记忆
|
||||
func (s *Store) Delete(ctx context.Context, id string) error {
|
||||
_, err := s.db.ExecContext(ctx, `DELETE FROM memories WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// PurgeExpired 清理过期记忆
|
||||
func (s *Store) PurgeExpired(ctx context.Context) (int64, error) {
|
||||
result, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM memories WHERE expires_at IS NOT NULL AND expires_at < NOW()`)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// SearchByVector 向量相似度搜索
|
||||
func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []float64, limit int) ([]model.MemoryEntry, error) {
|
||||
if limit <= 0 {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
vecStr := fmt.Sprintf("[%s]", joinFloats(embedding))
|
||||
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
|
||||
access_count, last_access, created_at, expires_at,
|
||||
1 - (embedding <=> $1) AS similarity
|
||||
FROM memories
|
||||
WHERE user_id = $2 AND embedding IS NOT NULL
|
||||
ORDER BY embedding <=> $1
|
||||
LIMIT $3`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, vecStr, userID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("向量搜索失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []model.MemoryEntry
|
||||
for rows.Next() {
|
||||
var entry model.MemoryEntry
|
||||
var category string
|
||||
var similarity float64
|
||||
if err := rows.Scan(
|
||||
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
|
||||
&category, &entry.Priority, &entry.SessionID, &entry.Source,
|
||||
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
|
||||
&similarity,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("扫描向量搜索结果失败: %w", err)
|
||||
}
|
||||
entry.Category = model.MemoryCategory(category)
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) incrementAccess(ctx context.Context, id string) {
|
||||
s.db.ExecContext(ctx,
|
||||
`UPDATE memories SET access_count = access_count + 1, last_access = NOW() WHERE id = $1`, id)
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (s *Store) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
// joinFloats 将 float64 切片转为逗号分隔字符串
|
||||
func joinFloats(vec []float64) string {
|
||||
if len(vec) == 0 {
|
||||
return ""
|
||||
}
|
||||
s := fmt.Sprintf("%f", vec[0])
|
||||
for i := 1; i < len(vec); i++ {
|
||||
s += fmt.Sprintf(",%f", vec[i])
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user