package memory import ( "context" "fmt" "strings" "github.com/yourname/cyrene-ai/ai-core/internal/model" ) // MemoryEntry 记忆条目别名(避免与model包冲突) type MemoryEntry = model.MemoryEntry // Retriever 记忆检索器 type Retriever struct { store *Store embedder Embedder // 文本转向量的接口 } // Embedder 文本嵌入接口 type Embedder interface { Embed(ctx context.Context, text string) ([]float64, error) } // SimpleEmbedder 基于关键词的简单嵌入(MVP阶段可用,无需外部API) type SimpleEmbedder struct{} // Embed 简单的关键词哈希嵌入(用于MVP快速验证) func (e *SimpleEmbedder) Embed(ctx context.Context, text string) ([]float64, error) { // 生成一个简单的1536维特征向量 // 基于字符频率的简单表示,用于MVP阶段 vec := make([]float64, 1536) runes := []rune(strings.ToLower(text)) for i, r := range runes { idx := int(r) % 1536 vec[idx] += 1.0 / float64(len(runes)) // 考虑位置信息 posIdx := (int(r) + i) % 1536 vec[posIdx] += 0.5 / float64(len(runes)) } return vec, nil } // NewRetriever 创建记忆检索器 func NewRetriever(store *Store, embedder Embedder) *Retriever { if embedder == nil { embedder = &SimpleEmbedder{} } return &Retriever{ store: store, embedder: embedder, } } // Retrieve 检索与查询相关的记忆 // 策略: 向量相似度 + 关键词匹配混合 → 按重要性降序返回 func (r *Retriever) Retrieve(ctx context.Context, userID string, query string) ([]MemoryEntry, error) { var allEntries []MemoryEntry seen := make(map[string]bool) // 1. 向量相似度检索 embedding, err := r.embedder.Embed(ctx, query) if err == nil { vecEntries, err := r.store.SearchByVector(ctx, userID, embedding, 8) if err == nil { for _, e := range vecEntries { if !seen[e.ID] { seen[e.ID] = true allEntries = append(allEntries, e) } } } } // 2. 关键词匹配检索(包含关键词标签匹配) keywordEntries, err := r.keywordSearch(ctx, userID, query) if err == nil { for _, e := range keywordEntries { if !seen[e.ID] { seen[e.ID] = true allEntries = append(allEntries, e) } } } // 3. 如果没有匹配,返回最近的重要记忆 if len(allEntries) == 0 { recentEntries, err := r.store.Query(ctx, model.MemoryQuery{ UserID: userID, Priority: model.MemoryImportant, Limit: 5, }) if err == nil { allEntries = recentEntries } } // 4. 去重合并:对高度相似的记忆只保留Importance更高的 allEntries = r.deduplicate(allEntries) // 5. 按重要性降序排列 sortByImportance(allEntries) // 限制返回数量 if len(allEntries) > 10 { allEntries = allEntries[:10] } return allEntries, nil } // RetrieveByCategory 按分类检索记忆 func (r *Retriever) RetrieveByCategory(ctx context.Context, userID string, category model.MemoryCategory, limit int) ([]MemoryEntry, error) { if limit <= 0 { limit = 20 } return r.store.Query(ctx, model.MemoryQuery{ UserID: userID, Category: category, Limit: limit, }) } // keywordSearch 关键词匹配检索(包含关键词标签匹配) func (r *Retriever) keywordSearch(ctx context.Context, userID string, query string) ([]MemoryEntry, error) { // 查询最近的核心和重要记忆 entries, err := r.store.Query(ctx, model.MemoryQuery{ UserID: userID, Priority: model.MemoryImportant, Limit: 50, }) if err != nil { return nil, err } // 关键词匹配过滤 var matched []MemoryEntry queryLower := strings.ToLower(query) for _, entry := range entries { contentLower := strings.ToLower(entry.Content) summaryLower := strings.ToLower(entry.Summary) // 内容/摘要匹配 if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) { matched = append(matched, entry) continue } // 关键词标签匹配 for _, kw := range entry.Keywords { if strings.Contains(queryLower, strings.ToLower(kw)) || strings.Contains(strings.ToLower(kw), queryLower) { matched = append(matched, entry) break } } } // 也匹配普通记忆 normalEntries, err := r.store.Query(ctx, model.MemoryQuery{ UserID: userID, Priority: model.MemoryNormal, Limit: 100, }) if err == nil { for _, entry := range normalEntries { contentLower := strings.ToLower(entry.Content) summaryLower := strings.ToLower(entry.Summary) if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) { matched = append(matched, entry) continue } for _, kw := range entry.Keywords { if strings.Contains(queryLower, strings.ToLower(kw)) || strings.Contains(strings.ToLower(kw), queryLower) { matched = append(matched, entry) break } } } } return matched, nil } // deduplicate 去重合并:对高度相似的记忆只保留 Importance 更高的 func (r *Retriever) deduplicate(entries []MemoryEntry) []MemoryEntry { if len(entries) < 2 { return entries } result := make([]MemoryEntry, 0, len(entries)) discarded := make(map[int]bool) for i := 0; i < len(entries); i++ { if discarded[i] { continue } for j := i + 1; j < len(entries); j++ { if discarded[j] { continue } score := entries[i].SimilarityScore(&entries[j]) if score >= deDupThreshold { // 保留更重要的那条 if entries[j].Importance > entries[i].Importance || (entries[j].Importance == entries[i].Importance && entries[j].Priority > entries[i].Priority) { discarded[i] = true break } else { discarded[j] = true } } } if !discarded[i] { result = append(result, entries[i]) } } return result } // sortByImportance 按 Importance 降序, Priority 降序排列 func sortByImportance(entries []MemoryEntry) { for i := 0; i < len(entries); i++ { for j := i + 1; j < len(entries); j++ { if entries[j].Importance > entries[i].Importance || (entries[j].Importance == entries[i].Importance && entries[j].Priority > entries[i].Priority) { entries[i], entries[j] = entries[j], entries[i] } } } } // Ensure fmt is used var _ = fmt.Sprintf