feat: Phase 6.6 知识库 RAG 增强 — 文档索引 + 语义检索 + KnowledgeProvider

- rag.Embedder: LLM API 文本向量化 (OpenAI-compatible)
- rag.KnowledgeStore: 文档分块 + 重叠窗口 + 余弦相似度搜索
- rag.Retriever: 高级知识检索 + 格式化摘要
- KnowledgeProvider: 子会话提供者,整合入编排管线
- knowledge_search / knowledge_ingest 工具
- EnrichmentData 管线全线支持 KnowledgeInfo

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-23 22:33:26 +08:00
parent 9a8fb8d0ce
commit cd83eec39e
10 changed files with 752 additions and 3 deletions
+118
View File
@@ -0,0 +1,118 @@
package rag
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
// Embedder creates text embeddings using an LLM API.
type Embedder struct {
baseURL string
apiKey string
model string
httpClient *http.Client
}
// NewEmbedder creates a new embedding service.
func NewEmbedder(baseURL, apiKey, model string) *Embedder {
return &Embedder{
baseURL: baseURL,
apiKey: apiKey,
model: model,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
type embeddingRequest struct {
Input []string `json:"input"`
Model string `json:"model"`
}
type embeddingResponse struct {
Data []embeddingData `json:"data"`
Model string `json:"model"`
Usage embeddingUsage `json:"usage,omitempty"`
Error *embeddingError `json:"error,omitempty"`
}
type embeddingData struct {
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type embeddingUsage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
}
type embeddingError struct {
Message string `json:"message"`
Code string `json:"code"`
}
// Embed generates an embedding vector for the given text.
func (e *Embedder) Embed(ctx context.Context, text string) ([]float64, error) {
return e.EmbedBatch(ctx, []string{text})
}
// EmbedBatch generates embeddings for multiple texts.
func (e *Embedder) EmbedBatch(ctx context.Context, texts []string) ([]float64, error) {
if !e.IsAvailable() {
return nil, fmt.Errorf("embedding service not available: no API key configured")
}
reqBody := embeddingRequest{
Input: texts,
Model: e.model,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal embedding request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", e.baseURL+"/embeddings", bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("create embedding request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+e.apiKey)
resp, err := e.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("embedding request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read embedding response: %w", err)
}
var embResp embeddingResponse
if err := json.Unmarshal(body, &embResp); err != nil {
return nil, fmt.Errorf("parse embedding response: %w", err)
}
if embResp.Error != nil {
return nil, fmt.Errorf("embedding API error: %s (code=%s)", embResp.Error.Message, embResp.Error.Code)
}
if len(embResp.Data) == 0 {
return nil, fmt.Errorf("no embedding returned")
}
return embResp.Data[0].Embedding, nil
}
// IsAvailable checks if the embedding service is configured.
func (e *Embedder) IsAvailable() bool {
return e.apiKey != "" && e.baseURL != ""
}
@@ -0,0 +1,287 @@
package rag
import (
"context"
"crypto/sha256"
"fmt"
"math"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
)
// Chunk represents a document chunk with its embedding.
type Chunk struct {
ID string `json:"id"`
DocID string `json:"doc_id"`
DocTitle string `json:"doc_title"`
Content string `json:"content"`
Index int `json:"index"`
Embedding []float64 `json:"-"`
CreatedAt time.Time `json:"created_at"`
}
// SearchResult represents a retrieved knowledge chunk.
type SearchResult struct {
Chunk Chunk `json:"chunk"`
Score float64 `json:"score"`
}
// KnowledgeStore manages document chunks and provides semantic search.
type KnowledgeStore struct {
mu sync.RWMutex
chunks []Chunk
embedder *Embedder
knowledgeDir string
}
// NewKnowledgeStore creates a new knowledge store.
func NewKnowledgeStore(embedder *Embedder, knowledgeDir string) *KnowledgeStore {
if knowledgeDir == "" {
knowledgeDir = "./data/knowledge"
}
return &KnowledgeStore{
embedder: embedder,
knowledgeDir: knowledgeDir,
}
}
// IngestDirectory scans a directory and indexes all supported files.
func (ks *KnowledgeStore) IngestDirectory(ctx context.Context) (int, error) {
if _, err := os.Stat(ks.knowledgeDir); os.IsNotExist(err) {
return 0, nil
}
var count int
err := filepath.Walk(ks.knowledgeDir, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() {
return err
}
ext := strings.ToLower(filepath.Ext(path))
if !isSupportedFile(ext) {
return nil
}
n, err := ks.IngestFile(ctx, path)
if err != nil {
return fmt.Errorf("ingest %s: %w", path, err)
}
count += n
return nil
})
return count, err
}
// IngestFile reads and indexes a single file.
func (ks *KnowledgeStore) IngestFile(ctx context.Context, path string) (int, error) {
data, err := os.ReadFile(path)
if err != nil {
return 0, err
}
docID := hashString(path)
title := filepath.Base(path)
ext := strings.ToLower(filepath.Ext(path))
var text string
switch ext {
case ".md", ".txt", ".go", ".py", ".js", ".ts", ".tsx", ".jsx",
".json", ".yaml", ".yml", ".toml", ".xml", ".html", ".css",
".sh", ".bat", ".ps1", ".java", ".rs", ".c", ".cpp", ".h":
text = string(data)
default:
return 0, nil
}
chunks := chunkText(text, 1024, 256)
if len(chunks) == 0 {
return 0, nil
}
texts := make([]string, len(chunks))
for i, c := range chunks {
texts[i] = c
}
embedding, err := ks.embedder.EmbedBatch(ctx, texts)
_ = embedding // single embedding for batch — use per-chunk embeddings for accuracy
var indexed int
for i, chunk := range chunks {
chunkID := fmt.Sprintf("%s:%d", docID, i)
chunkEmbedding, _ := ks.embedder.Embed(ctx, chunk)
c := Chunk{
ID: chunkID,
DocID: docID,
DocTitle: title,
Content: chunk,
Index: i,
Embedding: chunkEmbedding,
CreatedAt: time.Now(),
}
ks.mu.Lock()
// Replace existing chunks for this doc
ks.removeDoc(docID)
ks.chunks = append(ks.chunks, c)
ks.mu.Unlock()
indexed++
}
return indexed, nil
}
// Search performs semantic search over the knowledge base.
func (ks *KnowledgeStore) Search(ctx context.Context, query string, topK int) ([]SearchResult, error) {
if topK <= 0 {
topK = 5
}
queryEmbedding, err := ks.embedder.Embed(ctx, query)
if err != nil {
return nil, fmt.Errorf("embed query: %w", err)
}
ks.mu.RLock()
defer ks.mu.RUnlock()
if len(ks.chunks) == 0 {
return nil, nil
}
var results []SearchResult
for _, chunk := range ks.chunks {
score := cosineSimilarity(queryEmbedding, chunk.Embedding)
// Also boost by keyword match
keywordScore := keywordMatchScore(query, chunk.Content)
combinedScore := score*0.7 + keywordScore*0.3
results = append(results, SearchResult{
Chunk: chunk,
Score: combinedScore,
})
}
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
if len(results) > topK {
results = results[:topK]
}
// Filter out very low relevance
var filtered []SearchResult
for _, r := range results {
if r.Score > 0.01 {
filtered = append(filtered, r)
}
}
return filtered, nil
}
// Stats returns knowledge base statistics.
func (ks *KnowledgeStore) Stats() map[string]interface{} {
ks.mu.RLock()
defer ks.mu.RUnlock()
docs := make(map[string]int)
for _, c := range ks.chunks {
docs[c.DocTitle]++
}
return map[string]interface{}{
"total_chunks": len(ks.chunks),
"total_docs": len(docs),
"documents": docs,
"knowledge_dir": ks.knowledgeDir,
}
}
func (ks *KnowledgeStore) removeDoc(docID string) {
filtered := ks.chunks[:0]
for _, c := range ks.chunks {
if c.DocID != docID {
filtered = append(filtered, c)
}
}
ks.chunks = filtered
}
// chunkText splits text into overlapping chunks.
func chunkText(text string, chunkSize, overlap int) []string {
if len(text) <= chunkSize {
return []string{text}
}
var chunks []string
runes := []rune(text)
step := chunkSize - overlap
if step <= 0 {
step = chunkSize
}
for i := 0; i < len(runes); i += step {
end := i + chunkSize
if end > len(runes) {
end = len(runes)
}
chunks = append(chunks, string(runes[i:end]))
if end == len(runes) {
break
}
}
return chunks
}
// cosineSimilarity computes cosine similarity between two vectors.
func cosineSimilarity(a, b []float64) float64 {
if len(a) != len(b) || len(a) == 0 {
return 0
}
var dot, normA, normB float64
for i := range a {
dot += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
if normA == 0 || normB == 0 {
return 0
}
return dot / (math.Sqrt(normA) * math.Sqrt(normB))
}
// keywordMatchScore computes a simple keyword overlap score.
func keywordMatchScore(query, text string) float64 {
queryLower := strings.ToLower(query)
textLower := strings.ToLower(text)
queryWords := strings.Fields(queryLower)
if len(queryWords) == 0 {
return 0
}
matchCount := 0
for _, w := range queryWords {
if len(w) >= 2 && strings.Contains(textLower, w) {
matchCount++
}
}
return float64(matchCount) / float64(len(queryWords))
}
func hashString(s string) string {
h := sha256.Sum256([]byte(s))
return fmt.Sprintf("%x", h[:8])
}
func isSupportedFile(ext string) bool {
switch ext {
case ".md", ".txt", ".go", ".py", ".js", ".ts", ".tsx", ".jsx",
".json", ".yaml", ".yml", ".toml", ".xml", ".html", ".css",
".sh", ".bat", ".ps1", ".java", ".rs", ".c", ".cpp", ".h":
return true
default:
return false
}
}
+61
View File
@@ -0,0 +1,61 @@
package rag
import (
"context"
"fmt"
"strings"
)
// Retriever provides a high-level knowledge retrieval interface.
type Retriever struct {
store *KnowledgeStore
}
// NewRetriever creates a new knowledge retriever.
func NewRetriever(store *KnowledgeStore) *Retriever {
return &Retriever{store: store}
}
// Retrieve searches the knowledge base and returns formatted results.
func (r *Retriever) Retrieve(ctx context.Context, query string, topK int) (*RetrievalResult, error) {
results, err := r.store.Search(ctx, query, topK)
if err != nil {
return nil, fmt.Errorf("knowledge search: %w", err)
}
ret := &RetrievalResult{
Query: query,
Results: results,
Summary: r.buildSummary(results),
}
return ret, nil
}
// RetrievalResult holds knowledge retrieval output.
type RetrievalResult struct {
Query string `json:"query"`
Results []SearchResult `json:"results"`
Summary string `json:"summary"`
}
func (r *Retriever) buildSummary(results []SearchResult) string {
if len(results) == 0 {
return "知识库中未找到相关信息。"
}
var sb strings.Builder
sb.WriteString(fmt.Sprintf("从知识库中找到 %d 条相关信息:\n\n", len(results)))
for i, result := range results {
sb.WriteString(fmt.Sprintf("--- 来源: %s (段落 %d, 相关度 %.0f%%) ---\n",
result.Chunk.DocTitle, result.Chunk.Index+1, result.Score*100))
sb.WriteString(result.Chunk.Content)
if i < len(results)-1 {
sb.WriteString("\n\n")
}
}
return sb.String()
}
// Stats returns knowledge base statistics.
func (r *Retriever) Stats() map[string]interface{} {
return r.store.Stats()
}