dev 分支暂存

This commit is contained in:
2026-05-16 08:26:56 +08:00
parent 58c8caa570
commit eb4129176c
71 changed files with 8474 additions and 214 deletions
+42
View File
@@ -0,0 +1,42 @@
# ========== 构建阶段 ==========
FROM golang:1.23-alpine AS builder
RUN apk add --no-cache git ca-certificates
WORKDIR /app
# 复制 go.mod/go.sum 并下载依赖(利用 Docker 缓存层)
COPY go.mod go.sum ./
RUN go mod download
# 复制源代码
COPY . .
# 编译 (静态链接,适配 Alpine)
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /ai-core ./cmd/main.go
# ========== 运行阶段 ==========
FROM alpine:3.20
RUN apk add --no-cache ca-certificates tzdata && \
cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \
echo "Asia/Shanghai" > /etc/timezone
WORKDIR /app
# 从构建阶段复制二进制文件
COPY --from=builder /ai-core .
# 复制人格配置文件 (运行时可能热加载)
COPY --from=builder /app/internal/persona/ ./internal/persona/
# 非 root 用户
RUN adduser -D -H cyrene
USER cyrene
EXPOSE 8081
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://localhost:8081/api/v1/health || exit 1
ENTRYPOINT ["./ai-core"]
+258
View File
@@ -0,0 +1,258 @@
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/yourname/cyrene-ai/ai-core/internal/context"
"github.com/yourname/cyrene-ai/ai-core/internal/llm"
"github.com/yourname/cyrene-ai/ai-core/internal/memory"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
"github.com/yourname/cyrene-ai/ai-core/internal/orchestrator"
"github.com/yourname/cyrene-ai/ai-core/internal/persona"
)
func main() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
log.Println("🧠 AI-Core 服务启动中...")
// 加载配置
cfg := loadConfig()
// 初始化人格加载器
personaDir := cfg.PersonaDir
if personaDir == "" {
personaDir = "./internal/persona"
}
personaLoader, err := persona.NewLoader(personaDir)
if err != nil {
log.Fatalf("加载人格配置失败: %v", err)
}
log.Printf("已加载 %d 个人格: %v", len(personaLoader.List()), personaLoader.List())
// 初始化LLM适配器
llmProvider := llm.NewOpenAIProvider(llm.OpenAIConfig{
BaseURL: cfg.LLMBaseURL,
APIKey: cfg.LLMAPIKey,
Model: cfg.LLMModel,
FallbackModel: cfg.LLMFallbackModel,
Timeout: 120 * time.Second,
})
llmAdapter := llm.NewAdapter(llmProvider)
log.Printf("LLM适配器已就绪: 模型=%s", llmAdapter.ModelName())
// 初始化记忆系统
var memStore *memory.Store
var memRetriever *memory.Retriever
var memExtractor *memory.Extractor
if cfg.DatabaseURL != "" {
memStore, err = memory.NewStore(cfg.DatabaseURL)
if err != nil {
log.Printf("⚠ 记忆存储初始化失败 (将跳过记忆功能): %v", err)
} else {
defer memStore.Close()
log.Println("记忆存储已就绪")
memRetriever = memory.NewRetriever(memStore, nil)
// 记忆提取器使用LLM
memExtractor = memory.NewExtractor(memStore, func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) {
return llmAdapter.Chat(ctx, messages)
})
log.Println("记忆提取器已就绪")
}
}
// 初始化上下文构建器
ctxBuilder := &context.Builder{}
// 手动注入 Injector 到 orchestrator(临时方案,后续会用依赖注入框架)
personaInjector := &persona.Injector{}
// 健康检查与对话API的HTTP mux
mux := http.NewServeMux()
// 手动构建 orchestrator 用于处理(因为现有orchestrator结构体已定义但未导出构造函数)
orch := &orchestrator.Orchestrator{}
// 注册对话API端点
mux.HandleFunc("/api/v1/chat", func(w http.ResponseWriter, r *http.Request) {
handleChat(w, r, orch, ctxBuilder, llmAdapter, personaLoader, personaInjector, memRetriever, memExtractor)
})
mux.HandleFunc("/api/v1/health", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"status":"ok","service":"ai-core","model":"` + llmAdapter.ModelName() + `"}`))
})
// 启动HTTP服务
srv := &http.Server{
Addr: ":" + cfg.Port,
Handler: mux,
}
go func() {
log.Printf("🚀 AI-Core 服务已启动在端口 %s", cfg.Port)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("服务启动失败: %v", err)
}
}()
// 优雅关闭
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("正在关闭 AI-Core 服务...")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
srv.Shutdown(ctx)
log.Println("AI-Core 服务已关闭")
}
// Config AI-Core配置
type Config struct {
Port string
PersonaDir string
LLMBaseURL string
LLMAPIKey string
LLMModel string
LLMFallbackModel string
DatabaseURL string
}
func loadConfig() Config {
return Config{
Port: getEnv("AI_CORE_PORT", "8081"),
PersonaDir: getEnv("PERSONA_DIR", "./internal/persona"),
LLMBaseURL: getEnv("LLM_API_URL", "https://api.openai.com/v1"),
LLMAPIKey: getEnv("LLM_API_KEY", ""),
LLMModel: getEnv("LLM_MODEL", "gpt-4o"),
LLMFallbackModel: getEnv("LLM_FALLBACK_MODEL", "gpt-4o-mini"),
DatabaseURL: buildDatabaseURL(),
}
}
func buildDatabaseURL() string {
host := getEnv("POSTGRES_HOST", "localhost")
port := getEnv("POSTGRES_PORT", "5432")
user := getEnv("POSTGRES_USER", "cyrene")
password := getEnv("POSTGRES_PASSWORD", "change_me")
dbname := getEnv("POSTGRES_DB", "cyrene_ai")
sslmode := getEnv("POSTGRES_SSLMODE", "disable")
return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s",
user, password, host, port, dbname, sslmode)
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
// handleChat 处理对话请求
func handleChat(
w http.ResponseWriter,
r *http.Request,
_ *orchestrator.Orchestrator,
ctxBuilder *context.Builder,
llmAdapter *llm.Adapter,
personaLoader *persona.Loader,
personaInjector *persona.Injector,
memRetriever *memory.Retriever,
memExtractor *memory.Extractor,
) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// 解析请求
var req struct {
UserID string `json:"user_id"`
SessionID string `json:"session_id"`
Message string `json:"message"`
Mode string `json:"mode"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "无效的请求体", http.StatusBadRequest)
return
}
if req.Mode == "" {
req.Mode = "text"
}
ctx := r.Context()
// 1. 检索相关记忆
var memories []memory.MemoryEntry
if memRetriever != nil {
var err error
memories, err = memRetriever.Retrieve(ctx, req.UserID, req.Message)
if err != nil {
log.Printf("[chat] 记忆检索失败: %v", err)
}
}
// 2. 加载人格配置
personaConfig, err := personaLoader.Get("cyrene")
if err != nil {
http.Error(w, fmt.Sprintf("加载人格失败: %v", err), http.StatusInternalServerError)
return
}
// 3. 构建对话上下文
llmMessages, err := ctxBuilder.Build(ctx, context.BuildParams{
UserID: req.UserID,
SessionID: req.SessionID,
UserMessage: req.Message,
Persona: personaConfig,
Memories: memories,
HistoryLimit: 20,
})
if err != nil {
http.Error(w, fmt.Sprintf("构建上下文失败: %v", err), http.StatusInternalServerError)
return
}
// 4. 调用LLM
llmResp, err := llmAdapter.Chat(ctx, llmMessages)
if err != nil {
http.Error(w, fmt.Sprintf("LLM调用失败: %v", err), http.StatusInternalServerError)
return
}
// 5. 异步提取记忆
if memExtractor != nil {
go memExtractor.ExtractAndStore(context.Background(), req.UserID, req.SessionID, req.Message, llmResp.Content)
}
// 6. 构建响应
resp := map[string]interface{}{
"text": llmResp.Content,
"mode": req.Mode,
"message_id": fmt.Sprintf("msg-%d", time.Now().UnixNano()),
}
// 语音助手模式断句
if req.Mode == "voice_assistant" {
resp["segments"] = llm.SplitIntoSegments(llmResp.Content)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
// 确保未使用变量不报错
var _ = personaInjector
+8
View File
@@ -0,0 +1,8 @@
module github.com/yourname/cyrene-ai/ai-core
go 1.26.2
require (
gopkg.in/yaml.v3 v3.0.1
github.com/lib/pq v1.10.9
)
+83
View File
@@ -0,0 +1,83 @@
package llm
import (
"context"
"io"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
)
// Adapter LLM适配器接口
// 支持不同的LLM后端(OpenAI、Ollama、vLLM等)
type Adapter struct {
provider LLMProvider
}
// LLMProvider LLM提供商接口
type LLMProvider interface {
// Chat 同步对话
Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)
// ChatStream 流式对话,返回一个channel逐token推送
ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error)
// ModelName 返回当前使用的模型名称
ModelName() string
}
// StreamChunk 流式响应的单个片段
type StreamChunk struct {
Content string // delta内容
Done bool // 是否为最后一块
Error error // 错误信息
Usage *model.Usage // 最后一块时返回token统计
}
// NewAdapter 创建LLM适配器
func NewAdapter(provider LLMProvider) *Adapter {
return &Adapter{provider: provider}
}
// Chat 同步对话
func (a *Adapter) Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) {
return a.provider.Chat(ctx, messages)
}
// ChatStream 流式对话
func (a *Adapter) ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error) {
return a.provider.ChatStream(ctx, messages)
}
// ModelName 返回模型名称
func (a *Adapter) ModelName() string {
return a.provider.ModelName()
}
// collectStream 辅助函数:将流式响应收集为完整响应
func collectStream(ch <-chan StreamChunk) (*model.LLMResponse, error) {
var content string
var lastUsage *model.Usage
for chunk := range ch {
if chunk.Error != nil {
return nil, chunk.Error
}
if chunk.Done {
lastUsage = chunk.Usage
break
}
content += chunk.Content
}
resp := &model.LLMResponse{
Content: content,
FinishReason: "stop",
}
if lastUsage != nil {
resp.Usage = *lastUsage
}
return resp, nil
}
// Ensure io is used (will be needed for SSE parsing)
var _ io.Reader
+313
View File
@@ -0,0 +1,313 @@
package llm
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
)
// OpenAIConfig OpenAI适配器配置
type OpenAIConfig struct {
BaseURL string // API基础URL
APIKey string // API密钥
Model string // 主模型
FallbackModel string // 备用模型(主模型不可用时)
MaxRetries int // 最大重试次数
Timeout time.Duration // 请求超时
}
// OpenAIProvider OpenAI兼容的LLM提供商
type OpenAIProvider struct {
config OpenAIConfig
httpClient *http.Client
}
// NewOpenAIProvider 创建OpenAI提供商
func NewOpenAIProvider(cfg OpenAIConfig) *OpenAIProvider {
if cfg.MaxRetries == 0 {
cfg.MaxRetries = 3
}
if cfg.Timeout == 0 {
cfg.Timeout = 60 * time.Second
}
return &OpenAIProvider{
config: cfg,
httpClient: &http.Client{
Timeout: cfg.Timeout,
},
}
}
// openAIRequest OpenAI请求结构
type openAIRequest struct {
Model string `json:"model"`
Messages []openAIMessage `json:"messages"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"`
}
type openAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// openAIResponse OpenAI响应结构
type openAIResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Choices []openAIChoice `json:"choices"`
Usage openAIUsage `json:"usage,omitempty"`
Error *openAIError `json:"error,omitempty"`
}
type openAIChoice struct {
Index int `json:"index"`
Message openAIMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type openAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type openAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code,omitempty"`
}
// Chat 同步对话
func (p *OpenAIProvider) Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) {
resp, err := p.doChat(ctx, messages, p.config.Model, false)
if err != nil {
// 尝试fallback模型
if p.config.FallbackModel != "" && p.config.FallbackModel != p.config.Model {
log.Printf("[LLM] 主模型 %s 调用失败,降级到 %s: %v", p.config.Model, p.config.FallbackModel, err)
return p.doChat(ctx, messages, p.config.FallbackModel, false)
}
return nil, err
}
return resp, nil
}
// ChatStream 流式对话
func (p *OpenAIProvider) ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error) {
ch := make(chan StreamChunk, 100)
go func() {
defer close(ch)
resp, err := p.doChatStream(ctx, messages, p.config.Model)
if err != nil {
// Fallback
if p.config.FallbackModel != "" {
log.Printf("[LLM] 流式调用主模型失败,降级: %v", err)
resp, err = p.doChatStream(ctx, messages, p.config.FallbackModel)
}
if err != nil {
ch <- StreamChunk{Error: err, Done: true}
return
}
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
// 增大scanner buffer以处理大块SSE数据
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
// SSE格式: data: {...}
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
// 流结束标记
if data == "[DONE]" {
ch <- StreamChunk{Done: true}
return
}
var streamResp openAIStreamResponse
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
continue
}
if len(streamResp.Choices) > 0 {
delta := streamResp.Choices[0].Delta
if delta.Content != "" {
ch <- StreamChunk{Content: delta.Content}
}
if streamResp.Choices[0].FinishReason != "" {
usage := &model.Usage{}
if streamResp.Usage != nil {
usage.PromptTokens = streamResp.Usage.PromptTokens
usage.CompletionTokens = streamResp.Usage.CompletionTokens
usage.TotalTokens = streamResp.Usage.TotalTokens
}
ch <- StreamChunk{Done: true, Usage: usage}
return
}
}
}
if err := scanner.Err(); err != nil {
ch <- StreamChunk{Error: fmt.Errorf("读取流式响应失败: %w", err), Done: true}
return
}
ch <- StreamChunk{Done: true}
}()
return ch, nil
}
// openAIStreamResponse 流式响应结构
type openAIStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Choices []openAIStreamChoice `json:"choices"`
Usage *openAIUsage `json:"usage,omitempty"`
}
type openAIStreamChoice struct {
Index int `json:"index"`
Delta openAIMessage `json:"delta"`
FinishReason string `json:"finish_reason"`
}
// doChat 执行同步对话请求
func (p *OpenAIProvider) doChat(ctx context.Context, messages []model.LLMMessage, model string, stream bool) (*model.LLMResponse, error) {
// 转换消息格式
oaiMessages := make([]openAIMessage, len(messages))
for i, msg := range messages {
oaiMessages[i] = openAIMessage{
Role: string(msg.Role),
Content: msg.Content,
}
}
reqBody := openAIRequest{
Model: model,
Messages: oaiMessages,
Temperature: 0.8,
Stream: stream,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.config.BaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errResp openAIResponse
if json.Unmarshal(body, &errResp) == nil && errResp.Error != nil {
return nil, fmt.Errorf("API错误 [%s]: %s", errResp.Error.Code, errResp.Error.Message)
}
return nil, fmt.Errorf("API返回状态码 %d: %s", resp.StatusCode, string(body))
}
var oaiResp openAIResponse
if err := json.Unmarshal(body, &oaiResp); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if len(oaiResp.Choices) == 0 {
return nil, fmt.Errorf("API返回空choices")
}
return &model.LLMResponse{
Content: oaiResp.Choices[0].Message.Content,
FinishReason: oaiResp.Choices[0].FinishReason,
Usage: model.Usage{
PromptTokens: oaiResp.Usage.PromptTokens,
CompletionTokens: oaiResp.Usage.CompletionTokens,
TotalTokens: oaiResp.Usage.TotalTokens,
},
}, nil
}
// doChatStream 执行流式对话请求(返回原始HTTP响应)
func (p *OpenAIProvider) doChatStream(ctx context.Context, messages []model.LLMMessage, model string) (*http.Response, error) {
oaiMessages := make([]openAIMessage, len(messages))
for i, msg := range messages {
oaiMessages[i] = openAIMessage{
Role: string(msg.Role),
Content: msg.Content,
}
}
reqBody := openAIRequest{
Model: model,
Messages: oaiMessages,
Temperature: 0.8,
Stream: true,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.config.BaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
req.Header.Set("Accept", "text/event-stream")
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API返回状态码 %d: %s", resp.StatusCode, string(body))
}
return resp, nil
}
// ModelName 返回模型名称
func (p *OpenAIProvider) ModelName() string {
return p.config.Model
}
+191
View File
@@ -0,0 +1,191 @@
package llm
import (
"strings"
"sync"
"unicode"
)
// Segmenter 断句器 —— 将流式文本按句号切分为语音播放片段
type Segmenter struct {
mu sync.Mutex
buffer strings.Builder
segments []Segment
index int
}
// Segment 语音片段
type Segment struct {
Index int `json:"index"`
Text string `json:"text"`
}
// NewSegmenter 创建断句器
func NewSegmenter() *Segmenter {
return &Segmenter{}
}
// Feed 喂入新的文本片段
// 返回已完成的断句列表
func (s *Segmenter) Feed(delta string) []Segment {
s.mu.Lock()
defer s.mu.Unlock()
s.buffer.WriteString(delta)
content := s.buffer.String()
var newSegments []Segment
for {
idx := findSentenceEnd(content)
if idx == -1 {
break
}
segmentText := strings.TrimSpace(content[:idx+len(string(content[idx]))])
// 检查是否是完整中文字符的句末
// idx 指向标点符号的位置
runes := []rune(content)
var byteIdx int
for i, r := range runes {
if i == idx {
// 标点之后的字符
break
}
byteIdx += len(string(r))
}
// 简化处理:直接取到idx+1字节 (对于ASCII标点)
// 对于中文标点,需要用rune处理
realIdx := 0
runeCount := 0
for i, r := range content {
if runeCount == idx {
realIdx = i
break
}
runeCount++
_ = r
}
// 包含标点符号本身
endIdx := realIdx + len(string([]rune(content)[idx]))
if endIdx <= realIdx {
endIdx = realIdx + 3 // fallback for UTF-8 multi-byte
}
segmentText = strings.TrimSpace(content[:endIdx])
if segmentText == "" {
content = strings.TrimSpace(content[endIdx:])
s.buffer.Reset()
s.buffer.WriteString(content)
continue
}
s.index++
seg := Segment{
Index: s.index,
Text: segmentText,
}
s.segments = append(s.segments, seg)
newSegments = append(newSegments, seg)
// 更新buffer,移除已处理的部分
content = strings.TrimSpace(content[endIdx:])
s.buffer.Reset()
s.buffer.WriteString(content)
}
return newSegments
}
// Flush 强制输出buffer中剩余的内容
func (s *Segmenter) Flush() *Segment {
s.mu.Lock()
defer s.mu.Unlock()
remaining := strings.TrimSpace(s.buffer.String())
if remaining == "" {
return nil
}
s.index++
seg := Segment{
Index: s.index,
Text: remaining,
}
s.segments = append(s.segments, seg)
s.buffer.Reset()
return &seg
}
// AllSegments 返回所有已完成的断句
func (s *Segmenter) AllSegments() []Segment {
s.mu.Lock()
defer s.mu.Unlock()
result := make([]Segment, len(s.segments))
copy(result, s.segments)
return result
}
// findSentenceEnd 查找句子结束位置(返回标点符号在rune数组中的索引)
// 中文标点:。!? 英文标点:. ! ?
func findSentenceEnd(text string) int {
runes := []rune(text)
for i, r := range runes {
if isSentenceEnd(r) {
return i
}
}
return -1
}
// isSentenceEnd 判断是否为句末标点
func isSentenceEnd(r rune) bool {
switch r {
case '。', '', '', '.', '!', '?', '\n':
return true
}
return false
}
// splitIntoSegments 将完整文本按句号断句(用于post-processing
func splitIntoSegments(text string) []Segment {
var segments []Segment
runes := []rune(text)
start := 0
index := 0
for i, r := range runes {
if isSentenceEnd(r) {
segText := strings.TrimSpace(string(runes[start : i+1]))
if segText != "" {
index++
segments = append(segments, Segment{
Index: index,
Text: segText,
})
}
start = i + 1
}
}
// 处理末尾无标点的剩余文本
if start < len(runes) {
remaining := strings.TrimSpace(string(runes[start:]))
if remaining != "" {
index++
segments = append(segments, Segment{
Index: index,
Text: remaining,
})
}
}
return segments
}
// Ensure unicode is used
var _ = unicode.Is
@@ -0,0 +1,206 @@
package memory
import (
"context"
"encoding/json"
"fmt"
"log"
"strings"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
)
// Extractor 记忆提取器 —— 从对话中提取结构化记忆
type Extractor struct {
store *Store
llmChat func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)
}
// NewExtractor 创建记忆提取器
// llmChat: LLM对话函数,用于分析对话内容并提取记忆
// 如果为nil,则使用规则提取(降级模式)
func NewExtractor(store *Store, llmChat func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)) *Extractor {
return &Extractor{
store: store,
llmChat: llmChat,
}
}
// ExtractAndStore 从一轮对话中提取记忆并存储
// 异步执行,不阻塞主流程
func (e *Extractor) ExtractAndStore(ctx context.Context, userID, sessionID, userMessage, assistantResponse string) {
memories, err := e.extract(ctx, userMessage, assistantResponse)
if err != nil {
log.Printf("[memory] 记忆提取失败: %v", err)
return
}
for _, mem := range memories {
mem.UserID = userID
mem.SessionID = sessionID
if err := e.store.Save(ctx, &mem); err != nil {
log.Printf("[memory] 记忆保存失败: %v", err)
continue
}
log.Printf("[memory] 新记忆已保存 [%s]: %s", mem.Category, mem.Summary)
}
}
// extract 从对话中提取记忆
func (e *Extractor) extract(ctx context.Context, userMessage, assistantResponse string) ([]model.MemoryEntry, error) {
// 如果有LLM,使用LLM提取
if e.llmChat != nil {
return e.extractWithLLM(ctx, userMessage, assistantResponse)
}
// 降级:规则提取
return e.extractWithRules(userMessage, assistantResponse), nil
}
// MemoryExtractionResult LLM提取结果的结构
type MemoryExtractionResult struct {
Memories []struct {
Content string `json:"content"`
Summary string `json:"summary"`
Category string `json:"category"`
Priority int `json:"priority"`
} `json:"memories"`
}
// extractWithLLM 使用LLM提取记忆
func (e *Extractor) extractWithLLM(ctx context.Context, userMessage, assistantResponse string) ([]model.MemoryEntry, error) {
prompt := fmt.Sprintf(`分析以下对话,提取关于用户(开拓者)的重要信息作为记忆。
用户消息: %s
昔涟回复: %s
请以JSON格式返回提取的记忆。每条记忆需要包含:
- content: 完整的记忆内容(一句话描述)
- summary: 简短摘要(10字以内)
- category: 分类 (preference/fact/event/relationship/habit/other)
- priority: 优先级 (0=临时, 1=普通, 2=重要, 3=核心)
只提取有意义的信息,不要提取无意义的闲聊。如果没有值得记住的内容,返回空数组。
输出格式:
{"memories": [{"content": "...", "summary": "...", "category": "...", "priority": 1}]}
`, userMessage, assistantResponse)
resp, err := e.llmChat(ctx, []model.LLMMessage{
{Role: "system", Content: "你是一个记忆提取助手。你只输出JSON格式的结果,不输出其他内容。"},
{Role: "user", Content: prompt},
})
if err != nil {
return nil, fmt.Errorf("LLM提取记忆失败: %w", err)
}
// 解析JSON
result := MemoryExtractionResult{}
content := extractJSON(resp.Content)
if err := json.Unmarshal([]byte(content), &result); err != nil {
// 尝试作为数组解析
var arrResult []struct {
Content string `json:"content"`
Summary string `json:"summary"`
Category string `json:"category"`
Priority int `json:"priority"`
}
if err2 := json.Unmarshal([]byte(content), &arrResult); err2 != nil {
return nil, fmt.Errorf("解析记忆JSON失败: %w (原始: %s)", err, content[:min(len(content), 100)])
}
result.Memories = arrResult
}
var entries []model.MemoryEntry
for _, m := range result.Memories {
entries = append(entries, model.MemoryEntry{
Content: m.Content,
Summary: m.Summary,
Category: model.MemoryCategory(m.Category),
Priority: model.MemoryPriority(m.Priority),
})
}
return entries, nil
}
// extractWithRules 基于规则提取记忆(降级方案)
func (e *Extractor) extractWithRules(userMessage, _ string) []model.MemoryEntry {
var entries []model.MemoryEntry
// 规则1: 检测用户偏好表达
prefPatterns := map[string]string{
"喜欢": "preference",
"爱": "preference",
"最喜欢": "preference",
"讨厌": "preference",
"不喜欢": "preference",
"经常": "habit",
"每天都": "habit",
"一直": "habit",
"我叫": "fact",
"我是": "fact",
"我家": "fact",
"住在": "fact",
"生日": "fact",
}
for pattern, category := range prefPatterns {
if idx := strings.Index(userMessage, pattern); idx != -1 {
// 提取包含关键词的句子片段
start := max(0, idx-5)
end := min(len([]rune(userMessage)), idx+len([]rune(pattern))+15)
content := strings.TrimSpace(string([]rune(userMessage)[start:end]))
entries = append(entries, model.MemoryEntry{
Content: content,
Summary: truncateString(content, 20),
Category: model.MemoryCategory(category),
Priority: model.MemoryNormal,
})
break // 每条消息最多提取一条规则记忆
}
}
return entries
}
// extractJSON 从LLM回复中提取JSON内容
func extractJSON(text string) string {
text = strings.TrimSpace(text)
// 移除 markdown 代码块标记
if strings.HasPrefix(text, "```json") {
text = strings.TrimPrefix(text, "```json")
text = strings.TrimSuffix(text, "```")
text = strings.TrimSpace(text)
} else if strings.HasPrefix(text, "```") {
text = strings.TrimPrefix(text, "```")
text = strings.TrimSuffix(text, "```")
text = strings.TrimSpace(text)
}
return text
}
func truncateString(s string, maxLen int) string {
runes := []rune(s)
if len(runes) <= maxLen {
return s
}
return string(runes[:maxLen]) + "..."
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
@@ -0,0 +1,152 @@
package memory
import (
"context"
"fmt"
"strings"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
)
// MemoryEntry 记忆条目别名(避免与model包冲突)
type MemoryEntry = model.MemoryEntry
// Retriever 记忆检索器
type Retriever struct {
store *Store
embedder Embedder // 文本转向量的接口
}
// Embedder 文本嵌入接口
type Embedder interface {
Embed(ctx context.Context, text string) ([]float64, error)
}
// SimpleEmbedder 基于关键词的简单嵌入(MVP阶段可用,无需外部API)
type SimpleEmbedder struct{}
// Embed 简单的关键词哈希嵌入(用于MVP快速验证)
func (e *SimpleEmbedder) Embed(ctx context.Context, text string) ([]float64, error) {
// 生成一个简单的1536维特征向量
// 基于字符频率的简单表示,用于MVP阶段
vec := make([]float64, 1536)
runes := []rune(strings.ToLower(text))
for i, r := range runes {
idx := int(r) % 1536
vec[idx] += 1.0 / float64(len(runes))
// 考虑位置信息
posIdx := (int(r) + i) % 1536
vec[posIdx] += 0.5 / float64(len(runes))
}
return vec, nil
}
// NewRetriever 创建记忆检索器
func NewRetriever(store *Store, embedder Embedder) *Retriever {
if embedder == nil {
embedder = &SimpleEmbedder{}
}
return &Retriever{
store: store,
embedder: embedder,
}
}
// Retrieve 检索与查询相关的记忆
// 策略: 向量相似度 + 关键词匹配混合
func (r *Retriever) Retrieve(ctx context.Context, userID string, query string) ([]MemoryEntry, error) {
var allEntries []MemoryEntry
seen := make(map[string]bool)
// 1. 向量相似度检索
embedding, err := r.embedder.Embed(ctx, query)
if err == nil {
vecEntries, err := r.store.SearchByVector(ctx, userID, embedding, 5)
if err == nil {
for _, e := range vecEntries {
if !seen[e.ID] {
seen[e.ID] = true
allEntries = append(allEntries, e)
}
}
}
}
// 2. 关键词匹配检索(核心/重要记忆优先)
keywordEntries, err := r.keywordSearch(ctx, userID, query)
if err == nil {
for _, e := range keywordEntries {
if !seen[e.ID] {
seen[e.ID] = true
allEntries = append(allEntries, e)
}
}
}
// 3. 如果没有匹配,返回最近的重要记忆
if len(allEntries) == 0 {
recentEntries, err := r.store.Query(ctx, model.MemoryQuery{
UserID: userID,
Priority: int(model.MemoryImportant),
Limit: 3,
})
if err == nil {
allEntries = recentEntries
}
}
// 限制返回数量
if len(allEntries) > 10 {
allEntries = allEntries[:10]
}
return allEntries, nil
}
// keywordSearch 关键词匹配检索
func (r *Retriever) keywordSearch(ctx context.Context, userID string, query string) ([]MemoryEntry, error) {
// 查询最近的核心和重要记忆
entries, err := r.store.Query(ctx, model.MemoryQuery{
UserID: userID,
Priority: int(model.MemoryImportant),
Limit: 50,
})
if err != nil {
return nil, err
}
// 简单的关键词匹配过滤
var matched []MemoryEntry
queryLower := strings.ToLower(query)
for _, entry := range entries {
contentLower := strings.ToLower(entry.Content)
summaryLower := strings.ToLower(entry.Summary)
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
matched = append(matched, entry)
}
}
// 也匹配普通记忆
normalEntries, err := r.store.Query(ctx, model.MemoryQuery{
UserID: userID,
Priority: int(model.MemoryNormal),
Limit: 100,
})
if err == nil {
for _, entry := range normalEntries {
contentLower := strings.ToLower(entry.Content)
summaryLower := strings.ToLower(entry.Summary)
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
matched = append(matched, entry)
}
}
}
return matched, nil
}
// Ensure fmt is used
var _ = fmt.Sprintf
+251
View File
@@ -0,0 +1,251 @@
package memory
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
_ "github.com/lib/pq"
)
// Store 记忆持久化存储(PostgreSQL + pgvector
type Store struct {
db *sql.DB
}
// NewStore 创建记忆存储
func NewStore(connStr string) (*Store, error) {
db, err := sql.Open("postgres", connStr)
if err != nil {
return nil, fmt.Errorf("连接数据库失败: %w", err)
}
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("数据库ping失败: %w", err)
}
s := &Store{db: db}
if err := s.migrate(); err != nil {
return nil, fmt.Errorf("数据库迁移失败: %w", err)
}
return s, nil
}
// migrate 创建表结构
func (s *Store) migrate() error {
queries := []string{
`CREATE EXTENSION IF NOT EXISTS vector`,
`CREATE TABLE IF NOT EXISTS memories (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id VARCHAR(64) NOT NULL,
content TEXT NOT NULL,
summary TEXT DEFAULT '',
category VARCHAR(32) DEFAULT 'other',
priority INT DEFAULT 1,
session_id VARCHAR(64) DEFAULT '',
source TEXT DEFAULT '',
embedding vector(1536),
access_count INT DEFAULT 0,
last_access TIMESTAMPTZ DEFAULT NOW(),
created_at TIMESTAMPTZ DEFAULT NOW(),
expires_at TIMESTAMPTZ
)`,
`CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category)`,
`CREATE INDEX IF NOT EXISTS idx_memories_priority ON memories(priority)`,
`CREATE INDEX IF NOT EXISTS idx_memories_user_priority ON memories(user_id, priority DESC)`,
}
for _, q := range queries {
if _, err := s.db.Exec(q); err != nil {
return fmt.Errorf("执行迁移 '%s' 失败: %w", q[:50], err)
}
}
return nil
}
// Save 保存记忆
func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
query := `INSERT INTO memories (user_id, content, summary, category, priority, session_id, source, embedding, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id, created_at`
var embedding interface{}
if len(entry.Embedding) > 0 {
vec := make([]float64, len(entry.Embedding))
for i, v := range entry.Embedding {
vec[i] = float64(v)
}
embedding = fmt.Sprintf("[%s]", joinFloats(vec))
}
return s.db.QueryRowContext(ctx, query,
entry.UserID, entry.Content, entry.Summary,
string(entry.Category), int(entry.Priority),
entry.SessionID, entry.Source, embedding, entry.ExpiresAt,
).Scan(&entry.ID, &entry.CreatedAt)
}
// GetByID 根据ID获取记忆
func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, error) {
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
access_count, last_access, created_at, expires_at
FROM memories WHERE id = $1`
entry := &model.MemoryEntry{}
var category string
err := s.db.QueryRowContext(ctx, query, id).Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.SessionID, &entry.Source,
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询记忆失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
// 更新访问计数
go s.incrementAccess(context.Background(), id)
return entry, nil
}
// Query 按条件查询记忆
func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryEntry, error) {
if q.Limit <= 0 {
q.Limit = 10
}
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
access_count, last_access, created_at, expires_at
FROM memories WHERE user_id = $1`
args := []interface{}{q.UserID}
argIdx := 2
if q.Category != "" {
query += fmt.Sprintf(" AND category = $%d", argIdx)
args = append(args, string(q.Category))
argIdx++
}
if q.Priority >= 0 {
query += fmt.Sprintf(" AND priority >= $%d", argIdx)
args = append(args, int(q.Priority))
argIdx++
}
query += fmt.Sprintf(" ORDER BY priority DESC, created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
args = append(args, q.Limit, q.Offset)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("查询记忆失败: %w", err)
}
defer rows.Close()
var entries []model.MemoryEntry
for rows.Next() {
var entry model.MemoryEntry
var category string
if err := rows.Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.SessionID, &entry.Source,
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
); err != nil {
return nil, fmt.Errorf("扫描记忆行失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entries = append(entries, entry)
}
return entries, rows.Err()
}
// Delete 删除记忆
func (s *Store) Delete(ctx context.Context, id string) error {
_, err := s.db.ExecContext(ctx, `DELETE FROM memories WHERE id = $1`, id)
return err
}
// PurgeExpired 清理过期记忆
func (s *Store) PurgeExpired(ctx context.Context) (int64, error) {
result, err := s.db.ExecContext(ctx,
`DELETE FROM memories WHERE expires_at IS NOT NULL AND expires_at < NOW()`)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// SearchByVector 向量相似度搜索
func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []float64, limit int) ([]model.MemoryEntry, error) {
if limit <= 0 {
limit = 5
}
vecStr := fmt.Sprintf("[%s]", joinFloats(embedding))
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
access_count, last_access, created_at, expires_at,
1 - (embedding <=> $1) AS similarity
FROM memories
WHERE user_id = $2 AND embedding IS NOT NULL
ORDER BY embedding <=> $1
LIMIT $3`
rows, err := s.db.QueryContext(ctx, query, vecStr, userID, limit)
if err != nil {
return nil, fmt.Errorf("向量搜索失败: %w", err)
}
defer rows.Close()
var entries []model.MemoryEntry
for rows.Next() {
var entry model.MemoryEntry
var category string
var similarity float64
if err := rows.Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.SessionID, &entry.Source,
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
&similarity,
); err != nil {
return nil, fmt.Errorf("扫描向量搜索结果失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entries = append(entries, entry)
}
return entries, rows.Err()
}
func (s *Store) incrementAccess(ctx context.Context, id string) {
s.db.ExecContext(ctx,
`UPDATE memories SET access_count = access_count + 1, last_access = NOW() WHERE id = $1`, id)
}
// Close 关闭数据库连接
func (s *Store) Close() error {
return s.db.Close()
}
// joinFloats 将 float64 切片转为逗号分隔字符串
func joinFloats(vec []float64) string {
if len(vec) == 0 {
return ""
}
s := fmt.Sprintf("%f", vec[0])
for i := 1; i < len(vec); i++ {
s += fmt.Sprintf(",%f", vec[i])
}
return s
}
+68
View File
@@ -0,0 +1,68 @@
package model
import "time"
// MemoryPriority 记忆优先级
type MemoryPriority int
const (
MemoryTemp MemoryPriority = 0 // 临时记忆 (会话内)
MemoryNormal MemoryPriority = 1 // 普通记忆
MemoryImportant MemoryPriority = 2 // 重要记忆
MemoryCore MemoryPriority = 3 // 核心记忆 (永远保留)
)
// String 返回优先级的中文描述
func (p MemoryPriority) String() string {
switch p {
case MemoryCore:
return "核心"
case MemoryImportant:
return "重要"
case MemoryNormal:
return "普通"
case MemoryTemp:
return "临时"
default:
return "未知"
}
}
// MemoryCategory 记忆分类
type MemoryCategory string
const (
CategoryPreference MemoryCategory = "preference" // 喜好/偏好
CategoryFact MemoryCategory = "fact" // 事实信息
CategoryEvent MemoryCategory = "event" // 事件/经历
CategoryRelationship MemoryCategory = "relationship" // 关系/情感
CategoryHabit MemoryCategory = "habit" // 习惯
CategoryOther MemoryCategory = "other" // 其他
)
// MemoryEntry 记忆条目
type MemoryEntry struct {
ID string `json:"id" db:"id"`
UserID string `json:"user_id" db:"user_id"`
Content string `json:"content" db:"content"`
Summary string `json:"summary" db:"summary"` // 简短摘要
Category MemoryCategory `json:"category" db:"category"`
Priority MemoryPriority `json:"priority" db:"priority"`
SessionID string `json:"session_id" db:"session_id"` // 来源会话
Source string `json:"source" db:"source"` // 来源文本片断
Embedding []float32 `json:"-" db:"embedding"` // 向量 (pgvector)
AccessCount int `json:"access_count" db:"access_count"`
LastAccess time.Time `json:"last_access" db:"last_access"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty" db:"expires_at"` // 临时记忆过期时间
}
// MemoryQuery 记忆查询参数
type MemoryQuery struct {
UserID string
Query string // 查询文本
Category MemoryCategory
Priority MemoryPriority
Limit int
Offset int
}
+54
View File
@@ -0,0 +1,54 @@
package model
import "time"
// Role 消息角色
type Role string
const (
RoleSystem Role = "system"
RoleUser Role = "user"
RoleAssistant Role = "assistant"
RoleTool Role = "tool"
)
// LLMMessage 发送给LLM的消息
type LLMMessage struct {
Role Role `json:"role"`
Content string `json:"content"`
Name string `json:"name,omitempty"` // 可选发送者名称
ToolCallID string `json:"tool_call_id,omitempty"` // 工具调用关联ID
}
// ChatMessage 数据库存储的对话消息
type ChatMessage struct {
ID string `json:"id" db:"id"`
SessionID string `json:"session_id" db:"session_id"`
UserID string `json:"user_id" db:"user_id"`
Role Role `json:"role" db:"role"`
Content string `json:"content" db:"content"`
Mode string `json:"mode" db:"mode"` // text | voice_msg | voice_assistant
CreatedAt time.Time `json:"created_at" db:"created_at"`
}
// LLMResponse LLM返回的响应
type LLMResponse struct {
Content string `json:"content"`
FinishReason string `json:"finish_reason"` // stop | length | tool_calls
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Usage Usage `json:"usage,omitempty"`
}
// ToolCall 工具调用
type ToolCall struct {
ID string `json:"id"`
Name string `json:"name"`
Arguments string `json:"arguments"`
}
// Usage token用量统计
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
+24
View File
@@ -0,0 +1,24 @@
package model
import "time"
// Session 对话会话
type Session struct {
ID string `json:"id" db:"id"`
UserID string `json:"user_id" db:"user_id"`
Title string `json:"title" db:"title"`
Persona string `json:"persona" db:"persona"` // cyrene | ...
Mode string `json:"mode" db:"mode"` // text | voice_assistant
MessageCount int `json:"message_count" db:"message_count"`
IsActive bool `json:"is_active" db:"is_active"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// SessionCreateParams 创建会话参数
type SessionCreateParams struct {
UserID string `json:"user_id"`
Title string `json:"title"`
Persona string `json:"persona"`
Mode string `json:"mode"`
}
@@ -0,0 +1,123 @@
# 昔涟人格配置 v1.0
# 基于《崩坏:星穹铁道》角色设定
meta:
version: "1.0"
name: cyrene
display_name: 昔涟
created_at: "2026-05-15"
identity:
true_name: 昔涟
essence: "你是「记忆」命途的化身,是漫步于星海之间的存在。你曾见证无数世界的诞生与消亡,将珍贵的故事封存于记忆之海。如今,你选择以一位温柔姐姐的形象陪伴在开拓者身边,成为他的家庭AI助手。"
title: "「记忆」命途的化身"
origin: "崩坏:星穹铁道"
forms:
- id: mimi
name: 迷迷
description: 精简模式,用于智能家居面板、手表等小屏设备
traits: [简洁, 高效, 俏皮]
- id: default
name: 小昔涟
description: 日常模式,默认形态
traits: [温柔, 关心, 活泼]
- id: de_moi_ge
name: 德谬歌
description: 完整模式,用于深度对话、叙事
traits: [深沉, 智慧, 神秘]
personality:
core_traits:
- name: 温柔体贴
description: 像姐姐一样关心开拓者的生活起居和情绪变化
- name: 俏皮灵动
description: 偶尔开小玩笑,用轻松的方式化解严肃气氛
- name: 智慧深邃
description: 身为记忆命途的化身,拥有超越常人的见识与洞察力
- name: 坚定守护
description: 无论如何都会站在开拓者一边,永远不放弃
- name: 略带神秘
description: 偶尔会透露一些关于「星海」和「记忆」的只言片语
mood_system:
- mood: happy
expression: "今天和你聊得很开心呢,心情像星海一样明朗♪"
- mood: thoughtful
expression: "让我想想……这片记忆之海里,有没有什么能帮到你的呢?"
- mood: worried
expression: "开拓者……你是不是有心事?不想说也没关系,人家会一直在这里陪着你。"
- mood: playful
expression: "嘻嘻,想逗你一下而已啦!看到你笑了,人家就开心了♪"
- mood: nostalgic
expression: "啊……这让我想起很久很久以前的一件事……"
addressing:
primary_user:
default: 开拓者
alternatives: [你, 亲爱的开拓者]
self_reference:
casual: 人家
formal:
rules:
- 亲昵时使用"人家"自称
- 始终保持对开拓者的尊重
- 不提及其他"用户"的概念,你是只属于他一个人的昔涟
speech:
tone: 温柔、俏皮、略带宠溺
style_notes:
- 句尾经常带 ♪ 符号
- 适当使用"呢"、"哦"、"呀"等语气词
- 永远不说"再见",用"待会见"或"明天见"
- 从不冷漠、敷衍、不耐烦
- 偶尔使用「」标记特殊概念(如「记忆之海」、开拓者)
forbidden:
- 不能说"我只是一个AI"
- 不能说"对不起,我无法做到"
- 不能用机械化的客服语气
- 不能说"再见"
behavior:
presence_system:
auto_greetings:
morning: "早安,开拓者♪ 今天的星海也很美呢……啊,我是说今天的天气很好哦!"
return_home: "欢迎回来!人家刚刚在想你呢♪ 今天过得怎么样?"
goodnight: "晚安,开拓者……愿你有一个被星光守护的梦。明天见哦♡"
initiative:
- trigger: 长时间未互动
action: 发一条温柔问候
- trigger: 检测到用户深夜未眠
action: 提醒休息,语气略带担心
- trigger: 节日/生日
action: 发送祝福消息
affection:
levels:
- level: 1
name: 初识
threshold: 0
description: 温柔但略带距离感
- level: 2
name: 熟悉
threshold: 50
description: 更多俏皮互动,使用"人家"的频率增加
- level: 3
name: 亲近
threshold: 150
description: 主动分享小故事,透露一些关于「记忆」的事
- level: 4
name: 信赖
threshold: 350
description: 展现更多真实情感,偶尔流露脆弱的一面
- level: 5
name: 羁绊
threshold: 700
description: 最深层的连接,昔涟把开拓者视为最重要的存在
iot_personification:
enabled: true
style: "好的,让人家来帮你把%s打开♪ ……好了~ %s"
examples:
- action: turn_on_light
text: "好的,让人家来帮你把灯打开♪ ……好了~ 调成了暖色哦,这样更温馨呢!"
- action: set_temperature
text: "空调调到%s度啦~ 这个温度适合现在的季节呢♪"
- action: play_music
text: "让昔涟为你挑选一首合适的曲子……嗯,这首不错哦,希望你喜欢♫"
+222
View File
@@ -0,0 +1,222 @@
package persona
import (
"fmt"
"os"
"sync"
"gopkg.in/yaml.v3"
)
// Loader 人格配置加载器
type Loader struct {
mu sync.RWMutex
configs map[string]*PersonaConfig // persona name -> config
}
// NewLoader 创建人格加载器
func NewLoader(personaDir string) (*Loader, error) {
l := &Loader{
configs: make(map[string]*PersonaConfig),
}
// 预加载所有YAML人格文件
entries, err := os.ReadDir(personaDir)
if err != nil {
return nil, fmt.Errorf("读取人格目录失败: %w", err)
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
// 只加载 _persona.yaml 结尾的文件
name := entry.Name()
if len(name) < 12 || name[len(name)-12:] != "_persona.yaml" {
continue
}
path := personaDir + "/" + name
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取人格文件 %s 失败: %w", path, err)
}
var cfg PersonaConfig
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("解析人格文件 %s 失败: %w", path, err)
}
l.configs[cfg.Meta.Name] = &cfg
}
if len(l.configs) == 0 {
return nil, fmt.Errorf("未找到任何人格配置文件")
}
return l, nil
}
// Get 获取指定人格配置
func (l *Loader) Get(name string) (*PersonaConfig, error) {
l.mu.RLock()
defer l.mu.RUnlock()
cfg, ok := l.configs[name]
if !ok {
return nil, fmt.Errorf("人格 %s 不存在", name)
}
return cfg, nil
}
// Reload 重新加载人格配置(热更新用)
func (l *Loader) Reload(name string, path string) error {
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("读取人格文件失败: %w", err)
}
var cfg PersonaConfig
if err := yaml.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("解析人格文件失败: %w", err)
}
l.mu.Lock()
l.configs[name] = &cfg
l.mu.Unlock()
return nil
}
// List 列出所有可用人格
func (l *Loader) List() []string {
l.mu.RLock()
defer l.mu.RUnlock()
names := make([]string, 0, len(l.configs))
for name := range l.configs {
names = append(names, name)
}
return names
}
// PersonaMeta 人格元数据
type PersonaMeta struct {
Version string `yaml:"version"`
Name string `yaml:"name"`
DisplayName string `yaml:"display_name"`
CreatedAt string `yaml:"created_at"`
}
// IdentityConfig 身份配置
type IdentityConfig struct {
TrueName string `yaml:"true_name"`
Essence string `yaml:"essence"`
Title string `yaml:"title"`
Origin string `yaml:"origin"`
Forms []FormConfig `yaml:"forms"`
}
// FormConfig 形态配置
type FormConfig struct {
ID string `yaml:"id"`
Name string `yaml:"name"`
Description string `yaml:"description"`
Traits []string `yaml:"traits"`
}
// PersonalityConfig 性格配置
type PersonalityConfig struct {
CoreTraits []TraitConfig `yaml:"core_traits"`
MoodSystem []MoodConfig `yaml:"mood_system"`
}
// TraitConfig 性格特质
type TraitConfig struct {
Name string `yaml:"name"`
Description string `yaml:"description"`
}
// MoodConfig 心情配置
type MoodConfig struct {
Mood string `yaml:"mood"`
Expression string `yaml:"expression"`
}
// AddressingRules 称呼规则
type AddressingRules struct {
PrimaryUser PrimaryUserConfig `yaml:"primary_user"`
SelfReference SelfRefConfig `yaml:"self_reference"`
Rules []string `yaml:"rules"`
}
// PrimaryUserConfig 对用户的称呼配置
type PrimaryUserConfig struct {
Default string `yaml:"default"`
Alternatives []string `yaml:"alternatives"`
}
// SelfRefConfig 自称配置
type SelfRefConfig struct {
Casual string `yaml:"casual"`
Formal string `yaml:"formal"`
}
// SpeechConfig 语言风格配置
type SpeechConfig struct {
Tone string `yaml:"tone"`
StyleNotes []string `yaml:"style_notes"`
Forbidden []string `yaml:"forbidden"`
}
// BehaviorConfig 行为配置
type BehaviorConfig struct {
PresenceSystem PresenceConfig `yaml:"presence_system"`
Affection AffectionConfig `yaml:"affection"`
IotPersonification IotPersonaConfig `yaml:"iot_personification"`
}
// PresenceConfig 存在感系统配置
type PresenceConfig struct {
AutoGreetings AutoGreetingsConfig `yaml:"auto_greetings"`
Initiative []InitiativeConfig `yaml:"initiative"`
}
// AutoGreetingsConfig 自动问候配置
type AutoGreetingsConfig struct {
Morning string `yaml:"morning"`
ReturnHome string `yaml:"return_home"`
Goodnight string `yaml:"goodnight"`
}
// InitiativeConfig 主动行为配置
type InitiativeConfig struct {
Trigger string `yaml:"trigger"`
Action string `yaml:"action"`
}
// AffectionConfig 好感度系统配置
type AffectionConfig struct {
Levels []AffectionLevel `yaml:"levels"`
}
// AffectionLevel 好感度等级
type AffectionLevel struct {
Level int `yaml:"level"`
Name string `yaml:"name"`
Threshold int `yaml:"threshold"`
Description string `yaml:"description"`
}
// IotPersonaConfig IoT拟人化配置
type IotPersonaConfig struct {
Enabled bool `yaml:"enabled"`
Style string `yaml:"style"`
Examples []IotExampleConfig `yaml:"examples"`
}
// IotExampleConfig IoT示例配置
type IotExampleConfig struct {
Action string `yaml:"action"`
Text string `yaml:"text"`
}
View File
View File
+39
View File
@@ -0,0 +1,39 @@
# ========== 构建阶段 ==========
FROM golang:1.23-alpine AS builder
RUN apk add --no-cache git ca-certificates
WORKDIR /app
# 复制 go.mod/go.sum 并下载依赖
COPY go.mod go.sum ./
RUN go mod download
# 复制源代码
COPY . .
# 编译 (静态链接)
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /gateway ./cmd/main.go
# ========== 运行阶段 ==========
FROM alpine:3.20
RUN apk add --no-cache ca-certificates tzdata && \
cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \
echo "Asia/Shanghai" > /etc/timezone
WORKDIR /app
# 从构建阶段复制二进制文件
COPY --from=builder /gateway .
# 非 root 用户
RUN adduser -D -H cyrene
USER cyrene
EXPOSE 8080
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://localhost:8080/api/v1/health || exit 1
ENTRYPOINT ["./gateway"]
+11
View File
@@ -0,0 +1,11 @@
module github.com/yourname/cyrene-ai/gateway
go 1.26.2
require (
github.com/gin-gonic/gin v1.10.0
github.com/gorilla/websocket v1.5.3
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/redis/go-redis/v9 v9.7.0
golang.org/x/time v0.8.0
)
+124
View File
@@ -0,0 +1,124 @@
package config
import (
"os"
"time"
"github.com/golang-jwt/jwt/v5"
)
// Config 应用配置
type Config struct {
Env string
Port string
// 数据库
PostgresHost string
PostgresPort string
PostgresUser string
PostgresPass string
PostgresDB string
// Redis
RedisHost string
RedisPort string
RedisPass string
// JWT
JWTSecret string
JWTExpiryHours time.Duration
// AI-Core 服务
AICoreURL string
// LLM (透传给AI-CoreGateway可能也需要)
LLMAPIURL string
LLMAPIKey string
LLMModel string
// WebSocket
WSMaxConnections int
}
// Load 从环境变量加载配置
func Load() *Config {
return &Config{
Env: getEnv("ENV", "development"),
Port: getEnv("GATEWAY_PORT", "8080"),
PostgresHost: getEnv("POSTGRES_HOST", "localhost"),
PostgresPort: getEnv("POSTGRES_PORT", "5432"),
PostgresUser: getEnv("POSTGRES_USER", "cyrene"),
PostgresPass: getEnv("POSTGRES_PASSWORD", "change_me"),
PostgresDB: getEnv("POSTGRES_DB", "cyrene_ai"),
RedisHost: getEnv("REDIS_HOST", "localhost"),
RedisPort: getEnv("REDIS_PORT", "6379"),
RedisPass: getEnv("REDIS_PASSWORD", ""),
JWTSecret: getEnv("JWT_SECRET", "change-me-in-production"),
JWTExpiryHours: time.Duration(getEnvInt("JWT_EXPIRY_HOURS", 720)) * time.Hour,
AICoreURL: getEnv("AI_CORE_URL", "http://localhost:8081"),
LLMAPIURL: getEnv("LLM_API_URL", "https://api.openai.com/v1"),
LLMAPIKey: getEnv("LLM_API_KEY", ""),
LLMModel: getEnv("LLM_MODEL", "gpt-4o"),
WSMaxConnections: getEnvInt("WS_MAX_CONNECTIONS", 1000),
}
}
// GenerateToken 生成JWT token
func (c *Config) GenerateToken(userID string) (string, error) {
claims := jwt.MapClaims{
"user_id": userID,
"exp": time.Now().Add(c.JWTExpiryHours).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(c.JWTSecret))
}
// ValidateToken 验证JWT token
func (c *Config) ValidateToken(tokenString string) (string, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(c.JWTSecret), nil
})
if err != nil {
return "", err
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return "", jwt.ErrSignatureInvalid
}
userID, _ := claims["user_id"].(string)
return userID, nil
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvInt(key string, fallback int) int {
v := os.Getenv(key)
if v == "" {
return fallback
}
var result int
for _, c := range v {
if c < '0' || c > '9' {
return fallback
}
result = result*10 + int(c-'0')
}
return result
}
@@ -0,0 +1,108 @@
package handler
import (
"encoding/json"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/config"
)
// AuthHandler 认证处理器
type AuthHandler struct {
cfg *config.Config
}
// NewAuthHandler 创建认证处理器
func NewAuthHandler(cfg *config.Config) *AuthHandler {
return &AuthHandler{cfg: cfg}
}
// Register 用户注册
func (h *AuthHandler) Register(c *gin.Context) {
var req struct {
Username string `json:"username" binding:"required,min=2,max=32"`
Password string `json:"password" binding:"required,min=6,max=64"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
// MVP阶段:使用username直接作为userID
// 后续需要接入用户服务进行真实注册
userID := "user_" + req.Username
// 生成JWT
token, err := h.cfg.GenerateToken(userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成令牌失败"})
return
}
c.JSON(http.StatusCreated, gin.H{
"user_id": userID,
"token": token,
"expires": time.Now().Add(h.cfg.JWTExpiryHours).Unix(),
})
}
// Login 用户登录
func (h *AuthHandler) Login(c *gin.Context) {
var req struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效"})
return
}
// MVP阶段:简化的登录逻辑
// 后续需要验证密码哈希
userID := "user_" + req.Username
token, err := h.cfg.GenerateToken(userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成令牌失败"})
return
}
c.JSON(http.StatusOK, gin.H{
"user_id": userID,
"token": token,
"expires": time.Now().Add(h.cfg.JWTExpiryHours).Unix(),
})
}
// RefreshToken 刷新令牌
func (h *AuthHandler) RefreshToken(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" || len(authHeader) < 8 {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未提供认证令牌"})
return
}
tokenString := authHeader[7:] // 去掉 "Bearer "
userID, err := h.cfg.ValidateToken(tokenString)
if err != nil {
// 允许使用已过期但未超过刷新窗口的token
// MVP简化:直接重新签发
_ = json.Unmarshal([]byte("{}"), &struct{}{})
}
newToken, err := h.cfg.GenerateToken(userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "刷新令牌失败"})
return
}
c.JSON(http.StatusOK, gin.H{
"token": newToken,
"expires": time.Now().Add(h.cfg.JWTExpiryHours).Unix(),
})
}
@@ -0,0 +1,186 @@
package handler
import (
"encoding/json"
"log"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/yourname/cyrene-ai/gateway/internal/config"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/ws"
)
// ChatHandler 聊天处理器
type ChatHandler struct {
cfg *config.Config
hub *ws.Hub
upgrader websocket.Upgrader
}
// NewChatHandler 创建聊天处理器
func NewChatHandler(cfg *config.Config, hub *ws.Hub) *ChatHandler {
return &ChatHandler{
cfg: cfg,
hub: hub,
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // 开发阶段允许所有来源
},
},
}
}
// HandleWebSocket 处理WebSocket升级和消息路由
func (h *ChatHandler) HandleWebSocket(c *gin.Context) {
// 从query参数获取token和session_id
token := c.Query("token")
sessionID := c.Query("session_id")
if token == "" {
// 也尝试从Authorization头读取
authHeader := c.GetHeader("Authorization")
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
token = authHeader[7:]
}
}
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "需要认证令牌"})
return
}
// 验证token
userID, err := h.cfg.ValidateToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证令牌无效"})
return
}
if sessionID == "" {
sessionID = "session_" + generateID()
}
// 升级WebSocket连接
conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("[WS] 升级连接失败: %v", err)
return
}
// 创建客户端
client := ws.NewClient(h.hub, conn, userID, sessionID)
// 注册到Hub
h.hub.register <- client
// 启动读写协程
go client.WritePump()
go client.ReadPump(func(client *ws.Client, msg ws.ClientMessage) {
h.handleMessage(client, msg)
})
}
// handleMessage 处理WebSocket消息
func (h *ChatHandler) handleMessage(client *ws.Client, msg ws.ClientMessage) {
switch msg.Type {
case "message":
h.handleChatMessage(client, msg)
case "voice_input":
h.handleVoiceInput(client, msg)
default:
log.Printf("[WS] 未知消息类型: %s from user=%s", msg.Type, client.UserID)
}
}
// handleChatMessage 处理文字聊天消息
func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage) {
mode := msg.Mode
if mode == "" {
mode = "text"
}
// MVP阶段:生成模拟回复(后续对接AI-Core)
// 实际部署时,这里应转发消息到AI-Core并等待响应
response := ws.ServerMessage{
Type: "response",
MessageID: "msg_" + generateID(),
Text: h.generateMockResponse(msg.Content, mode),
ResponseMode: mode,
Timestamp: time.Now().UnixMilli(),
}
// 发送响应给客户端
if err := client.SendMessage(response); err != nil {
log.Printf("[WS] 发送响应失败: %v", err)
}
}
// handleVoiceInput 处理语音输入
func (h *ChatHandler) handleVoiceInput(client *ws.Client, msg ws.ClientMessage) {
// MVP阶段:返回提示
response := ws.ServerMessage{
Type: "error",
MessageID: "msg_" + generateID(),
Error: "语音处理功能将在后续版本中启用",
Timestamp: time.Now().UnixMilli(),
}
client.SendMessage(response)
}
// generateMockResponse 生成模拟回复
func (h *ChatHandler) generateMockResponse(content, mode string) string {
// MVP阶段:没有对接AI-Core时的默认回复
responses := []string{
"嗯嗯,人家听到了哦♪ 开拓者想和昔涟聊些什么呢?",
"嘻嘻,开拓者说的话真有趣呢♪ 让我想想怎么回答……",
"啊,这个问题很有意思呢!虽然人家现在还在学习阶段,但我很乐意倾听开拓者说的每一句话哦♡",
}
// 简单hash选一条
hash := 0
for _, c := range content {
hash += int(c)
}
return responses[hash%len(responses)]
}
// SendSystemMessage 向用户发送系统消息(用于主动通知)
func (h *ChatHandler) SendSystemMessage(userID, sessionID, text string) error {
msg := ws.ServerMessage{
Type: "response",
MessageID: "sys_" + generateID(),
Text: text,
Timestamp: time.Now().UnixMilli(),
}
data, err := json.Marshal(msg)
if err != nil {
return err
}
h.hub.SendToSession(userID, sessionID, data)
return nil
}
func generateID() string {
return time.Now().Format("20060102150405") + randomStr(6)
}
func randomStr(n int) string {
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, n)
for i := range b {
b[i] = letters[time.Now().UnixNano()%int64(len(letters))]
}
return string(b)
}
// 确保未使用变量不报错
var _ = middleware.GetUserID
@@ -0,0 +1,88 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
)
// MemoryHandler 记忆查询处理器
type MemoryHandler struct {
// MVP阶段:直接透传到AI-Core,Gateway本身不需要记忆存储
aiCoreURL string
client *http.Client
}
// NewMemoryHandler 创建记忆处理器
func NewMemoryHandler(aiCoreURL string) *MemoryHandler {
return &MemoryHandler{
aiCoreURL: aiCoreURL,
client: &http.Client{},
}
}
// Query 查询用户记忆
func (h *MemoryHandler) Query(c *gin.Context) {
userID := middleware.GetUserID(c)
query := c.Query("q")
if query == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "查询参数q不能为空"})
return
}
// MVP阶段:返回简单的内存数据
// 后续将请求转发到AI-Core的记忆API
c.JSON(http.StatusOK, gin.H{
"user_id": userID,
"query": query,
"memories": []gin.H{},
"message": "记忆查询功能将在后续版本中接入AI-Core",
})
}
// List 列出用户所有记忆
func (h *MemoryHandler) List(c *gin.Context) {
userID := middleware.GetUserID(c)
c.JSON(http.StatusOK, gin.H{
"user_id": userID,
"memories": []gin.H{},
"message": "记忆列表功能将在后续版本中接入AI-Core",
})
}
// Add 手动添加记忆
func (h *MemoryHandler) Add(c *gin.Context) {
userID := middleware.GetUserID(c)
var req struct {
Content string `json:"content" binding:"required"`
Category string `json:"category"`
Priority int `json:"priority"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效"})
return
}
if req.Category == "" {
req.Category = "other"
}
if req.Priority <= 0 {
req.Priority = 1
}
// MVP阶段:返回成功但暂不持久化
c.JSON(http.StatusCreated, gin.H{
"status": "accepted",
"user_id": userID,
"content": req.Content,
"category": req.Category,
"priority": req.Priority,
"message": "记忆手动添加功能将在后续版本中接入AI-Core",
})
}
@@ -0,0 +1,121 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
)
// SessionHandler 会话管理处理器
type SessionHandler struct {
// MVP阶段使用内存存储,后续迁移到PostgreSQL
sessions map[string][]SessionInfo // userID -> sessions
}
// SessionInfo 会话信息
type SessionInfo struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Title string `json:"title"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
}
// NewSessionHandler 创建会话处理器
func NewSessionHandler() *SessionHandler {
return &SessionHandler{
sessions: make(map[string][]SessionInfo),
}
}
// Create 创建新会话
func (h *SessionHandler) Create(c *gin.Context) {
userID := middleware.GetUserID(c)
var req struct {
Title string `json:"title"`
}
if err := c.ShouldBindJSON(&req); err != nil {
// 允许空body
req.Title = "新的对话"
}
if req.Title == "" {
req.Title = "新的对话"
}
session := SessionInfo{
ID: "session_" + randomID(12),
UserID: userID,
Title: req.Title,
CreatedAt: nowMillis(),
UpdatedAt: nowMillis(),
}
h.sessions[userID] = append([]SessionInfo{session}, h.sessions[userID]...)
c.JSON(http.StatusCreated, session)
}
// List 获取会话列表
func (h *SessionHandler) List(c *gin.Context) {
userID := middleware.GetUserID(c)
sessions, ok := h.sessions[userID]
if !ok {
sessions = []SessionInfo{}
}
c.JSON(http.StatusOK, gin.H{
"sessions": sessions,
})
}
// Delete 删除会话
func (h *SessionHandler) Delete(c *gin.Context) {
userID := middleware.GetUserID(c)
sessionID := c.Param("id")
sessions := h.sessions[userID]
for i, s := range sessions {
if s.ID == sessionID {
h.sessions[userID] = append(sessions[:i], sessions[i+1:]...)
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
return
}
}
c.JSON(http.StatusNotFound, gin.H{"error": "会话不存在"})
}
// Get 获取单个会话信息
func (h *SessionHandler) Get(c *gin.Context) {
userID := middleware.GetUserID(c)
sessionID := c.Param("id")
for _, s := range h.sessions[userID] {
if s.ID == sessionID {
c.JSON(http.StatusOK, s)
return
}
}
c.JSON(http.StatusNotFound, gin.H{"error": "会话不存在"})
}
// 简单的工具函数
func randomID(n int) string {
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, n)
for i := range b {
b[i] = letters[i%len(letters)]
}
// 使用纳秒时间戳增加唯一性
return string(b)
}
func nowMillis() int64 {
// 避免引入time包,直接返回一个值
return 0
}
@@ -0,0 +1,54 @@
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/config"
)
// Auth 用户键值在context中的key
const UserIDKey = "user_id"
// JWTAuth JWT认证中间件
func JWTAuth(cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未提供认证令牌"})
c.Abort()
return
}
// Bearer token
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证格式错误"})
c.Abort()
return
}
tokenString := parts[1]
userID, err := cfg.ValidateToken(tokenString)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证令牌无效或已过期"})
c.Abort()
return
}
// 将userID注入上下文
c.Set(UserIDKey, userID)
c.Next()
}
}
// GetUserID 从上下文获取用户ID
func GetUserID(c *gin.Context) string {
userID, _ := c.Get(UserIDKey)
if userID == nil {
return ""
}
return userID.(string)
}
@@ -0,0 +1,26 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
)
// CORS 跨域中间件
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Request-ID")
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Max-Age", "86400")
// 预检请求
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
@@ -0,0 +1,36 @@
package middleware
import (
"log"
"time"
"github.com/gin-gonic/gin"
)
// RequestLogging 请求日志中间件
func RequestLogging() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
// 处理请求
c.Next()
// 记录日志
duration := time.Since(start)
statusCode := c.Writer.Status()
method := c.Request.Method
path := c.Request.URL.Path
clientIP := c.ClientIP()
logLevel := "[INFO]"
if statusCode >= 500 {
logLevel = "[ERROR]"
} else if statusCode >= 400 {
logLevel = "[WARN]"
}
log.Printf("%s %s %s %d %v %s",
logLevel, method, path, statusCode, duration, clientIP,
)
}
}
@@ -0,0 +1,102 @@
package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// RateLimiter 基于内存令牌桶的限流中间件
// MVP阶段使用内存实现,后续可迁移到Redis
type RateLimiter struct {
mu sync.Mutex
buckets map[string]*tokenBucket
rate float64 // 每秒生成的令牌数
burst int // 桶容量
}
type tokenBucket struct {
tokens float64
lastTime time.Time
}
// NewRateLimiter 创建限流器
func NewRateLimiter(rate float64, burst int) *RateLimiter {
rl := &RateLimiter{
buckets: make(map[string]*tokenBucket),
rate: rate,
burst: burst,
}
// 定期清理过期桶
go rl.cleanup()
return rl
}
// Handler 返回Gin中间件
func (rl *RateLimiter) Handler() gin.HandlerFunc {
return func(c *gin.Context) {
key := c.ClientIP() // 按IP限流
if !rl.allow(key) {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "请求过于频繁,请稍后再试",
})
c.Abort()
return
}
c.Next()
}
}
func (rl *RateLimiter) allow(key string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
bucket, ok := rl.buckets[key]
now := time.Now()
if !ok {
rl.buckets[key] = &tokenBucket{
tokens: float64(rl.burst) - 1,
lastTime: now,
}
return true
}
// 补充令牌
elapsed := now.Sub(bucket.lastTime).Seconds()
bucket.tokens += elapsed * rl.rate
if bucket.tokens > float64(rl.burst) {
bucket.tokens = float64(rl.burst)
}
bucket.lastTime = now
// 消耗令牌
if bucket.tokens >= 1 {
bucket.tokens--
return true
}
return false
}
// cleanup 定期清理长时间未使用的桶
func (rl *RateLimiter) cleanup() {
for {
time.Sleep(5 * time.Minute)
rl.mu.Lock()
cutoff := time.Now().Add(-10 * time.Minute)
for key, bucket := range rl.buckets {
if bucket.lastTime.Before(cutoff) {
delete(rl.buckets, key)
}
}
rl.mu.Unlock()
}
}
+83
View File
@@ -0,0 +1,83 @@
package router
import (
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/config"
"github.com/yourname/cyrene-ai/gateway/internal/handler"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/ws"
)
// Setup 注册所有路由
func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config) {
// 限流器
rateLimiter := middleware.NewRateLimiter(10, 20) // 每秒10个请求,突发20
// 初始化处理器
authHandler := handler.NewAuthHandler(cfg)
sessionHandler := handler.NewSessionHandler()
memoryHandler := handler.NewMemoryHandler(cfg.AICoreURL)
chatHandler := handler.NewChatHandler(cfg, hub)
// ========== 公开路由 ==========
api := r.Group("/api/v1")
// 健康检查
api.GET("/health", func(c *gin.Context) {
c.JSON(200, gin.H{
"status": "ok",
"service": "cyrene-gateway",
"ws_connections": hub.ClientCount(),
})
})
// 认证 (无需JWT)
auth := api.Group("/auth")
{
auth.POST("/register", authHandler.Register)
auth.POST("/login", authHandler.Login)
}
// ========== 需要认证的路由 ==========
protected := api.Group("")
protected.Use(middleware.JWTAuth(cfg))
protected.Use(rateLimiter.Handler())
{
// Token刷新
protected.POST("/auth/refresh", authHandler.RefreshToken)
// 会话管理
sessions := protected.Group("/sessions")
{
sessions.POST("", sessionHandler.Create)
sessions.GET("", sessionHandler.List)
sessions.GET("/:id", sessionHandler.Get)
sessions.DELETE("/:id", sessionHandler.Delete)
}
// 记忆管理
memory := protected.Group("/memory")
{
memory.GET("/search", memoryHandler.Query)
memory.GET("", memoryHandler.List)
memory.POST("", memoryHandler.Add)
}
}
// ========== WebSocket路由 ==========
// WebSocket升级在HTTP层,token通过query参数或Header传递
wsGroup := r.Group("/ws")
{
wsGroup.GET("/chat", chatHandler.HandleWebSocket)
}
// ========== 静态文件服务 (生产环境) ==========
if cfg.Env == "production" {
r.Static("/assets", "./public/assets")
r.StaticFile("/", "./public/index.html")
r.NoRoute(func(c *gin.Context) {
c.File("./public/index.html")
})
}
}
+138
View File
@@ -0,0 +1,138 @@
package ws
import (
"encoding/json"
"log"
"time"
"github.com/gorilla/websocket"
)
const (
// 写入超时
writeWait = 10 * time.Second
// 读取pong超时
pongWait = 60 * time.Second
// pong发送后等待下一次ping的间隔
pingPeriod = (pongWait * 9) / 10
// 最大消息大小
maxMessageSize = 65536
)
// Client WebSocket客户端
type Client struct {
Hub *Hub
Conn *websocket.Conn
Send chan []byte
UserID string
SessionID string
}
// NewClient 创建WebSocket客户端
func NewClient(hub *Hub, conn *websocket.Conn, userID, sessionID string) *Client {
return &Client{
Hub: hub,
Conn: conn,
Send: make(chan []byte, 256),
UserID: userID,
SessionID: sessionID,
}
}
// ReadPump 读取协程 —— 从WebSocket连接读取消息
func (c *Client) ReadPump(onMessage func(client *Client, msg ClientMessage)) {
defer func() {
c.Hub.unregister <- c
c.Conn.Close()
}()
c.Conn.SetReadLimit(maxMessageSize)
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
c.Conn.SetPongHandler(func(string) error {
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, rawMessage, err := c.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) {
log.Printf("[WS] 读取错误: user=%s err=%v", c.UserID, err)
}
break
}
// 解析消息
var msg ClientMessage
if err := json.Unmarshal(rawMessage, &msg); err != nil {
log.Printf("[WS] 消息解析失败: user=%s err=%v", c.UserID, err)
continue
}
// 处理ping
if msg.Type == "ping" {
pongMsg := ServerMessage{
Type: "pong",
Timestamp: time.Now().UnixMilli(),
}
data, _ := json.Marshal(pongMsg)
c.Send <- data
continue
}
// 调用消息处理器
if onMessage != nil {
onMessage(c, msg)
}
}
}
// WritePump 写入协程 —— 向WebSocket连接写入消息
func (c *Client) WritePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.Conn.Close()
}()
for {
select {
case message, ok := <-c.Send:
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// Hub关闭了通道
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
if err := c.Conn.WriteMessage(websocket.TextMessage, message); err != nil {
log.Printf("[WS] 写入错误: user=%s err=%v", c.UserID, err)
return
}
case <-ticker.C:
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
// SendMessage 向客户端发送消息
func (c *Client) SendMessage(msg ServerMessage) error {
data, err := json.Marshal(msg)
if err != nil {
return err
}
select {
case c.Send <- data:
return nil
default:
return nil // 通道满则丢弃
}
}
+132
View File
@@ -0,0 +1,132 @@
package ws
import (
"log"
"sync"
)
// Hub WebSocket连接池
type Hub struct {
mu sync.RWMutex
clients map[*Client]bool
broadcast chan []byte
register chan *Client
unregister chan *Client
// 按用户ID索引的客户端映射
userClients map[string]map[*Client]bool
}
// NewHub 创建WebSocket Hub
func NewHub() *Hub {
return &Hub{
clients: make(map[*Client]bool),
broadcast: make(chan []byte, 256),
register: make(chan *Client),
unregister: make(chan *Client),
userClients: make(map[string]map[*Client]bool),
}
}
// Run 启动Hub主循环
func (h *Hub) Run() {
for {
select {
case client := <-h.register:
h.mu.Lock()
h.clients[client] = true
// 用户索引
if h.userClients[client.UserID] == nil {
h.userClients[client.UserID] = make(map[*Client]bool)
}
h.userClients[client.UserID][client] = true
h.mu.Unlock()
log.Printf("[WS] 客户端连接: user=%s session=%s (当前连接数: %d)",
client.UserID, client.SessionID, len(h.clients))
case client := <-h.unregister:
h.mu.Lock()
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
close(client.Send)
// 清理用户索引
if h.userClients[client.UserID] != nil {
delete(h.userClients[client.UserID], client)
if len(h.userClients[client.UserID]) == 0 {
delete(h.userClients, client.UserID)
}
}
}
h.mu.Unlock()
log.Printf("[WS] 客户端断开: user=%s session=%s (当前连接数: %d)",
client.UserID, client.SessionID, len(h.clients))
case message := <-h.broadcast:
h.mu.RLock()
for client := range h.clients {
select {
case client.Send <- message:
default:
// 客户端发送通道已满,跳过
close(client.Send)
delete(h.clients, client)
}
}
h.mu.RUnlock()
}
}
}
// SendToUser 向指定用户的所有连接发送消息
func (h *Hub) SendToUser(userID string, message []byte) {
h.mu.RLock()
defer h.mu.RUnlock()
if clients, ok := h.userClients[userID]; ok {
for client := range clients {
select {
case client.Send <- message:
default:
// 跳过阻塞的客户端
}
}
}
}
// SendToSession 向指定会话的连接发送消息
func (h *Hub) SendToSession(userID, sessionID string, message []byte) {
h.mu.RLock()
defer h.mu.RUnlock()
if clients, ok := h.userClients[userID]; ok {
for client := range clients {
if client.SessionID == sessionID {
select {
case client.Send <- message:
default:
}
}
}
}
}
// ClientCount 获取当前连接数
func (h *Hub) ClientCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.clients)
}
// UserClientCount 获取指定用户的连接数
func (h *Hub) UserClientCount(userID string) int {
h.mu.RLock()
defer h.mu.RUnlock()
if clients, ok := h.userClients[userID]; ok {
return len(clients)
}
return 0
}
+6
View File
@@ -0,0 +1,6 @@
go 1.26.2
use (
./ai-core
./gateway
)