feat: Phase 6.6 知识库 RAG 增强 — 文档索引 + 语义检索 + KnowledgeProvider
- rag.Embedder: LLM API 文本向量化 (OpenAI-compatible) - rag.KnowledgeStore: 文档分块 + 重叠窗口 + 余弦相似度搜索 - rag.Retriever: 高级知识检索 + 格式化摘要 - KnowledgeProvider: 子会话提供者,整合入编排管线 - knowledge_search / knowledge_ingest 工具 - EnrichmentData 管线全线支持 KnowledgeInfo Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/orchestrator"
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/persona"
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/rag"
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/subsession"
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/tools"
|
||||
)
|
||||
@@ -130,6 +131,13 @@ func main() {
|
||||
hostManager.SetAllowedDirs([]string{dataDir, os.TempDir(), "."})
|
||||
log.Printf("主机操控管理器已就绪: 沙箱执行 + 文件隔离 (数据目录=%s)", dataDir)
|
||||
|
||||
// 初始化 RAG 知识库 (Phase 6.6: 知识库 RAG 增强)
|
||||
knowledgeDir := getEnv("KNOWLEDGE_DIR", "./data/knowledge")
|
||||
ragEmbedder := rag.NewEmbedder(cfg.LLMBaseURL, cfg.LLMAPIKey, "text-embedding-3-small")
|
||||
knowledgeStore := rag.NewKnowledgeStore(ragEmbedder, knowledgeDir)
|
||||
knowledgeRetriever := rag.NewRetriever(knowledgeStore)
|
||||
log.Printf("RAG 知识库已就绪: 目录=%s, 嵌入模型=text-embedding-3-small", knowledgeDir)
|
||||
|
||||
// 初始化工具注册中心
|
||||
toolRegistry := tools.NewRegistry()
|
||||
if getEnvBool("ENABLE_TOOLS", true) {
|
||||
@@ -161,6 +169,12 @@ func main() {
|
||||
|
||||
// Phase 6.3: 视觉理解工具
|
||||
toolRegistry.Register(tools.NewVisionTool())
|
||||
|
||||
// Phase 6.6: 知识库 RAG 工具
|
||||
if knowledgeRetriever != nil {
|
||||
toolRegistry.Register(tools.NewKnowledgeSearchTool(knowledgeRetriever))
|
||||
toolRegistry.Register(tools.NewKnowledgeIngestTool(knowledgeStore))
|
||||
}
|
||||
log.Printf("工具注册中心已就绪: %d 个工具 (%v)", len(toolRegistry.ListTools()), toolRegistry.ListTools())
|
||||
}
|
||||
|
||||
@@ -236,6 +250,9 @@ func main() {
|
||||
subManager.Register(subsession.NewIoTProvider(iotClient, personaDir))
|
||||
}
|
||||
subManager.Register(subsession.NewReviewProvider())
|
||||
if knowledgeRetriever != nil {
|
||||
subManager.Register(subsession.NewKnowledgeProvider(knowledgeRetriever))
|
||||
}
|
||||
log.Printf("子会话管理器已就绪: %d 个提供者 (%v)", len(subManager.ListProviders()), subManager.ListProviders())
|
||||
|
||||
// 构建新的 Orchestrator (v2.0) — 传入 purpose 专用适配器
|
||||
|
||||
@@ -135,7 +135,7 @@ func DefaultAutonomousToolPolicy() *AutonomousToolPolicy {
|
||||
"iot_query", "iot_control", "memory_search", "web_search",
|
||||
"calculator", "datetime", "web_fetch",
|
||||
"host_exec", "host_file", "host_system",
|
||||
"vision_analyze",
|
||||
"vision_analyze", "knowledge_search", "knowledge_ingest",
|
||||
},
|
||||
MaxToolCallsPerRound: 5,
|
||||
MaxHighRiskPerHour: 10,
|
||||
|
||||
@@ -7,6 +7,7 @@ type EnrichmentData struct {
|
||||
MemorySummary string
|
||||
ThoughtOutline string
|
||||
IoTSummary string
|
||||
KnowledgeInfo string
|
||||
}
|
||||
|
||||
// SessionEnrichmentStore is a thread-safe per-session cache for async
|
||||
|
||||
@@ -243,10 +243,11 @@ func (o *Orchestrator) ProcessInput(
|
||||
if o.enrichmentStore != nil {
|
||||
prevEnrichment = o.enrichmentStore.Get(params.SessionID)
|
||||
if prevEnrichment != nil {
|
||||
logger.Printf("[orchestrator] 加载上一轮富化结果: memory=%t thought=%t iot=%t",
|
||||
logger.Printf("[orchestrator] 加载上一轮富化结果: memory=%t thought=%t iot=%t knowledge=%t",
|
||||
prevEnrichment.MemorySummary != "",
|
||||
prevEnrichment.ThoughtOutline != "",
|
||||
prevEnrichment.IoTSummary != "")
|
||||
prevEnrichment.IoTSummary != "",
|
||||
prevEnrichment.KnowledgeInfo != "")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,6 +274,7 @@ func (o *Orchestrator) ProcessInput(
|
||||
synthParams.MemorySummary = prevEnrichment.MemorySummary
|
||||
synthParams.ThoughtOutline = prevEnrichment.ThoughtOutline
|
||||
synthParams.IoTSummary = prevEnrichment.IoTSummary
|
||||
synthParams.KnowledgeInfo = prevEnrichment.KnowledgeInfo
|
||||
}
|
||||
|
||||
// 异步收集子会话结果,存入 enrichmentStore 供下一轮使用
|
||||
@@ -300,6 +302,8 @@ func (o *Orchestrator) ProcessInput(
|
||||
logger.Printf("[orchestrator] 通用对话子会话完成: %s", result.Summary)
|
||||
case model.SubSessionIoT:
|
||||
enriched.IoTSummary = result.Summary
|
||||
case model.SubSessionKnowledge:
|
||||
enriched.KnowledgeInfo = result.Summary
|
||||
logger.Printf("[orchestrator] IoT 子会话完成: %s", result.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ type SynthesizeParams struct {
|
||||
ThoughtOutline string // 通用对话思考
|
||||
IoTSummary string // IoT 操作摘要
|
||||
DeviceContext string // 设备状态上下文
|
||||
KnowledgeInfo string // 知识库检索摘要
|
||||
Mode string // text / voice_assistant
|
||||
}
|
||||
|
||||
@@ -95,6 +96,14 @@ func (s *Synthesizer) buildSynthesizeMessages(params SynthesizeParams) []model.L
|
||||
})
|
||||
}
|
||||
|
||||
// 注入知识库检索结果
|
||||
if params.KnowledgeInfo != "" && !strings.Contains(params.KnowledgeInfo, "未找到") {
|
||||
messages = append(messages, model.LLMMessage{
|
||||
Role: model.RoleSystem,
|
||||
Content: fmt.Sprintf("【知识库参考资料】\n%s", params.KnowledgeInfo),
|
||||
})
|
||||
}
|
||||
|
||||
// 注入对话历史
|
||||
if len(params.DialogHistory) > 0 {
|
||||
messages = append(messages, params.DialogHistory...)
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Embedder creates text embeddings using an LLM API.
|
||||
type Embedder struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
model string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewEmbedder creates a new embedding service.
|
||||
func NewEmbedder(baseURL, apiKey, model string) *Embedder {
|
||||
return &Embedder{
|
||||
baseURL: baseURL,
|
||||
apiKey: apiKey,
|
||||
model: model,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type embeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type embeddingResponse struct {
|
||||
Data []embeddingData `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage embeddingUsage `json:"usage,omitempty"`
|
||||
Error *embeddingError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type embeddingData struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type embeddingUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type embeddingError struct {
|
||||
Message string `json:"message"`
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
// Embed generates an embedding vector for the given text.
|
||||
func (e *Embedder) Embed(ctx context.Context, text string) ([]float64, error) {
|
||||
return e.EmbedBatch(ctx, []string{text})
|
||||
}
|
||||
|
||||
// EmbedBatch generates embeddings for multiple texts.
|
||||
func (e *Embedder) EmbedBatch(ctx context.Context, texts []string) ([]float64, error) {
|
||||
if !e.IsAvailable() {
|
||||
return nil, fmt.Errorf("embedding service not available: no API key configured")
|
||||
}
|
||||
|
||||
reqBody := embeddingRequest{
|
||||
Input: texts,
|
||||
Model: e.model,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal embedding request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", e.baseURL+"/embeddings", bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create embedding request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+e.apiKey)
|
||||
|
||||
resp, err := e.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embedding request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read embedding response: %w", err)
|
||||
}
|
||||
|
||||
var embResp embeddingResponse
|
||||
if err := json.Unmarshal(body, &embResp); err != nil {
|
||||
return nil, fmt.Errorf("parse embedding response: %w", err)
|
||||
}
|
||||
|
||||
if embResp.Error != nil {
|
||||
return nil, fmt.Errorf("embedding API error: %s (code=%s)", embResp.Error.Message, embResp.Error.Code)
|
||||
}
|
||||
|
||||
if len(embResp.Data) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
|
||||
return embResp.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
// IsAvailable checks if the embedding service is configured.
|
||||
func (e *Embedder) IsAvailable() bool {
|
||||
return e.apiKey != "" && e.baseURL != ""
|
||||
}
|
||||
@@ -0,0 +1,287 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Chunk represents a document chunk with its embedding.
|
||||
type Chunk struct {
|
||||
ID string `json:"id"`
|
||||
DocID string `json:"doc_id"`
|
||||
DocTitle string `json:"doc_title"`
|
||||
Content string `json:"content"`
|
||||
Index int `json:"index"`
|
||||
Embedding []float64 `json:"-"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SearchResult represents a retrieved knowledge chunk.
|
||||
type SearchResult struct {
|
||||
Chunk Chunk `json:"chunk"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
|
||||
// KnowledgeStore manages document chunks and provides semantic search.
|
||||
type KnowledgeStore struct {
|
||||
mu sync.RWMutex
|
||||
chunks []Chunk
|
||||
embedder *Embedder
|
||||
knowledgeDir string
|
||||
}
|
||||
|
||||
// NewKnowledgeStore creates a new knowledge store.
|
||||
func NewKnowledgeStore(embedder *Embedder, knowledgeDir string) *KnowledgeStore {
|
||||
if knowledgeDir == "" {
|
||||
knowledgeDir = "./data/knowledge"
|
||||
}
|
||||
return &KnowledgeStore{
|
||||
embedder: embedder,
|
||||
knowledgeDir: knowledgeDir,
|
||||
}
|
||||
}
|
||||
|
||||
// IngestDirectory scans a directory and indexes all supported files.
|
||||
func (ks *KnowledgeStore) IngestDirectory(ctx context.Context) (int, error) {
|
||||
if _, err := os.Stat(ks.knowledgeDir); os.IsNotExist(err) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var count int
|
||||
err := filepath.Walk(ks.knowledgeDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil || info.IsDir() {
|
||||
return err
|
||||
}
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
if !isSupportedFile(ext) {
|
||||
return nil
|
||||
}
|
||||
n, err := ks.IngestFile(ctx, path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ingest %s: %w", path, err)
|
||||
}
|
||||
count += n
|
||||
return nil
|
||||
})
|
||||
return count, err
|
||||
}
|
||||
|
||||
// IngestFile reads and indexes a single file.
|
||||
func (ks *KnowledgeStore) IngestFile(ctx context.Context, path string) (int, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
docID := hashString(path)
|
||||
title := filepath.Base(path)
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
|
||||
var text string
|
||||
switch ext {
|
||||
case ".md", ".txt", ".go", ".py", ".js", ".ts", ".tsx", ".jsx",
|
||||
".json", ".yaml", ".yml", ".toml", ".xml", ".html", ".css",
|
||||
".sh", ".bat", ".ps1", ".java", ".rs", ".c", ".cpp", ".h":
|
||||
text = string(data)
|
||||
default:
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
chunks := chunkText(text, 1024, 256)
|
||||
if len(chunks) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
texts := make([]string, len(chunks))
|
||||
for i, c := range chunks {
|
||||
texts[i] = c
|
||||
}
|
||||
|
||||
embedding, err := ks.embedder.EmbedBatch(ctx, texts)
|
||||
_ = embedding // single embedding for batch — use per-chunk embeddings for accuracy
|
||||
|
||||
var indexed int
|
||||
for i, chunk := range chunks {
|
||||
chunkID := fmt.Sprintf("%s:%d", docID, i)
|
||||
chunkEmbedding, _ := ks.embedder.Embed(ctx, chunk)
|
||||
c := Chunk{
|
||||
ID: chunkID,
|
||||
DocID: docID,
|
||||
DocTitle: title,
|
||||
Content: chunk,
|
||||
Index: i,
|
||||
Embedding: chunkEmbedding,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
ks.mu.Lock()
|
||||
// Replace existing chunks for this doc
|
||||
ks.removeDoc(docID)
|
||||
ks.chunks = append(ks.chunks, c)
|
||||
ks.mu.Unlock()
|
||||
indexed++
|
||||
}
|
||||
|
||||
return indexed, nil
|
||||
}
|
||||
|
||||
// Search performs semantic search over the knowledge base.
|
||||
func (ks *KnowledgeStore) Search(ctx context.Context, query string, topK int) ([]SearchResult, error) {
|
||||
if topK <= 0 {
|
||||
topK = 5
|
||||
}
|
||||
|
||||
queryEmbedding, err := ks.embedder.Embed(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("embed query: %w", err)
|
||||
}
|
||||
|
||||
ks.mu.RLock()
|
||||
defer ks.mu.RUnlock()
|
||||
|
||||
if len(ks.chunks) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var results []SearchResult
|
||||
for _, chunk := range ks.chunks {
|
||||
score := cosineSimilarity(queryEmbedding, chunk.Embedding)
|
||||
// Also boost by keyword match
|
||||
keywordScore := keywordMatchScore(query, chunk.Content)
|
||||
combinedScore := score*0.7 + keywordScore*0.3
|
||||
|
||||
results = append(results, SearchResult{
|
||||
Chunk: chunk,
|
||||
Score: combinedScore,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return results[i].Score > results[j].Score
|
||||
})
|
||||
|
||||
if len(results) > topK {
|
||||
results = results[:topK]
|
||||
}
|
||||
|
||||
// Filter out very low relevance
|
||||
var filtered []SearchResult
|
||||
for _, r := range results {
|
||||
if r.Score > 0.01 {
|
||||
filtered = append(filtered, r)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// Stats returns knowledge base statistics.
|
||||
func (ks *KnowledgeStore) Stats() map[string]interface{} {
|
||||
ks.mu.RLock()
|
||||
defer ks.mu.RUnlock()
|
||||
|
||||
docs := make(map[string]int)
|
||||
for _, c := range ks.chunks {
|
||||
docs[c.DocTitle]++
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_chunks": len(ks.chunks),
|
||||
"total_docs": len(docs),
|
||||
"documents": docs,
|
||||
"knowledge_dir": ks.knowledgeDir,
|
||||
}
|
||||
}
|
||||
|
||||
func (ks *KnowledgeStore) removeDoc(docID string) {
|
||||
filtered := ks.chunks[:0]
|
||||
for _, c := range ks.chunks {
|
||||
if c.DocID != docID {
|
||||
filtered = append(filtered, c)
|
||||
}
|
||||
}
|
||||
ks.chunks = filtered
|
||||
}
|
||||
|
||||
// chunkText splits text into overlapping chunks.
|
||||
func chunkText(text string, chunkSize, overlap int) []string {
|
||||
if len(text) <= chunkSize {
|
||||
return []string{text}
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
runes := []rune(text)
|
||||
step := chunkSize - overlap
|
||||
if step <= 0 {
|
||||
step = chunkSize
|
||||
}
|
||||
|
||||
for i := 0; i < len(runes); i += step {
|
||||
end := i + chunkSize
|
||||
if end > len(runes) {
|
||||
end = len(runes)
|
||||
}
|
||||
chunks = append(chunks, string(runes[i:end]))
|
||||
if end == len(runes) {
|
||||
break
|
||||
}
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
// cosineSimilarity computes cosine similarity between two vectors.
|
||||
func cosineSimilarity(a, b []float64) float64 {
|
||||
if len(a) != len(b) || len(a) == 0 {
|
||||
return 0
|
||||
}
|
||||
var dot, normA, normB float64
|
||||
for i := range a {
|
||||
dot += a[i] * b[i]
|
||||
normA += a[i] * a[i]
|
||||
normB += b[i] * b[i]
|
||||
}
|
||||
if normA == 0 || normB == 0 {
|
||||
return 0
|
||||
}
|
||||
return dot / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
// keywordMatchScore computes a simple keyword overlap score.
|
||||
func keywordMatchScore(query, text string) float64 {
|
||||
queryLower := strings.ToLower(query)
|
||||
textLower := strings.ToLower(text)
|
||||
queryWords := strings.Fields(queryLower)
|
||||
if len(queryWords) == 0 {
|
||||
return 0
|
||||
}
|
||||
matchCount := 0
|
||||
for _, w := range queryWords {
|
||||
if len(w) >= 2 && strings.Contains(textLower, w) {
|
||||
matchCount++
|
||||
}
|
||||
}
|
||||
return float64(matchCount) / float64(len(queryWords))
|
||||
}
|
||||
|
||||
func hashString(s string) string {
|
||||
h := sha256.Sum256([]byte(s))
|
||||
return fmt.Sprintf("%x", h[:8])
|
||||
}
|
||||
|
||||
func isSupportedFile(ext string) bool {
|
||||
switch ext {
|
||||
case ".md", ".txt", ".go", ".py", ".js", ".ts", ".tsx", ".jsx",
|
||||
".json", ".yaml", ".yml", ".toml", ".xml", ".html", ".css",
|
||||
".sh", ".bat", ".ps1", ".java", ".rs", ".c", ".cpp", ".h":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Retriever provides a high-level knowledge retrieval interface.
|
||||
type Retriever struct {
|
||||
store *KnowledgeStore
|
||||
}
|
||||
|
||||
// NewRetriever creates a new knowledge retriever.
|
||||
func NewRetriever(store *KnowledgeStore) *Retriever {
|
||||
return &Retriever{store: store}
|
||||
}
|
||||
|
||||
// Retrieve searches the knowledge base and returns formatted results.
|
||||
func (r *Retriever) Retrieve(ctx context.Context, query string, topK int) (*RetrievalResult, error) {
|
||||
results, err := r.store.Search(ctx, query, topK)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("knowledge search: %w", err)
|
||||
}
|
||||
|
||||
ret := &RetrievalResult{
|
||||
Query: query,
|
||||
Results: results,
|
||||
Summary: r.buildSummary(results),
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// RetrievalResult holds knowledge retrieval output.
|
||||
type RetrievalResult struct {
|
||||
Query string `json:"query"`
|
||||
Results []SearchResult `json:"results"`
|
||||
Summary string `json:"summary"`
|
||||
}
|
||||
|
||||
func (r *Retriever) buildSummary(results []SearchResult) string {
|
||||
if len(results) == 0 {
|
||||
return "知识库中未找到相关信息。"
|
||||
}
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("从知识库中找到 %d 条相关信息:\n\n", len(results)))
|
||||
for i, result := range results {
|
||||
sb.WriteString(fmt.Sprintf("--- 来源: %s (段落 %d, 相关度 %.0f%%) ---\n",
|
||||
result.Chunk.DocTitle, result.Chunk.Index+1, result.Score*100))
|
||||
sb.WriteString(result.Chunk.Content)
|
||||
if i < len(results)-1 {
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// Stats returns knowledge base statistics.
|
||||
func (r *Retriever) Stats() map[string]interface{} {
|
||||
return r.store.Stats()
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package subsession
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/rag"
|
||||
"github.com/yourname/cyrene-ai/pkg/logger"
|
||||
)
|
||||
|
||||
// KnowledgeProvider searches the knowledge base for relevant information.
|
||||
type KnowledgeProvider struct {
|
||||
retriever *rag.Retriever
|
||||
}
|
||||
|
||||
// NewKnowledgeProvider creates a knowledge subsession provider.
|
||||
func NewKnowledgeProvider(retriever *rag.Retriever) *KnowledgeProvider {
|
||||
return &KnowledgeProvider{retriever: retriever}
|
||||
}
|
||||
|
||||
func (p *KnowledgeProvider) Type() model.SubSessionType {
|
||||
return model.SubSessionKnowledge
|
||||
}
|
||||
|
||||
func (p *KnowledgeProvider) CanHandle(_ context.Context, intent *model.IntentResult, _ string) bool {
|
||||
if intent == nil {
|
||||
return true
|
||||
}
|
||||
// Activate for technical questions, how-to queries, and factual questions
|
||||
switch intent.Primary {
|
||||
case "knowledge", "technical", "how_to", "factual", "research":
|
||||
return true
|
||||
case "chat":
|
||||
// For general chat, only search if there might be relevant info
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *KnowledgeProvider) Priority() int {
|
||||
return 3
|
||||
}
|
||||
|
||||
func (p *KnowledgeProvider) Timeout() time.Duration {
|
||||
return 15 * time.Second
|
||||
}
|
||||
|
||||
func (p *KnowledgeProvider) CreateContext(ctx context.Context, params CreateContextParams) ([]model.LLMMessage, error) {
|
||||
return []model.LLMMessage{
|
||||
{Role: model.RoleSystem, Content: "知识库检索子会话"},
|
||||
{Role: model.RoleUser, Content: params.UserMessage},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *KnowledgeProvider) Execute(ctx context.Context, subCtx []model.LLMMessage) (*model.SubSessionResult, error) {
|
||||
userMessage := ""
|
||||
for i := len(subCtx) - 1; i >= 0; i-- {
|
||||
if subCtx[i].Role == model.RoleUser {
|
||||
userMessage = subCtx[i].Content
|
||||
break
|
||||
}
|
||||
}
|
||||
if userMessage == "" {
|
||||
return nil, fmt.Errorf("无法提取用户消息")
|
||||
}
|
||||
|
||||
result := &model.SubSessionResult{
|
||||
Type: model.SubSessionKnowledge,
|
||||
Confidence: 0,
|
||||
}
|
||||
|
||||
if p.retriever == nil {
|
||||
result.Summary = "(知识库未就绪)"
|
||||
return result, nil
|
||||
}
|
||||
|
||||
retrieval, err := p.retriever.Retrieve(ctx, userMessage, 3)
|
||||
if err != nil {
|
||||
logger.Printf("[knowledge-subsession] 知识检索失败: %v", err)
|
||||
result.Error = fmt.Sprintf("检索失败: %v", err)
|
||||
result.Summary = "(知识库检索失败)"
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if len(retrieval.Results) == 0 {
|
||||
result.Summary = "(未找到相关知识)"
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result.Summary = retrieval.Summary
|
||||
result.Confidence = 0.6
|
||||
logger.Printf("[knowledge-subsession] 完成: 找到 %d 条知识", len(retrieval.Results))
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/rag"
|
||||
)
|
||||
|
||||
// KnowledgeSearchTool searches the knowledge base.
|
||||
type KnowledgeSearchTool struct {
|
||||
retriever *rag.Retriever
|
||||
}
|
||||
|
||||
// NewKnowledgeSearchTool creates a knowledge search tool.
|
||||
func NewKnowledgeSearchTool(retriever *rag.Retriever) *KnowledgeSearchTool {
|
||||
return &KnowledgeSearchTool{retriever: retriever}
|
||||
}
|
||||
|
||||
func (t *KnowledgeSearchTool) Definition() ToolDefinition {
|
||||
return ToolDefinition{
|
||||
Name: "knowledge_search",
|
||||
Description: "搜索本地知识库。从文档、代码、笔记等中检索相关信息,支持语义搜索和关键词匹配。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "搜索查询",
|
||||
},
|
||||
"top_k": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "返回结果数量,默认5条,最大10条",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *KnowledgeSearchTool) Execute(ctx context.Context, args map[string]interface{}) (*ToolResult, error) {
|
||||
query, _ := args["query"].(string)
|
||||
if query == "" {
|
||||
return &ToolResult{
|
||||
ToolName: "knowledge_search",
|
||||
Success: false,
|
||||
Error: "query 参数不能为空",
|
||||
}, nil
|
||||
}
|
||||
|
||||
topK := 5
|
||||
if v, ok := args["top_k"].(float64); ok {
|
||||
topK = int(v)
|
||||
if topK > 10 {
|
||||
topK = 10
|
||||
}
|
||||
}
|
||||
|
||||
result, err := t.retriever.Retrieve(ctx, query, topK)
|
||||
if err != nil {
|
||||
return &ToolResult{
|
||||
ToolName: "knowledge_search",
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("知识库搜索失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
count := 0
|
||||
if result.Results != nil {
|
||||
count = len(result.Results)
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"query": result.Query,
|
||||
"summary": result.Summary,
|
||||
"count": count,
|
||||
})
|
||||
|
||||
return &ToolResult{
|
||||
ToolName: "knowledge_search",
|
||||
Success: true,
|
||||
Data: string(data),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// KnowledgeIngestTool allows ingesting documents into the knowledge base.
|
||||
type KnowledgeIngestTool struct {
|
||||
store *rag.KnowledgeStore
|
||||
}
|
||||
|
||||
// NewKnowledgeIngestTool creates a knowledge ingestion tool.
|
||||
func NewKnowledgeIngestTool(store *rag.KnowledgeStore) *KnowledgeIngestTool {
|
||||
return &KnowledgeIngestTool{store: store}
|
||||
}
|
||||
|
||||
func (t *KnowledgeIngestTool) Definition() ToolDefinition {
|
||||
return ToolDefinition{
|
||||
Name: "knowledge_ingest",
|
||||
Description: "将文件导入知识库。支持 .md .txt .go .py .js .ts .json 等常见文件格式。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "文件路径或目录路径",
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *KnowledgeIngestTool) Execute(ctx context.Context, args map[string]interface{}) (*ToolResult, error) {
|
||||
path, _ := args["path"].(string)
|
||||
if path == "" {
|
||||
return &ToolResult{
|
||||
ToolName: "knowledge_ingest",
|
||||
Success: false,
|
||||
Error: "path 参数不能为空",
|
||||
}, nil
|
||||
}
|
||||
|
||||
count, err := t.store.IngestFile(ctx, path)
|
||||
if err != nil {
|
||||
return &ToolResult{
|
||||
ToolName: "knowledge_ingest",
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("知识导入失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
// Try directory
|
||||
count, err = t.store.IngestDirectory(ctx)
|
||||
if err != nil {
|
||||
return &ToolResult{
|
||||
ToolName: "knowledge_ingest",
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("目录导入失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"path": path,
|
||||
"chunks_indexed": count,
|
||||
"status": "ok",
|
||||
})
|
||||
|
||||
return &ToolResult{
|
||||
ToolName: "knowledge_ingest",
|
||||
Success: true,
|
||||
Data: string(data),
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user