dev 分支暂存

This commit is contained in:
2026-05-16 08:26:56 +08:00
parent 58c8caa570
commit eb4129176c
71 changed files with 8474 additions and 214 deletions
+83
View File
@@ -0,0 +1,83 @@
package llm
import (
"context"
"io"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
)
// Adapter LLM适配器接口
// 支持不同的LLM后端(OpenAI、Ollama、vLLM等)
type Adapter struct {
provider LLMProvider
}
// LLMProvider LLM提供商接口
type LLMProvider interface {
// Chat 同步对话
Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)
// ChatStream 流式对话,返回一个channel逐token推送
ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error)
// ModelName 返回当前使用的模型名称
ModelName() string
}
// StreamChunk 流式响应的单个片段
type StreamChunk struct {
Content string // delta内容
Done bool // 是否为最后一块
Error error // 错误信息
Usage *model.Usage // 最后一块时返回token统计
}
// NewAdapter 创建LLM适配器
func NewAdapter(provider LLMProvider) *Adapter {
return &Adapter{provider: provider}
}
// Chat 同步对话
func (a *Adapter) Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) {
return a.provider.Chat(ctx, messages)
}
// ChatStream 流式对话
func (a *Adapter) ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error) {
return a.provider.ChatStream(ctx, messages)
}
// ModelName 返回模型名称
func (a *Adapter) ModelName() string {
return a.provider.ModelName()
}
// collectStream 辅助函数:将流式响应收集为完整响应
func collectStream(ch <-chan StreamChunk) (*model.LLMResponse, error) {
var content string
var lastUsage *model.Usage
for chunk := range ch {
if chunk.Error != nil {
return nil, chunk.Error
}
if chunk.Done {
lastUsage = chunk.Usage
break
}
content += chunk.Content
}
resp := &model.LLMResponse{
Content: content,
FinishReason: "stop",
}
if lastUsage != nil {
resp.Usage = *lastUsage
}
return resp, nil
}
// Ensure io is used (will be needed for SSE parsing)
var _ io.Reader
+313
View File
@@ -0,0 +1,313 @@
package llm
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
)
// OpenAIConfig OpenAI适配器配置
type OpenAIConfig struct {
BaseURL string // API基础URL
APIKey string // API密钥
Model string // 主模型
FallbackModel string // 备用模型(主模型不可用时)
MaxRetries int // 最大重试次数
Timeout time.Duration // 请求超时
}
// OpenAIProvider OpenAI兼容的LLM提供商
type OpenAIProvider struct {
config OpenAIConfig
httpClient *http.Client
}
// NewOpenAIProvider 创建OpenAI提供商
func NewOpenAIProvider(cfg OpenAIConfig) *OpenAIProvider {
if cfg.MaxRetries == 0 {
cfg.MaxRetries = 3
}
if cfg.Timeout == 0 {
cfg.Timeout = 60 * time.Second
}
return &OpenAIProvider{
config: cfg,
httpClient: &http.Client{
Timeout: cfg.Timeout,
},
}
}
// openAIRequest OpenAI请求结构
type openAIRequest struct {
Model string `json:"model"`
Messages []openAIMessage `json:"messages"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"`
}
type openAIMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// openAIResponse OpenAI响应结构
type openAIResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Choices []openAIChoice `json:"choices"`
Usage openAIUsage `json:"usage,omitempty"`
Error *openAIError `json:"error,omitempty"`
}
type openAIChoice struct {
Index int `json:"index"`
Message openAIMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type openAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type openAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code,omitempty"`
}
// Chat 同步对话
func (p *OpenAIProvider) Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) {
resp, err := p.doChat(ctx, messages, p.config.Model, false)
if err != nil {
// 尝试fallback模型
if p.config.FallbackModel != "" && p.config.FallbackModel != p.config.Model {
log.Printf("[LLM] 主模型 %s 调用失败,降级到 %s: %v", p.config.Model, p.config.FallbackModel, err)
return p.doChat(ctx, messages, p.config.FallbackModel, false)
}
return nil, err
}
return resp, nil
}
// ChatStream 流式对话
func (p *OpenAIProvider) ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error) {
ch := make(chan StreamChunk, 100)
go func() {
defer close(ch)
resp, err := p.doChatStream(ctx, messages, p.config.Model)
if err != nil {
// Fallback
if p.config.FallbackModel != "" {
log.Printf("[LLM] 流式调用主模型失败,降级: %v", err)
resp, err = p.doChatStream(ctx, messages, p.config.FallbackModel)
}
if err != nil {
ch <- StreamChunk{Error: err, Done: true}
return
}
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
// 增大scanner buffer以处理大块SSE数据
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
// SSE格式: data: {...}
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
// 流结束标记
if data == "[DONE]" {
ch <- StreamChunk{Done: true}
return
}
var streamResp openAIStreamResponse
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
continue
}
if len(streamResp.Choices) > 0 {
delta := streamResp.Choices[0].Delta
if delta.Content != "" {
ch <- StreamChunk{Content: delta.Content}
}
if streamResp.Choices[0].FinishReason != "" {
usage := &model.Usage{}
if streamResp.Usage != nil {
usage.PromptTokens = streamResp.Usage.PromptTokens
usage.CompletionTokens = streamResp.Usage.CompletionTokens
usage.TotalTokens = streamResp.Usage.TotalTokens
}
ch <- StreamChunk{Done: true, Usage: usage}
return
}
}
}
if err := scanner.Err(); err != nil {
ch <- StreamChunk{Error: fmt.Errorf("读取流式响应失败: %w", err), Done: true}
return
}
ch <- StreamChunk{Done: true}
}()
return ch, nil
}
// openAIStreamResponse 流式响应结构
type openAIStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Choices []openAIStreamChoice `json:"choices"`
Usage *openAIUsage `json:"usage,omitempty"`
}
type openAIStreamChoice struct {
Index int `json:"index"`
Delta openAIMessage `json:"delta"`
FinishReason string `json:"finish_reason"`
}
// doChat 执行同步对话请求
func (p *OpenAIProvider) doChat(ctx context.Context, messages []model.LLMMessage, model string, stream bool) (*model.LLMResponse, error) {
// 转换消息格式
oaiMessages := make([]openAIMessage, len(messages))
for i, msg := range messages {
oaiMessages[i] = openAIMessage{
Role: string(msg.Role),
Content: msg.Content,
}
}
reqBody := openAIRequest{
Model: model,
Messages: oaiMessages,
Temperature: 0.8,
Stream: stream,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.config.BaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errResp openAIResponse
if json.Unmarshal(body, &errResp) == nil && errResp.Error != nil {
return nil, fmt.Errorf("API错误 [%s]: %s", errResp.Error.Code, errResp.Error.Message)
}
return nil, fmt.Errorf("API返回状态码 %d: %s", resp.StatusCode, string(body))
}
var oaiResp openAIResponse
if err := json.Unmarshal(body, &oaiResp); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if len(oaiResp.Choices) == 0 {
return nil, fmt.Errorf("API返回空choices")
}
return &model.LLMResponse{
Content: oaiResp.Choices[0].Message.Content,
FinishReason: oaiResp.Choices[0].FinishReason,
Usage: model.Usage{
PromptTokens: oaiResp.Usage.PromptTokens,
CompletionTokens: oaiResp.Usage.CompletionTokens,
TotalTokens: oaiResp.Usage.TotalTokens,
},
}, nil
}
// doChatStream 执行流式对话请求(返回原始HTTP响应)
func (p *OpenAIProvider) doChatStream(ctx context.Context, messages []model.LLMMessage, model string) (*http.Response, error) {
oaiMessages := make([]openAIMessage, len(messages))
for i, msg := range messages {
oaiMessages[i] = openAIMessage{
Role: string(msg.Role),
Content: msg.Content,
}
}
reqBody := openAIRequest{
Model: model,
Messages: oaiMessages,
Temperature: 0.8,
Stream: true,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.config.BaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
req.Header.Set("Accept", "text/event-stream")
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API返回状态码 %d: %s", resp.StatusCode, string(body))
}
return resp, nil
}
// ModelName 返回模型名称
func (p *OpenAIProvider) ModelName() string {
return p.config.Model
}
+191
View File
@@ -0,0 +1,191 @@
package llm
import (
"strings"
"sync"
"unicode"
)
// Segmenter 断句器 —— 将流式文本按句号切分为语音播放片段
type Segmenter struct {
mu sync.Mutex
buffer strings.Builder
segments []Segment
index int
}
// Segment 语音片段
type Segment struct {
Index int `json:"index"`
Text string `json:"text"`
}
// NewSegmenter 创建断句器
func NewSegmenter() *Segmenter {
return &Segmenter{}
}
// Feed 喂入新的文本片段
// 返回已完成的断句列表
func (s *Segmenter) Feed(delta string) []Segment {
s.mu.Lock()
defer s.mu.Unlock()
s.buffer.WriteString(delta)
content := s.buffer.String()
var newSegments []Segment
for {
idx := findSentenceEnd(content)
if idx == -1 {
break
}
segmentText := strings.TrimSpace(content[:idx+len(string(content[idx]))])
// 检查是否是完整中文字符的句末
// idx 指向标点符号的位置
runes := []rune(content)
var byteIdx int
for i, r := range runes {
if i == idx {
// 标点之后的字符
break
}
byteIdx += len(string(r))
}
// 简化处理:直接取到idx+1字节 (对于ASCII标点)
// 对于中文标点,需要用rune处理
realIdx := 0
runeCount := 0
for i, r := range content {
if runeCount == idx {
realIdx = i
break
}
runeCount++
_ = r
}
// 包含标点符号本身
endIdx := realIdx + len(string([]rune(content)[idx]))
if endIdx <= realIdx {
endIdx = realIdx + 3 // fallback for UTF-8 multi-byte
}
segmentText = strings.TrimSpace(content[:endIdx])
if segmentText == "" {
content = strings.TrimSpace(content[endIdx:])
s.buffer.Reset()
s.buffer.WriteString(content)
continue
}
s.index++
seg := Segment{
Index: s.index,
Text: segmentText,
}
s.segments = append(s.segments, seg)
newSegments = append(newSegments, seg)
// 更新buffer,移除已处理的部分
content = strings.TrimSpace(content[endIdx:])
s.buffer.Reset()
s.buffer.WriteString(content)
}
return newSegments
}
// Flush 强制输出buffer中剩余的内容
func (s *Segmenter) Flush() *Segment {
s.mu.Lock()
defer s.mu.Unlock()
remaining := strings.TrimSpace(s.buffer.String())
if remaining == "" {
return nil
}
s.index++
seg := Segment{
Index: s.index,
Text: remaining,
}
s.segments = append(s.segments, seg)
s.buffer.Reset()
return &seg
}
// AllSegments 返回所有已完成的断句
func (s *Segmenter) AllSegments() []Segment {
s.mu.Lock()
defer s.mu.Unlock()
result := make([]Segment, len(s.segments))
copy(result, s.segments)
return result
}
// findSentenceEnd 查找句子结束位置(返回标点符号在rune数组中的索引)
// 中文标点:。!? 英文标点:. ! ?
func findSentenceEnd(text string) int {
runes := []rune(text)
for i, r := range runes {
if isSentenceEnd(r) {
return i
}
}
return -1
}
// isSentenceEnd 判断是否为句末标点
func isSentenceEnd(r rune) bool {
switch r {
case '。', '', '', '.', '!', '?', '\n':
return true
}
return false
}
// splitIntoSegments 将完整文本按句号断句(用于post-processing
func splitIntoSegments(text string) []Segment {
var segments []Segment
runes := []rune(text)
start := 0
index := 0
for i, r := range runes {
if isSentenceEnd(r) {
segText := strings.TrimSpace(string(runes[start : i+1]))
if segText != "" {
index++
segments = append(segments, Segment{
Index: index,
Text: segText,
})
}
start = i + 1
}
}
// 处理末尾无标点的剩余文本
if start < len(runes) {
remaining := strings.TrimSpace(string(runes[start:]))
if remaining != "" {
index++
segments = append(segments, Segment{
Index: index,
Text: remaining,
})
}
}
return segments
}
// Ensure unicode is used
var _ = unicode.Is