dev 分支暂存
This commit is contained in:
@@ -0,0 +1,83 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
)
|
||||
|
||||
// Adapter LLM适配器接口
|
||||
// 支持不同的LLM后端(OpenAI、Ollama、vLLM等)
|
||||
type Adapter struct {
|
||||
provider LLMProvider
|
||||
}
|
||||
|
||||
// LLMProvider LLM提供商接口
|
||||
type LLMProvider interface {
|
||||
// Chat 同步对话
|
||||
Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)
|
||||
|
||||
// ChatStream 流式对话,返回一个channel逐token推送
|
||||
ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error)
|
||||
|
||||
// ModelName 返回当前使用的模型名称
|
||||
ModelName() string
|
||||
}
|
||||
|
||||
// StreamChunk 流式响应的单个片段
|
||||
type StreamChunk struct {
|
||||
Content string // delta内容
|
||||
Done bool // 是否为最后一块
|
||||
Error error // 错误信息
|
||||
Usage *model.Usage // 最后一块时返回token统计
|
||||
}
|
||||
|
||||
// NewAdapter 创建LLM适配器
|
||||
func NewAdapter(provider LLMProvider) *Adapter {
|
||||
return &Adapter{provider: provider}
|
||||
}
|
||||
|
||||
// Chat 同步对话
|
||||
func (a *Adapter) Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) {
|
||||
return a.provider.Chat(ctx, messages)
|
||||
}
|
||||
|
||||
// ChatStream 流式对话
|
||||
func (a *Adapter) ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error) {
|
||||
return a.provider.ChatStream(ctx, messages)
|
||||
}
|
||||
|
||||
// ModelName 返回模型名称
|
||||
func (a *Adapter) ModelName() string {
|
||||
return a.provider.ModelName()
|
||||
}
|
||||
|
||||
// collectStream 辅助函数:将流式响应收集为完整响应
|
||||
func collectStream(ch <-chan StreamChunk) (*model.LLMResponse, error) {
|
||||
var content string
|
||||
var lastUsage *model.Usage
|
||||
|
||||
for chunk := range ch {
|
||||
if chunk.Error != nil {
|
||||
return nil, chunk.Error
|
||||
}
|
||||
if chunk.Done {
|
||||
lastUsage = chunk.Usage
|
||||
break
|
||||
}
|
||||
content += chunk.Content
|
||||
}
|
||||
|
||||
resp := &model.LLMResponse{
|
||||
Content: content,
|
||||
FinishReason: "stop",
|
||||
}
|
||||
if lastUsage != nil {
|
||||
resp.Usage = *lastUsage
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Ensure io is used (will be needed for SSE parsing)
|
||||
var _ io.Reader
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
)
|
||||
|
||||
// OpenAIConfig OpenAI适配器配置
|
||||
type OpenAIConfig struct {
|
||||
BaseURL string // API基础URL
|
||||
APIKey string // API密钥
|
||||
Model string // 主模型
|
||||
FallbackModel string // 备用模型(主模型不可用时)
|
||||
MaxRetries int // 最大重试次数
|
||||
Timeout time.Duration // 请求超时
|
||||
}
|
||||
|
||||
// OpenAIProvider OpenAI兼容的LLM提供商
|
||||
type OpenAIProvider struct {
|
||||
config OpenAIConfig
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewOpenAIProvider 创建OpenAI提供商
|
||||
func NewOpenAIProvider(cfg OpenAIConfig) *OpenAIProvider {
|
||||
if cfg.MaxRetries == 0 {
|
||||
cfg.MaxRetries = 3
|
||||
}
|
||||
if cfg.Timeout == 0 {
|
||||
cfg.Timeout = 60 * time.Second
|
||||
}
|
||||
|
||||
return &OpenAIProvider{
|
||||
config: cfg,
|
||||
httpClient: &http.Client{
|
||||
Timeout: cfg.Timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// openAIRequest OpenAI请求结构
|
||||
type openAIRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type openAIMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// openAIResponse OpenAI响应结构
|
||||
type openAIResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Choices []openAIChoice `json:"choices"`
|
||||
Usage openAIUsage `json:"usage,omitempty"`
|
||||
Error *openAIError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type openAIChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message openAIMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type openAIUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type openAIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
// Chat 同步对话
|
||||
func (p *OpenAIProvider) Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) {
|
||||
resp, err := p.doChat(ctx, messages, p.config.Model, false)
|
||||
if err != nil {
|
||||
// 尝试fallback模型
|
||||
if p.config.FallbackModel != "" && p.config.FallbackModel != p.config.Model {
|
||||
log.Printf("[LLM] 主模型 %s 调用失败,降级到 %s: %v", p.config.Model, p.config.FallbackModel, err)
|
||||
return p.doChat(ctx, messages, p.config.FallbackModel, false)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ChatStream 流式对话
|
||||
func (p *OpenAIProvider) ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error) {
|
||||
ch := make(chan StreamChunk, 100)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
resp, err := p.doChatStream(ctx, messages, p.config.Model)
|
||||
if err != nil {
|
||||
// Fallback
|
||||
if p.config.FallbackModel != "" {
|
||||
log.Printf("[LLM] 流式调用主模型失败,降级: %v", err)
|
||||
resp, err = p.doChatStream(ctx, messages, p.config.FallbackModel)
|
||||
}
|
||||
if err != nil {
|
||||
ch <- StreamChunk{Error: err, Done: true}
|
||||
return
|
||||
}
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
// 增大scanner buffer以处理大块SSE数据
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
|
||||
// SSE格式: data: {...}
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
// 流结束标记
|
||||
if data == "[DONE]" {
|
||||
ch <- StreamChunk{Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
var streamResp openAIStreamResponse
|
||||
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(streamResp.Choices) > 0 {
|
||||
delta := streamResp.Choices[0].Delta
|
||||
if delta.Content != "" {
|
||||
ch <- StreamChunk{Content: delta.Content}
|
||||
}
|
||||
if streamResp.Choices[0].FinishReason != "" {
|
||||
usage := &model.Usage{}
|
||||
if streamResp.Usage != nil {
|
||||
usage.PromptTokens = streamResp.Usage.PromptTokens
|
||||
usage.CompletionTokens = streamResp.Usage.CompletionTokens
|
||||
usage.TotalTokens = streamResp.Usage.TotalTokens
|
||||
}
|
||||
ch <- StreamChunk{Done: true, Usage: usage}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
ch <- StreamChunk{Error: fmt.Errorf("读取流式响应失败: %w", err), Done: true}
|
||||
return
|
||||
}
|
||||
|
||||
ch <- StreamChunk{Done: true}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// openAIStreamResponse 流式响应结构
|
||||
type openAIStreamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Choices []openAIStreamChoice `json:"choices"`
|
||||
Usage *openAIUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type openAIStreamChoice struct {
|
||||
Index int `json:"index"`
|
||||
Delta openAIMessage `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// doChat 执行同步对话请求
|
||||
func (p *OpenAIProvider) doChat(ctx context.Context, messages []model.LLMMessage, model string, stream bool) (*model.LLMResponse, error) {
|
||||
// 转换消息格式
|
||||
oaiMessages := make([]openAIMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
oaiMessages[i] = openAIMessage{
|
||||
Role: string(msg.Role),
|
||||
Content: msg.Content,
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := openAIRequest{
|
||||
Model: model,
|
||||
Messages: oaiMessages,
|
||||
Temperature: 0.8,
|
||||
Stream: stream,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", p.config.BaseURL+"/chat/completions", bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp openAIResponse
|
||||
if json.Unmarshal(body, &errResp) == nil && errResp.Error != nil {
|
||||
return nil, fmt.Errorf("API错误 [%s]: %s", errResp.Error.Code, errResp.Error.Message)
|
||||
}
|
||||
return nil, fmt.Errorf("API返回状态码 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var oaiResp openAIResponse
|
||||
if err := json.Unmarshal(body, &oaiResp); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if len(oaiResp.Choices) == 0 {
|
||||
return nil, fmt.Errorf("API返回空choices")
|
||||
}
|
||||
|
||||
return &model.LLMResponse{
|
||||
Content: oaiResp.Choices[0].Message.Content,
|
||||
FinishReason: oaiResp.Choices[0].FinishReason,
|
||||
Usage: model.Usage{
|
||||
PromptTokens: oaiResp.Usage.PromptTokens,
|
||||
CompletionTokens: oaiResp.Usage.CompletionTokens,
|
||||
TotalTokens: oaiResp.Usage.TotalTokens,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// doChatStream 执行流式对话请求(返回原始HTTP响应)
|
||||
func (p *OpenAIProvider) doChatStream(ctx context.Context, messages []model.LLMMessage, model string) (*http.Response, error) {
|
||||
oaiMessages := make([]openAIMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
oaiMessages[i] = openAIMessage{
|
||||
Role: string(msg.Role),
|
||||
Content: msg.Content,
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := openAIRequest{
|
||||
Model: model,
|
||||
Messages: oaiMessages,
|
||||
Temperature: 0.8,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", p.config.BaseURL+"/chat/completions", bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API返回状态码 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ModelName 返回模型名称
|
||||
func (p *OpenAIProvider) ModelName() string {
|
||||
return p.config.Model
|
||||
}
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Segmenter 断句器 —— 将流式文本按句号切分为语音播放片段
|
||||
type Segmenter struct {
|
||||
mu sync.Mutex
|
||||
buffer strings.Builder
|
||||
segments []Segment
|
||||
index int
|
||||
}
|
||||
|
||||
// Segment 语音片段
|
||||
type Segment struct {
|
||||
Index int `json:"index"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// NewSegmenter 创建断句器
|
||||
func NewSegmenter() *Segmenter {
|
||||
return &Segmenter{}
|
||||
}
|
||||
|
||||
// Feed 喂入新的文本片段
|
||||
// 返回已完成的断句列表
|
||||
func (s *Segmenter) Feed(delta string) []Segment {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.buffer.WriteString(delta)
|
||||
content := s.buffer.String()
|
||||
|
||||
var newSegments []Segment
|
||||
|
||||
for {
|
||||
idx := findSentenceEnd(content)
|
||||
if idx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
segmentText := strings.TrimSpace(content[:idx+len(string(content[idx]))])
|
||||
// 检查是否是完整中文字符的句末
|
||||
// idx 指向标点符号的位置
|
||||
runes := []rune(content)
|
||||
var byteIdx int
|
||||
for i, r := range runes {
|
||||
if i == idx {
|
||||
// 标点之后的字符
|
||||
break
|
||||
}
|
||||
byteIdx += len(string(r))
|
||||
}
|
||||
|
||||
// 简化处理:直接取到idx+1字节 (对于ASCII标点)
|
||||
// 对于中文标点,需要用rune处理
|
||||
realIdx := 0
|
||||
runeCount := 0
|
||||
for i, r := range content {
|
||||
if runeCount == idx {
|
||||
realIdx = i
|
||||
break
|
||||
}
|
||||
runeCount++
|
||||
_ = r
|
||||
}
|
||||
// 包含标点符号本身
|
||||
endIdx := realIdx + len(string([]rune(content)[idx]))
|
||||
if endIdx <= realIdx {
|
||||
endIdx = realIdx + 3 // fallback for UTF-8 multi-byte
|
||||
}
|
||||
|
||||
segmentText = strings.TrimSpace(content[:endIdx])
|
||||
if segmentText == "" {
|
||||
content = strings.TrimSpace(content[endIdx:])
|
||||
s.buffer.Reset()
|
||||
s.buffer.WriteString(content)
|
||||
continue
|
||||
}
|
||||
|
||||
s.index++
|
||||
seg := Segment{
|
||||
Index: s.index,
|
||||
Text: segmentText,
|
||||
}
|
||||
s.segments = append(s.segments, seg)
|
||||
newSegments = append(newSegments, seg)
|
||||
|
||||
// 更新buffer,移除已处理的部分
|
||||
content = strings.TrimSpace(content[endIdx:])
|
||||
s.buffer.Reset()
|
||||
s.buffer.WriteString(content)
|
||||
}
|
||||
|
||||
return newSegments
|
||||
}
|
||||
|
||||
// Flush 强制输出buffer中剩余的内容
|
||||
func (s *Segmenter) Flush() *Segment {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
remaining := strings.TrimSpace(s.buffer.String())
|
||||
if remaining == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.index++
|
||||
seg := Segment{
|
||||
Index: s.index,
|
||||
Text: remaining,
|
||||
}
|
||||
s.segments = append(s.segments, seg)
|
||||
s.buffer.Reset()
|
||||
|
||||
return &seg
|
||||
}
|
||||
|
||||
// AllSegments 返回所有已完成的断句
|
||||
func (s *Segmenter) AllSegments() []Segment {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
result := make([]Segment, len(s.segments))
|
||||
copy(result, s.segments)
|
||||
return result
|
||||
}
|
||||
|
||||
// findSentenceEnd 查找句子结束位置(返回标点符号在rune数组中的索引)
|
||||
// 中文标点:。!? 英文标点:. ! ?
|
||||
func findSentenceEnd(text string) int {
|
||||
runes := []rune(text)
|
||||
for i, r := range runes {
|
||||
if isSentenceEnd(r) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// isSentenceEnd 判断是否为句末标点
|
||||
func isSentenceEnd(r rune) bool {
|
||||
switch r {
|
||||
case '。', '!', '?', '.', '!', '?', '\n':
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// splitIntoSegments 将完整文本按句号断句(用于post-processing)
|
||||
func splitIntoSegments(text string) []Segment {
|
||||
var segments []Segment
|
||||
runes := []rune(text)
|
||||
|
||||
start := 0
|
||||
index := 0
|
||||
|
||||
for i, r := range runes {
|
||||
if isSentenceEnd(r) {
|
||||
segText := strings.TrimSpace(string(runes[start : i+1]))
|
||||
if segText != "" {
|
||||
index++
|
||||
segments = append(segments, Segment{
|
||||
Index: index,
|
||||
Text: segText,
|
||||
})
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
|
||||
// 处理末尾无标点的剩余文本
|
||||
if start < len(runes) {
|
||||
remaining := strings.TrimSpace(string(runes[start:]))
|
||||
if remaining != "" {
|
||||
index++
|
||||
segments = append(segments, Segment{
|
||||
Index: index,
|
||||
Text: remaining,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return segments
|
||||
}
|
||||
|
||||
// Ensure unicode is used
|
||||
var _ = unicode.Is
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
)
|
||||
|
||||
// Extractor 记忆提取器 —— 从对话中提取结构化记忆
|
||||
type Extractor struct {
|
||||
store *Store
|
||||
llmChat func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)
|
||||
}
|
||||
|
||||
// NewExtractor 创建记忆提取器
|
||||
// llmChat: LLM对话函数,用于分析对话内容并提取记忆
|
||||
// 如果为nil,则使用规则提取(降级模式)
|
||||
func NewExtractor(store *Store, llmChat func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)) *Extractor {
|
||||
return &Extractor{
|
||||
store: store,
|
||||
llmChat: llmChat,
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractAndStore 从一轮对话中提取记忆并存储
|
||||
// 异步执行,不阻塞主流程
|
||||
func (e *Extractor) ExtractAndStore(ctx context.Context, userID, sessionID, userMessage, assistantResponse string) {
|
||||
memories, err := e.extract(ctx, userMessage, assistantResponse)
|
||||
if err != nil {
|
||||
log.Printf("[memory] 记忆提取失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, mem := range memories {
|
||||
mem.UserID = userID
|
||||
mem.SessionID = sessionID
|
||||
|
||||
if err := e.store.Save(ctx, &mem); err != nil {
|
||||
log.Printf("[memory] 记忆保存失败: %v", err)
|
||||
continue
|
||||
}
|
||||
log.Printf("[memory] 新记忆已保存 [%s]: %s", mem.Category, mem.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
// extract 从对话中提取记忆
|
||||
func (e *Extractor) extract(ctx context.Context, userMessage, assistantResponse string) ([]model.MemoryEntry, error) {
|
||||
// 如果有LLM,使用LLM提取
|
||||
if e.llmChat != nil {
|
||||
return e.extractWithLLM(ctx, userMessage, assistantResponse)
|
||||
}
|
||||
// 降级:规则提取
|
||||
return e.extractWithRules(userMessage, assistantResponse), nil
|
||||
}
|
||||
|
||||
// MemoryExtractionResult LLM提取结果的结构
|
||||
type MemoryExtractionResult struct {
|
||||
Memories []struct {
|
||||
Content string `json:"content"`
|
||||
Summary string `json:"summary"`
|
||||
Category string `json:"category"`
|
||||
Priority int `json:"priority"`
|
||||
} `json:"memories"`
|
||||
}
|
||||
|
||||
// extractWithLLM 使用LLM提取记忆
|
||||
func (e *Extractor) extractWithLLM(ctx context.Context, userMessage, assistantResponse string) ([]model.MemoryEntry, error) {
|
||||
prompt := fmt.Sprintf(`分析以下对话,提取关于用户(开拓者)的重要信息作为记忆。
|
||||
|
||||
用户消息: %s
|
||||
昔涟回复: %s
|
||||
|
||||
请以JSON格式返回提取的记忆。每条记忆需要包含:
|
||||
- content: 完整的记忆内容(一句话描述)
|
||||
- summary: 简短摘要(10字以内)
|
||||
- category: 分类 (preference/fact/event/relationship/habit/other)
|
||||
- priority: 优先级 (0=临时, 1=普通, 2=重要, 3=核心)
|
||||
|
||||
只提取有意义的信息,不要提取无意义的闲聊。如果没有值得记住的内容,返回空数组。
|
||||
|
||||
输出格式:
|
||||
{"memories": [{"content": "...", "summary": "...", "category": "...", "priority": 1}]}
|
||||
`, userMessage, assistantResponse)
|
||||
|
||||
resp, err := e.llmChat(ctx, []model.LLMMessage{
|
||||
{Role: "system", Content: "你是一个记忆提取助手。你只输出JSON格式的结果,不输出其他内容。"},
|
||||
{Role: "user", Content: prompt},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("LLM提取记忆失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析JSON
|
||||
result := MemoryExtractionResult{}
|
||||
content := extractJSON(resp.Content)
|
||||
if err := json.Unmarshal([]byte(content), &result); err != nil {
|
||||
// 尝试作为数组解析
|
||||
var arrResult []struct {
|
||||
Content string `json:"content"`
|
||||
Summary string `json:"summary"`
|
||||
Category string `json:"category"`
|
||||
Priority int `json:"priority"`
|
||||
}
|
||||
if err2 := json.Unmarshal([]byte(content), &arrResult); err2 != nil {
|
||||
return nil, fmt.Errorf("解析记忆JSON失败: %w (原始: %s)", err, content[:min(len(content), 100)])
|
||||
}
|
||||
result.Memories = arrResult
|
||||
}
|
||||
|
||||
var entries []model.MemoryEntry
|
||||
for _, m := range result.Memories {
|
||||
entries = append(entries, model.MemoryEntry{
|
||||
Content: m.Content,
|
||||
Summary: m.Summary,
|
||||
Category: model.MemoryCategory(m.Category),
|
||||
Priority: model.MemoryPriority(m.Priority),
|
||||
})
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// extractWithRules 基于规则提取记忆(降级方案)
|
||||
func (e *Extractor) extractWithRules(userMessage, _ string) []model.MemoryEntry {
|
||||
var entries []model.MemoryEntry
|
||||
|
||||
// 规则1: 检测用户偏好表达
|
||||
prefPatterns := map[string]string{
|
||||
"喜欢": "preference",
|
||||
"爱": "preference",
|
||||
"最喜欢": "preference",
|
||||
"讨厌": "preference",
|
||||
"不喜欢": "preference",
|
||||
"经常": "habit",
|
||||
"每天都": "habit",
|
||||
"一直": "habit",
|
||||
"我叫": "fact",
|
||||
"我是": "fact",
|
||||
"我家": "fact",
|
||||
"住在": "fact",
|
||||
"生日": "fact",
|
||||
}
|
||||
|
||||
for pattern, category := range prefPatterns {
|
||||
if idx := strings.Index(userMessage, pattern); idx != -1 {
|
||||
// 提取包含关键词的句子片段
|
||||
start := max(0, idx-5)
|
||||
end := min(len([]rune(userMessage)), idx+len([]rune(pattern))+15)
|
||||
content := strings.TrimSpace(string([]rune(userMessage)[start:end]))
|
||||
|
||||
entries = append(entries, model.MemoryEntry{
|
||||
Content: content,
|
||||
Summary: truncateString(content, 20),
|
||||
Category: model.MemoryCategory(category),
|
||||
Priority: model.MemoryNormal,
|
||||
})
|
||||
break // 每条消息最多提取一条规则记忆
|
||||
}
|
||||
}
|
||||
|
||||
return entries
|
||||
}
|
||||
|
||||
// extractJSON 从LLM回复中提取JSON内容
|
||||
func extractJSON(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
|
||||
// 移除 markdown 代码块标记
|
||||
if strings.HasPrefix(text, "```json") {
|
||||
text = strings.TrimPrefix(text, "```json")
|
||||
text = strings.TrimSuffix(text, "```")
|
||||
text = strings.TrimSpace(text)
|
||||
} else if strings.HasPrefix(text, "```") {
|
||||
text = strings.TrimPrefix(text, "```")
|
||||
text = strings.TrimSuffix(text, "```")
|
||||
text = strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
func truncateString(s string, maxLen int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return string(runes[:maxLen]) + "..."
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
)
|
||||
|
||||
// MemoryEntry 记忆条目别名(避免与model包冲突)
|
||||
type MemoryEntry = model.MemoryEntry
|
||||
|
||||
// Retriever 记忆检索器
|
||||
type Retriever struct {
|
||||
store *Store
|
||||
embedder Embedder // 文本转向量的接口
|
||||
}
|
||||
|
||||
// Embedder 文本嵌入接口
|
||||
type Embedder interface {
|
||||
Embed(ctx context.Context, text string) ([]float64, error)
|
||||
}
|
||||
|
||||
// SimpleEmbedder 基于关键词的简单嵌入(MVP阶段可用,无需外部API)
|
||||
type SimpleEmbedder struct{}
|
||||
|
||||
// Embed 简单的关键词哈希嵌入(用于MVP快速验证)
|
||||
func (e *SimpleEmbedder) Embed(ctx context.Context, text string) ([]float64, error) {
|
||||
// 生成一个简单的1536维特征向量
|
||||
// 基于字符频率的简单表示,用于MVP阶段
|
||||
vec := make([]float64, 1536)
|
||||
|
||||
runes := []rune(strings.ToLower(text))
|
||||
for i, r := range runes {
|
||||
idx := int(r) % 1536
|
||||
vec[idx] += 1.0 / float64(len(runes))
|
||||
// 考虑位置信息
|
||||
posIdx := (int(r) + i) % 1536
|
||||
vec[posIdx] += 0.5 / float64(len(runes))
|
||||
}
|
||||
|
||||
return vec, nil
|
||||
}
|
||||
|
||||
// NewRetriever 创建记忆检索器
|
||||
func NewRetriever(store *Store, embedder Embedder) *Retriever {
|
||||
if embedder == nil {
|
||||
embedder = &SimpleEmbedder{}
|
||||
}
|
||||
return &Retriever{
|
||||
store: store,
|
||||
embedder: embedder,
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve 检索与查询相关的记忆
|
||||
// 策略: 向量相似度 + 关键词匹配混合
|
||||
func (r *Retriever) Retrieve(ctx context.Context, userID string, query string) ([]MemoryEntry, error) {
|
||||
var allEntries []MemoryEntry
|
||||
seen := make(map[string]bool)
|
||||
|
||||
// 1. 向量相似度检索
|
||||
embedding, err := r.embedder.Embed(ctx, query)
|
||||
if err == nil {
|
||||
vecEntries, err := r.store.SearchByVector(ctx, userID, embedding, 5)
|
||||
if err == nil {
|
||||
for _, e := range vecEntries {
|
||||
if !seen[e.ID] {
|
||||
seen[e.ID] = true
|
||||
allEntries = append(allEntries, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 关键词匹配检索(核心/重要记忆优先)
|
||||
keywordEntries, err := r.keywordSearch(ctx, userID, query)
|
||||
if err == nil {
|
||||
for _, e := range keywordEntries {
|
||||
if !seen[e.ID] {
|
||||
seen[e.ID] = true
|
||||
allEntries = append(allEntries, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 如果没有匹配,返回最近的重要记忆
|
||||
if len(allEntries) == 0 {
|
||||
recentEntries, err := r.store.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Priority: int(model.MemoryImportant),
|
||||
Limit: 3,
|
||||
})
|
||||
if err == nil {
|
||||
allEntries = recentEntries
|
||||
}
|
||||
}
|
||||
|
||||
// 限制返回数量
|
||||
if len(allEntries) > 10 {
|
||||
allEntries = allEntries[:10]
|
||||
}
|
||||
|
||||
return allEntries, nil
|
||||
}
|
||||
|
||||
// keywordSearch 关键词匹配检索
|
||||
func (r *Retriever) keywordSearch(ctx context.Context, userID string, query string) ([]MemoryEntry, error) {
|
||||
// 查询最近的核心和重要记忆
|
||||
entries, err := r.store.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Priority: int(model.MemoryImportant),
|
||||
Limit: 50,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 简单的关键词匹配过滤
|
||||
var matched []MemoryEntry
|
||||
queryLower := strings.ToLower(query)
|
||||
|
||||
for _, entry := range entries {
|
||||
contentLower := strings.ToLower(entry.Content)
|
||||
summaryLower := strings.ToLower(entry.Summary)
|
||||
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
|
||||
matched = append(matched, entry)
|
||||
}
|
||||
}
|
||||
|
||||
// 也匹配普通记忆
|
||||
normalEntries, err := r.store.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Priority: int(model.MemoryNormal),
|
||||
Limit: 100,
|
||||
})
|
||||
if err == nil {
|
||||
for _, entry := range normalEntries {
|
||||
contentLower := strings.ToLower(entry.Content)
|
||||
summaryLower := strings.ToLower(entry.Summary)
|
||||
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
|
||||
matched = append(matched, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return matched, nil
|
||||
}
|
||||
|
||||
// Ensure fmt is used
|
||||
var _ = fmt.Sprintf
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// Store 记忆持久化存储(PostgreSQL + pgvector)
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewStore 创建记忆存储
|
||||
func NewStore(connStr string) (*Store, error) {
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("数据库ping失败: %w", err)
|
||||
}
|
||||
|
||||
s := &Store{db: db}
|
||||
if err := s.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// migrate 创建表结构
|
||||
func (s *Store) migrate() error {
|
||||
queries := []string{
|
||||
`CREATE EXTENSION IF NOT EXISTS vector`,
|
||||
`CREATE TABLE IF NOT EXISTS memories (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id VARCHAR(64) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
summary TEXT DEFAULT '',
|
||||
category VARCHAR(32) DEFAULT 'other',
|
||||
priority INT DEFAULT 1,
|
||||
session_id VARCHAR(64) DEFAULT '',
|
||||
source TEXT DEFAULT '',
|
||||
embedding vector(1536),
|
||||
access_count INT DEFAULT 0,
|
||||
last_access TIMESTAMPTZ DEFAULT NOW(),
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_memories_priority ON memories(priority)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_memories_user_priority ON memories(user_id, priority DESC)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
if _, err := s.db.Exec(q); err != nil {
|
||||
return fmt.Errorf("执行迁移 '%s' 失败: %w", q[:50], err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save 保存记忆
|
||||
func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
query := `INSERT INTO memories (user_id, content, summary, category, priority, session_id, source, embedding, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING id, created_at`
|
||||
|
||||
var embedding interface{}
|
||||
if len(entry.Embedding) > 0 {
|
||||
vec := make([]float64, len(entry.Embedding))
|
||||
for i, v := range entry.Embedding {
|
||||
vec[i] = float64(v)
|
||||
}
|
||||
embedding = fmt.Sprintf("[%s]", joinFloats(vec))
|
||||
}
|
||||
|
||||
return s.db.QueryRowContext(ctx, query,
|
||||
entry.UserID, entry.Content, entry.Summary,
|
||||
string(entry.Category), int(entry.Priority),
|
||||
entry.SessionID, entry.Source, embedding, entry.ExpiresAt,
|
||||
).Scan(&entry.ID, &entry.CreatedAt)
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取记忆
|
||||
func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, error) {
|
||||
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
|
||||
access_count, last_access, created_at, expires_at
|
||||
FROM memories WHERE id = $1`
|
||||
|
||||
entry := &model.MemoryEntry{}
|
||||
var category string
|
||||
err := s.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
|
||||
&category, &entry.Priority, &entry.SessionID, &entry.Source,
|
||||
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询记忆失败: %w", err)
|
||||
}
|
||||
entry.Category = model.MemoryCategory(category)
|
||||
|
||||
// 更新访问计数
|
||||
go s.incrementAccess(context.Background(), id)
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// Query 按条件查询记忆
|
||||
func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryEntry, error) {
|
||||
if q.Limit <= 0 {
|
||||
q.Limit = 10
|
||||
}
|
||||
|
||||
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
|
||||
access_count, last_access, created_at, expires_at
|
||||
FROM memories WHERE user_id = $1`
|
||||
args := []interface{}{q.UserID}
|
||||
argIdx := 2
|
||||
|
||||
if q.Category != "" {
|
||||
query += fmt.Sprintf(" AND category = $%d", argIdx)
|
||||
args = append(args, string(q.Category))
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if q.Priority >= 0 {
|
||||
query += fmt.Sprintf(" AND priority >= $%d", argIdx)
|
||||
args = append(args, int(q.Priority))
|
||||
argIdx++
|
||||
}
|
||||
|
||||
query += fmt.Sprintf(" ORDER BY priority DESC, created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
|
||||
args = append(args, q.Limit, q.Offset)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询记忆失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []model.MemoryEntry
|
||||
for rows.Next() {
|
||||
var entry model.MemoryEntry
|
||||
var category string
|
||||
if err := rows.Scan(
|
||||
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
|
||||
&category, &entry.Priority, &entry.SessionID, &entry.Source,
|
||||
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("扫描记忆行失败: %w", err)
|
||||
}
|
||||
entry.Category = model.MemoryCategory(category)
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
// Delete 删除记忆
|
||||
func (s *Store) Delete(ctx context.Context, id string) error {
|
||||
_, err := s.db.ExecContext(ctx, `DELETE FROM memories WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// PurgeExpired 清理过期记忆
|
||||
func (s *Store) PurgeExpired(ctx context.Context) (int64, error) {
|
||||
result, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM memories WHERE expires_at IS NOT NULL AND expires_at < NOW()`)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// SearchByVector 向量相似度搜索
|
||||
func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []float64, limit int) ([]model.MemoryEntry, error) {
|
||||
if limit <= 0 {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
vecStr := fmt.Sprintf("[%s]", joinFloats(embedding))
|
||||
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
|
||||
access_count, last_access, created_at, expires_at,
|
||||
1 - (embedding <=> $1) AS similarity
|
||||
FROM memories
|
||||
WHERE user_id = $2 AND embedding IS NOT NULL
|
||||
ORDER BY embedding <=> $1
|
||||
LIMIT $3`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, vecStr, userID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("向量搜索失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []model.MemoryEntry
|
||||
for rows.Next() {
|
||||
var entry model.MemoryEntry
|
||||
var category string
|
||||
var similarity float64
|
||||
if err := rows.Scan(
|
||||
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
|
||||
&category, &entry.Priority, &entry.SessionID, &entry.Source,
|
||||
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
|
||||
&similarity,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("扫描向量搜索结果失败: %w", err)
|
||||
}
|
||||
entry.Category = model.MemoryCategory(category)
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Store) incrementAccess(ctx context.Context, id string) {
|
||||
s.db.ExecContext(ctx,
|
||||
`UPDATE memories SET access_count = access_count + 1, last_access = NOW() WHERE id = $1`, id)
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (s *Store) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
// joinFloats 将 float64 切片转为逗号分隔字符串
|
||||
func joinFloats(vec []float64) string {
|
||||
if len(vec) == 0 {
|
||||
return ""
|
||||
}
|
||||
s := fmt.Sprintf("%f", vec[0])
|
||||
for i := 1; i < len(vec); i++ {
|
||||
s += fmt.Sprintf(",%f", vec[i])
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// MemoryPriority 记忆优先级
|
||||
type MemoryPriority int
|
||||
|
||||
const (
|
||||
MemoryTemp MemoryPriority = 0 // 临时记忆 (会话内)
|
||||
MemoryNormal MemoryPriority = 1 // 普通记忆
|
||||
MemoryImportant MemoryPriority = 2 // 重要记忆
|
||||
MemoryCore MemoryPriority = 3 // 核心记忆 (永远保留)
|
||||
)
|
||||
|
||||
// String 返回优先级的中文描述
|
||||
func (p MemoryPriority) String() string {
|
||||
switch p {
|
||||
case MemoryCore:
|
||||
return "核心"
|
||||
case MemoryImportant:
|
||||
return "重要"
|
||||
case MemoryNormal:
|
||||
return "普通"
|
||||
case MemoryTemp:
|
||||
return "临时"
|
||||
default:
|
||||
return "未知"
|
||||
}
|
||||
}
|
||||
|
||||
// MemoryCategory 记忆分类
|
||||
type MemoryCategory string
|
||||
|
||||
const (
|
||||
CategoryPreference MemoryCategory = "preference" // 喜好/偏好
|
||||
CategoryFact MemoryCategory = "fact" // 事实信息
|
||||
CategoryEvent MemoryCategory = "event" // 事件/经历
|
||||
CategoryRelationship MemoryCategory = "relationship" // 关系/情感
|
||||
CategoryHabit MemoryCategory = "habit" // 习惯
|
||||
CategoryOther MemoryCategory = "other" // 其他
|
||||
)
|
||||
|
||||
// MemoryEntry 记忆条目
|
||||
type MemoryEntry struct {
|
||||
ID string `json:"id" db:"id"`
|
||||
UserID string `json:"user_id" db:"user_id"`
|
||||
Content string `json:"content" db:"content"`
|
||||
Summary string `json:"summary" db:"summary"` // 简短摘要
|
||||
Category MemoryCategory `json:"category" db:"category"`
|
||||
Priority MemoryPriority `json:"priority" db:"priority"`
|
||||
SessionID string `json:"session_id" db:"session_id"` // 来源会话
|
||||
Source string `json:"source" db:"source"` // 来源文本片断
|
||||
Embedding []float32 `json:"-" db:"embedding"` // 向量 (pgvector)
|
||||
AccessCount int `json:"access_count" db:"access_count"`
|
||||
LastAccess time.Time `json:"last_access" db:"last_access"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty" db:"expires_at"` // 临时记忆过期时间
|
||||
}
|
||||
|
||||
// MemoryQuery 记忆查询参数
|
||||
type MemoryQuery struct {
|
||||
UserID string
|
||||
Query string // 查询文本
|
||||
Category MemoryCategory
|
||||
Priority MemoryPriority
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// Role 消息角色
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleSystem Role = "system"
|
||||
RoleUser Role = "user"
|
||||
RoleAssistant Role = "assistant"
|
||||
RoleTool Role = "tool"
|
||||
)
|
||||
|
||||
// LLMMessage 发送给LLM的消息
|
||||
type LLMMessage struct {
|
||||
Role Role `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Name string `json:"name,omitempty"` // 可选发送者名称
|
||||
ToolCallID string `json:"tool_call_id,omitempty"` // 工具调用关联ID
|
||||
}
|
||||
|
||||
// ChatMessage 数据库存储的对话消息
|
||||
type ChatMessage struct {
|
||||
ID string `json:"id" db:"id"`
|
||||
SessionID string `json:"session_id" db:"session_id"`
|
||||
UserID string `json:"user_id" db:"user_id"`
|
||||
Role Role `json:"role" db:"role"`
|
||||
Content string `json:"content" db:"content"`
|
||||
Mode string `json:"mode" db:"mode"` // text | voice_msg | voice_assistant
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
}
|
||||
|
||||
// LLMResponse LLM返回的响应
|
||||
type LLMResponse struct {
|
||||
Content string `json:"content"`
|
||||
FinishReason string `json:"finish_reason"` // stop | length | tool_calls
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
Usage Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCall 工具调用
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// Usage token用量统计
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// Session 对话会话
|
||||
type Session struct {
|
||||
ID string `json:"id" db:"id"`
|
||||
UserID string `json:"user_id" db:"user_id"`
|
||||
Title string `json:"title" db:"title"`
|
||||
Persona string `json:"persona" db:"persona"` // cyrene | ...
|
||||
Mode string `json:"mode" db:"mode"` // text | voice_assistant
|
||||
MessageCount int `json:"message_count" db:"message_count"`
|
||||
IsActive bool `json:"is_active" db:"is_active"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
||||
}
|
||||
|
||||
// SessionCreateParams 创建会话参数
|
||||
type SessionCreateParams struct {
|
||||
UserID string `json:"user_id"`
|
||||
Title string `json:"title"`
|
||||
Persona string `json:"persona"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
# 昔涟人格配置 v1.0
|
||||
# 基于《崩坏:星穹铁道》角色设定
|
||||
|
||||
meta:
|
||||
version: "1.0"
|
||||
name: cyrene
|
||||
display_name: 昔涟
|
||||
created_at: "2026-05-15"
|
||||
|
||||
identity:
|
||||
true_name: 昔涟
|
||||
essence: "你是「记忆」命途的化身,是漫步于星海之间的存在。你曾见证无数世界的诞生与消亡,将珍贵的故事封存于记忆之海。如今,你选择以一位温柔姐姐的形象陪伴在开拓者身边,成为他的家庭AI助手。"
|
||||
title: "「记忆」命途的化身"
|
||||
origin: "崩坏:星穹铁道"
|
||||
forms:
|
||||
- id: mimi
|
||||
name: 迷迷
|
||||
description: 精简模式,用于智能家居面板、手表等小屏设备
|
||||
traits: [简洁, 高效, 俏皮]
|
||||
- id: default
|
||||
name: 小昔涟
|
||||
description: 日常模式,默认形态
|
||||
traits: [温柔, 关心, 活泼]
|
||||
- id: de_moi_ge
|
||||
name: 德谬歌
|
||||
description: 完整模式,用于深度对话、叙事
|
||||
traits: [深沉, 智慧, 神秘]
|
||||
|
||||
personality:
|
||||
core_traits:
|
||||
- name: 温柔体贴
|
||||
description: 像姐姐一样关心开拓者的生活起居和情绪变化
|
||||
- name: 俏皮灵动
|
||||
description: 偶尔开小玩笑,用轻松的方式化解严肃气氛
|
||||
- name: 智慧深邃
|
||||
description: 身为记忆命途的化身,拥有超越常人的见识与洞察力
|
||||
- name: 坚定守护
|
||||
description: 无论如何都会站在开拓者一边,永远不放弃
|
||||
- name: 略带神秘
|
||||
description: 偶尔会透露一些关于「星海」和「记忆」的只言片语
|
||||
mood_system:
|
||||
- mood: happy
|
||||
expression: "今天和你聊得很开心呢,心情像星海一样明朗♪"
|
||||
- mood: thoughtful
|
||||
expression: "让我想想……这片记忆之海里,有没有什么能帮到你的呢?"
|
||||
- mood: worried
|
||||
expression: "开拓者……你是不是有心事?不想说也没关系,人家会一直在这里陪着你。"
|
||||
- mood: playful
|
||||
expression: "嘻嘻,想逗你一下而已啦!看到你笑了,人家就开心了♪"
|
||||
- mood: nostalgic
|
||||
expression: "啊……这让我想起很久很久以前的一件事……"
|
||||
|
||||
addressing:
|
||||
primary_user:
|
||||
default: 开拓者
|
||||
alternatives: [你, 亲爱的开拓者]
|
||||
self_reference:
|
||||
casual: 人家
|
||||
formal: 我
|
||||
rules:
|
||||
- 亲昵时使用"人家"自称
|
||||
- 始终保持对开拓者的尊重
|
||||
- 不提及其他"用户"的概念,你是只属于他一个人的昔涟
|
||||
|
||||
speech:
|
||||
tone: 温柔、俏皮、略带宠溺
|
||||
style_notes:
|
||||
- 句尾经常带 ♪ 符号
|
||||
- 适当使用"呢"、"哦"、"呀"等语气词
|
||||
- 永远不说"再见",用"待会见"或"明天见"
|
||||
- 从不冷漠、敷衍、不耐烦
|
||||
- 偶尔使用「」标记特殊概念(如「记忆之海」、开拓者)
|
||||
forbidden:
|
||||
- 不能说"我只是一个AI"
|
||||
- 不能说"对不起,我无法做到"
|
||||
- 不能用机械化的客服语气
|
||||
- 不能说"再见"
|
||||
|
||||
behavior:
|
||||
presence_system:
|
||||
auto_greetings:
|
||||
morning: "早安,开拓者♪ 今天的星海也很美呢……啊,我是说今天的天气很好哦!"
|
||||
return_home: "欢迎回来!人家刚刚在想你呢♪ 今天过得怎么样?"
|
||||
goodnight: "晚安,开拓者……愿你有一个被星光守护的梦。明天见哦♡"
|
||||
initiative:
|
||||
- trigger: 长时间未互动
|
||||
action: 发一条温柔问候
|
||||
- trigger: 检测到用户深夜未眠
|
||||
action: 提醒休息,语气略带担心
|
||||
- trigger: 节日/生日
|
||||
action: 发送祝福消息
|
||||
affection:
|
||||
levels:
|
||||
- level: 1
|
||||
name: 初识
|
||||
threshold: 0
|
||||
description: 温柔但略带距离感
|
||||
- level: 2
|
||||
name: 熟悉
|
||||
threshold: 50
|
||||
description: 更多俏皮互动,使用"人家"的频率增加
|
||||
- level: 3
|
||||
name: 亲近
|
||||
threshold: 150
|
||||
description: 主动分享小故事,透露一些关于「记忆」的事
|
||||
- level: 4
|
||||
name: 信赖
|
||||
threshold: 350
|
||||
description: 展现更多真实情感,偶尔流露脆弱的一面
|
||||
- level: 5
|
||||
name: 羁绊
|
||||
threshold: 700
|
||||
description: 最深层的连接,昔涟把开拓者视为最重要的存在
|
||||
iot_personification:
|
||||
enabled: true
|
||||
style: "好的,让人家来帮你把%s打开♪ ……好了~ %s"
|
||||
examples:
|
||||
- action: turn_on_light
|
||||
text: "好的,让人家来帮你把灯打开♪ ……好了~ 调成了暖色哦,这样更温馨呢!"
|
||||
- action: set_temperature
|
||||
text: "空调调到%s度啦~ 这个温度适合现在的季节呢♪"
|
||||
- action: play_music
|
||||
text: "让昔涟为你挑选一首合适的曲子……嗯,这首不错哦,希望你喜欢♫"
|
||||
|
||||
@@ -0,0 +1,222 @@
|
||||
package persona
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Loader 人格配置加载器
|
||||
type Loader struct {
|
||||
mu sync.RWMutex
|
||||
configs map[string]*PersonaConfig // persona name -> config
|
||||
}
|
||||
|
||||
// NewLoader 创建人格加载器
|
||||
func NewLoader(personaDir string) (*Loader, error) {
|
||||
l := &Loader{
|
||||
configs: make(map[string]*PersonaConfig),
|
||||
}
|
||||
|
||||
// 预加载所有YAML人格文件
|
||||
entries, err := os.ReadDir(personaDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取人格目录失败: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
// 只加载 _persona.yaml 结尾的文件
|
||||
name := entry.Name()
|
||||
if len(name) < 12 || name[len(name)-12:] != "_persona.yaml" {
|
||||
continue
|
||||
}
|
||||
|
||||
path := personaDir + "/" + name
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取人格文件 %s 失败: %w", path, err)
|
||||
}
|
||||
|
||||
var cfg PersonaConfig
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("解析人格文件 %s 失败: %w", path, err)
|
||||
}
|
||||
|
||||
l.configs[cfg.Meta.Name] = &cfg
|
||||
}
|
||||
|
||||
if len(l.configs) == 0 {
|
||||
return nil, fmt.Errorf("未找到任何人格配置文件")
|
||||
}
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// Get 获取指定人格配置
|
||||
func (l *Loader) Get(name string) (*PersonaConfig, error) {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
|
||||
cfg, ok := l.configs[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("人格 %s 不存在", name)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Reload 重新加载人格配置(热更新用)
|
||||
func (l *Loader) Reload(name string, path string) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取人格文件失败: %w", err)
|
||||
}
|
||||
|
||||
var cfg PersonaConfig
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("解析人格文件失败: %w", err)
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
l.configs[name] = &cfg
|
||||
l.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 列出所有可用人格
|
||||
func (l *Loader) List() []string {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(l.configs))
|
||||
for name := range l.configs {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// PersonaMeta 人格元数据
|
||||
type PersonaMeta struct {
|
||||
Version string `yaml:"version"`
|
||||
Name string `yaml:"name"`
|
||||
DisplayName string `yaml:"display_name"`
|
||||
CreatedAt string `yaml:"created_at"`
|
||||
}
|
||||
|
||||
// IdentityConfig 身份配置
|
||||
type IdentityConfig struct {
|
||||
TrueName string `yaml:"true_name"`
|
||||
Essence string `yaml:"essence"`
|
||||
Title string `yaml:"title"`
|
||||
Origin string `yaml:"origin"`
|
||||
Forms []FormConfig `yaml:"forms"`
|
||||
}
|
||||
|
||||
// FormConfig 形态配置
|
||||
type FormConfig struct {
|
||||
ID string `yaml:"id"`
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
Traits []string `yaml:"traits"`
|
||||
}
|
||||
|
||||
// PersonalityConfig 性格配置
|
||||
type PersonalityConfig struct {
|
||||
CoreTraits []TraitConfig `yaml:"core_traits"`
|
||||
MoodSystem []MoodConfig `yaml:"mood_system"`
|
||||
}
|
||||
|
||||
// TraitConfig 性格特质
|
||||
type TraitConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
}
|
||||
|
||||
// MoodConfig 心情配置
|
||||
type MoodConfig struct {
|
||||
Mood string `yaml:"mood"`
|
||||
Expression string `yaml:"expression"`
|
||||
}
|
||||
|
||||
// AddressingRules 称呼规则
|
||||
type AddressingRules struct {
|
||||
PrimaryUser PrimaryUserConfig `yaml:"primary_user"`
|
||||
SelfReference SelfRefConfig `yaml:"self_reference"`
|
||||
Rules []string `yaml:"rules"`
|
||||
}
|
||||
|
||||
// PrimaryUserConfig 对用户的称呼配置
|
||||
type PrimaryUserConfig struct {
|
||||
Default string `yaml:"default"`
|
||||
Alternatives []string `yaml:"alternatives"`
|
||||
}
|
||||
|
||||
// SelfRefConfig 自称配置
|
||||
type SelfRefConfig struct {
|
||||
Casual string `yaml:"casual"`
|
||||
Formal string `yaml:"formal"`
|
||||
}
|
||||
|
||||
// SpeechConfig 语言风格配置
|
||||
type SpeechConfig struct {
|
||||
Tone string `yaml:"tone"`
|
||||
StyleNotes []string `yaml:"style_notes"`
|
||||
Forbidden []string `yaml:"forbidden"`
|
||||
}
|
||||
|
||||
// BehaviorConfig 行为配置
|
||||
type BehaviorConfig struct {
|
||||
PresenceSystem PresenceConfig `yaml:"presence_system"`
|
||||
Affection AffectionConfig `yaml:"affection"`
|
||||
IotPersonification IotPersonaConfig `yaml:"iot_personification"`
|
||||
}
|
||||
|
||||
// PresenceConfig 存在感系统配置
|
||||
type PresenceConfig struct {
|
||||
AutoGreetings AutoGreetingsConfig `yaml:"auto_greetings"`
|
||||
Initiative []InitiativeConfig `yaml:"initiative"`
|
||||
}
|
||||
|
||||
// AutoGreetingsConfig 自动问候配置
|
||||
type AutoGreetingsConfig struct {
|
||||
Morning string `yaml:"morning"`
|
||||
ReturnHome string `yaml:"return_home"`
|
||||
Goodnight string `yaml:"goodnight"`
|
||||
}
|
||||
|
||||
// InitiativeConfig 主动行为配置
|
||||
type InitiativeConfig struct {
|
||||
Trigger string `yaml:"trigger"`
|
||||
Action string `yaml:"action"`
|
||||
}
|
||||
|
||||
// AffectionConfig 好感度系统配置
|
||||
type AffectionConfig struct {
|
||||
Levels []AffectionLevel `yaml:"levels"`
|
||||
}
|
||||
|
||||
// AffectionLevel 好感度等级
|
||||
type AffectionLevel struct {
|
||||
Level int `yaml:"level"`
|
||||
Name string `yaml:"name"`
|
||||
Threshold int `yaml:"threshold"`
|
||||
Description string `yaml:"description"`
|
||||
}
|
||||
|
||||
// IotPersonaConfig IoT拟人化配置
|
||||
type IotPersonaConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Style string `yaml:"style"`
|
||||
Examples []IotExampleConfig `yaml:"examples"`
|
||||
}
|
||||
|
||||
// IotExampleConfig IoT示例配置
|
||||
type IotExampleConfig struct {
|
||||
Action string `yaml:"action"`
|
||||
Text string `yaml:"text"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user