dev 分支暂存
This commit is contained in:
@@ -0,0 +1,152 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
)
|
||||
|
||||
// MemoryEntry 记忆条目别名(避免与model包冲突)
|
||||
type MemoryEntry = model.MemoryEntry
|
||||
|
||||
// Retriever 记忆检索器
|
||||
type Retriever struct {
|
||||
store *Store
|
||||
embedder Embedder // 文本转向量的接口
|
||||
}
|
||||
|
||||
// Embedder 文本嵌入接口
|
||||
type Embedder interface {
|
||||
Embed(ctx context.Context, text string) ([]float64, error)
|
||||
}
|
||||
|
||||
// SimpleEmbedder 基于关键词的简单嵌入(MVP阶段可用,无需外部API)
|
||||
type SimpleEmbedder struct{}
|
||||
|
||||
// Embed 简单的关键词哈希嵌入(用于MVP快速验证)
|
||||
func (e *SimpleEmbedder) Embed(ctx context.Context, text string) ([]float64, error) {
|
||||
// 生成一个简单的1536维特征向量
|
||||
// 基于字符频率的简单表示,用于MVP阶段
|
||||
vec := make([]float64, 1536)
|
||||
|
||||
runes := []rune(strings.ToLower(text))
|
||||
for i, r := range runes {
|
||||
idx := int(r) % 1536
|
||||
vec[idx] += 1.0 / float64(len(runes))
|
||||
// 考虑位置信息
|
||||
posIdx := (int(r) + i) % 1536
|
||||
vec[posIdx] += 0.5 / float64(len(runes))
|
||||
}
|
||||
|
||||
return vec, nil
|
||||
}
|
||||
|
||||
// NewRetriever 创建记忆检索器
|
||||
func NewRetriever(store *Store, embedder Embedder) *Retriever {
|
||||
if embedder == nil {
|
||||
embedder = &SimpleEmbedder{}
|
||||
}
|
||||
return &Retriever{
|
||||
store: store,
|
||||
embedder: embedder,
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve 检索与查询相关的记忆
|
||||
// 策略: 向量相似度 + 关键词匹配混合
|
||||
func (r *Retriever) Retrieve(ctx context.Context, userID string, query string) ([]MemoryEntry, error) {
|
||||
var allEntries []MemoryEntry
|
||||
seen := make(map[string]bool)
|
||||
|
||||
// 1. 向量相似度检索
|
||||
embedding, err := r.embedder.Embed(ctx, query)
|
||||
if err == nil {
|
||||
vecEntries, err := r.store.SearchByVector(ctx, userID, embedding, 5)
|
||||
if err == nil {
|
||||
for _, e := range vecEntries {
|
||||
if !seen[e.ID] {
|
||||
seen[e.ID] = true
|
||||
allEntries = append(allEntries, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 关键词匹配检索(核心/重要记忆优先)
|
||||
keywordEntries, err := r.keywordSearch(ctx, userID, query)
|
||||
if err == nil {
|
||||
for _, e := range keywordEntries {
|
||||
if !seen[e.ID] {
|
||||
seen[e.ID] = true
|
||||
allEntries = append(allEntries, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 如果没有匹配,返回最近的重要记忆
|
||||
if len(allEntries) == 0 {
|
||||
recentEntries, err := r.store.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Priority: int(model.MemoryImportant),
|
||||
Limit: 3,
|
||||
})
|
||||
if err == nil {
|
||||
allEntries = recentEntries
|
||||
}
|
||||
}
|
||||
|
||||
// 限制返回数量
|
||||
if len(allEntries) > 10 {
|
||||
allEntries = allEntries[:10]
|
||||
}
|
||||
|
||||
return allEntries, nil
|
||||
}
|
||||
|
||||
// keywordSearch 关键词匹配检索
|
||||
func (r *Retriever) keywordSearch(ctx context.Context, userID string, query string) ([]MemoryEntry, error) {
|
||||
// 查询最近的核心和重要记忆
|
||||
entries, err := r.store.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Priority: int(model.MemoryImportant),
|
||||
Limit: 50,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 简单的关键词匹配过滤
|
||||
var matched []MemoryEntry
|
||||
queryLower := strings.ToLower(query)
|
||||
|
||||
for _, entry := range entries {
|
||||
contentLower := strings.ToLower(entry.Content)
|
||||
summaryLower := strings.ToLower(entry.Summary)
|
||||
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
|
||||
matched = append(matched, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// 也匹配普通记忆
|
||||
normalEntries, err := r.store.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Priority: int(model.MemoryNormal),
|
||||
Limit: 100,
|
||||
})
|
||||
if err == nil {
|
||||
for _, entry := range normalEntries {
|
||||
contentLower := strings.ToLower(entry.Content)
|
||||
summaryLower := strings.ToLower(entry.Summary)
|
||||
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
|
||||
matched = append(matched, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return matched, nil
|
||||
}
|
||||
|
||||
// Ensure fmt is used
|
||||
var _ = fmt.Sprintf
|
||||
|
||||
Reference in New Issue
Block a user