Files
Cyrene/backend/ai-core/internal/llm/openai.go
T
AskaEth 91c9ee4b2d fix: 修复 AI 回复无法送达发送者 + 重复消息 + action角色泄露 + OS环境支持
广播逻辑重构:
- AI 回复 (stream_start/response/stream_segments/multi_message/stream_end) 改用 broadcastToUser 发送给所有客户端
- 用户消息回显保持 broadcastToUserExcept 排除发送者

消息去重与角色修复:
- CacheMessage(user) 移至回复生成后,避免本轮 LLM 调用出现重复用户消息
- action 角色消息在 DB 存储时映射为 assistant,DeepSeek 等模型不支持自定义角色
- stream_end defer 机制确保错误路径也会终止客户端思考指示器

OS 完整环境支持:
- host 包重构为 HostBackend 接口 + Direct/WSL/Docker 三种后端
- 新增 os_exec/os_file/os_system 工具供 AI 在完整 Linux 环境中自由操作

其他:
- 视觉模型注入 + 图片预处理后清空 Images 避免传给 Chat 模型
- 图片 URL 相对路径→绝对 URL 转换
- DevTools 链路追踪页面 + 重启修复
- 记忆搜索模糊匹配增强
- 后台思考定时调度支持
- 管理后台页面 (模型配置/用户管理等)
- docs/api 更新广播机制说明

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-29 12:46:17 +08:00

545 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package llm
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
"github.com/yourname/cyrene-ai/pkg/logger"
)
// OpenAIConfig OpenAI适配器配置
type OpenAIConfig struct {
BaseURL string // API基础URL
APIKey string // API密钥
Model string // 主模型
FallbackModel string // 备用模型(主模型不可用时)
MaxRetries int // 最大重试次数
Timeout time.Duration // 请求超时
}
// OpenAIProvider OpenAI兼容的LLM提供商
type OpenAIProvider struct {
config OpenAIConfig
httpClient *http.Client
}
// NewOpenAIProvider 创建OpenAI提供商
func NewOpenAIProvider(cfg OpenAIConfig) *OpenAIProvider {
if cfg.MaxRetries == 0 {
cfg.MaxRetries = 3
}
if cfg.Timeout == 0 {
cfg.Timeout = 60 * time.Second
}
return &OpenAIProvider{
config: cfg,
httpClient: &http.Client{
Timeout: cfg.Timeout,
},
}
}
// openAIRequest OpenAI请求结构
type openAIRequest struct {
Model string `json:"model"`
Messages []openAIMessage `json:"messages"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream"`
Tools []OpenAITool `json:"tools,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"` // "auto", "none", or specific tool
}
type openAIMessage struct {
Role string `json:"role"`
Content interface{} `json:"content,omitempty"` // string or []model.ImageContent for multimodal
Name string `json:"name,omitempty"`
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"` // DeepSeek 思考链
}
// openAIToolCall OpenAI工具调用
type openAIToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function openAIToolCallFunction `json:"function"`
}
type openAIToolCallFunction struct {
Name string `json:"name"`
Arguments string `json:"arguments"` // JSON string
}
// openAIResponse OpenAI响应结构
type openAIResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Choices []openAIChoice `json:"choices"`
Usage openAIUsage `json:"usage,omitempty"`
Error *openAIError `json:"error,omitempty"`
}
type openAIChoice struct {
Index int `json:"index"`
Message openAIMessage `json:"message"`
Delta openAIMessage `json:"delta,omitempty"`
FinishReason string `json:"finish_reason"`
}
type openAIUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type openAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code,omitempty"`
}
// Chat 同步对话
func (p *OpenAIProvider) Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) {
return p.ChatWithTools(ctx, messages, nil)
}
// ChatWithTools 同步对话(支持工具调用)
func (p *OpenAIProvider) ChatWithTools(ctx context.Context, messages []model.LLMMessage, tools []OpenAITool) (*model.LLMResponse, error) {
resp, err := p.doChat(ctx, messages, p.config.Model, false, tools)
if err != nil {
// 尝试fallback模型
if p.config.FallbackModel != "" && p.config.FallbackModel != p.config.Model {
logger.Printf("[LLM] 主模型 %s 调用失败,降级到 %s: %v", p.config.Model, p.config.FallbackModel, err)
return p.doChat(ctx, messages, p.config.FallbackModel, false, tools)
}
return nil, err
}
return resp, nil
}
// ChatStream 流式对话
func (p *OpenAIProvider) ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error) {
return p.ChatStreamWithTools(ctx, messages, nil)
}
// ChatStreamWithTools 流式对话(支持工具调用)
func (p *OpenAIProvider) ChatStreamWithTools(ctx context.Context, messages []model.LLMMessage, tools []OpenAITool) (<-chan StreamChunk, error) {
ch := make(chan StreamChunk, 100)
go func() {
defer close(ch)
startTime := time.Now()
modelName := p.config.Model
var streamErr error
var finalUsage *model.Usage
defer func() {
r := CallRecord{
Model: modelName,
Duration: time.Since(startTime),
Success: streamErr == nil,
}
if streamErr != nil {
r.Error = streamErr.Error()
}
if finalUsage != nil {
r.PromptTokens = finalUsage.PromptTokens
r.CompletionTokens = finalUsage.CompletionTokens
r.TotalTokens = finalUsage.TotalTokens
}
LogCall(r)
}()
resp, err := p.doChatStream(ctx, messages, p.config.Model, tools)
if err != nil {
// Fallback
if p.config.FallbackModel != "" {
logger.Printf("[LLM] 流式调用主模型失败,降级: %v", err)
modelName = p.config.FallbackModel
resp, err = p.doChatStream(ctx, messages, p.config.FallbackModel, tools)
}
if err != nil {
streamErr = err
ch <- StreamChunk{Error: err, Done: true}
return
}
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
// 增大scanner buffer以处理大块SSE数据
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
// SSE格式: data: {...}
if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
// 流结束标记
if data == "[DONE]" {
ch <- StreamChunk{Done: true}
return
}
var streamResp openAIStreamResponse
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
continue
}
if len(streamResp.Choices) > 0 {
delta := streamResp.Choices[0].Delta
if deltaStr := contentString(delta.Content); deltaStr != "" {
ch <- StreamChunk{Content: deltaStr}
}
if streamResp.Choices[0].FinishReason != "" {
if streamResp.Usage != nil {
finalUsage = &model.Usage{
PromptTokens: streamResp.Usage.PromptTokens,
CompletionTokens: streamResp.Usage.CompletionTokens,
TotalTokens: streamResp.Usage.TotalTokens,
}
}
ch <- StreamChunk{Done: true, Usage: finalUsage}
return
}
}
}
if err := scanner.Err(); err != nil {
streamErr = fmt.Errorf("读取流式响应失败: %w", err)
ch <- StreamChunk{Error: streamErr, Done: true}
return
}
ch <- StreamChunk{Done: true}
}()
return ch, nil
}
// openAIStreamResponse 流式响应结构
type openAIStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Choices []openAIStreamChoice `json:"choices"`
Usage *openAIUsage `json:"usage,omitempty"`
}
type openAIStreamChoice struct {
Index int `json:"index"`
Delta openAIMessage `json:"delta"`
FinishReason string `json:"finish_reason"`
}
// doChat 执行同步对话请求
func (p *OpenAIProvider) doChat(ctx context.Context, messages []model.LLMMessage, modelName string, stream bool, tools []OpenAITool) (llmResp *model.LLMResponse, err error) {
startTime := time.Now()
defer func() {
r := CallRecord{
Model: modelName,
Duration: time.Since(startTime),
Success: err == nil,
}
if err != nil {
r.Error = err.Error()
}
if llmResp != nil {
r.PromptTokens = llmResp.Usage.PromptTokens
r.CompletionTokens = llmResp.Usage.CompletionTokens
r.TotalTokens = llmResp.Usage.TotalTokens
}
LogCall(r)
}()
// 转换消息格式(先解析图片 URL 为 data URL
oaiMessages := make([]openAIMessage, len(messages))
for i, msg := range messages {
resolvedImages := p.resolveImages(msg.Images)
oaiMsg := openAIMessage{
Role: string(msg.Role),
Content: buildContent(msg.Content, resolvedImages),
Name: msg.Name,
ToolCallID: msg.ToolCallID,
ReasoningContent: msg.ReasoningContent,
}
// 转换工具调用
if len(msg.ToolCalls) > 0 {
oaiMsg.ToolCalls = make([]openAIToolCall, len(msg.ToolCalls))
for j, tc := range msg.ToolCalls {
oaiMsg.ToolCalls[j] = openAIToolCall{
ID: tc.ID,
Type: "function",
Function: openAIToolCallFunction{
Name: tc.Name,
Arguments: tc.Arguments,
},
}
}
}
oaiMessages[i] = oaiMsg
}
reqBody := openAIRequest{
Model: modelName,
Messages: oaiMessages,
Temperature: 0.8,
Stream: stream,
Tools: tools,
}
if len(tools) > 0 {
reqBody.ToolChoice = "auto"
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.config.BaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errResp openAIResponse
if json.Unmarshal(body, &errResp) == nil && errResp.Error != nil {
return nil, fmt.Errorf("API错误 [%s]: %s", errResp.Error.Code, errResp.Error.Message)
}
return nil, fmt.Errorf("API返回状态码 %d: %s", resp.StatusCode, string(body))
}
var oaiResp openAIResponse
if err := json.Unmarshal(body, &oaiResp); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if len(oaiResp.Choices) == 0 {
return nil, fmt.Errorf("API返回空choices")
}
// 检查是否有工具调用
choice := oaiResp.Choices[0]
llmResp = &model.LLMResponse{
Content: contentString(choice.Message.Content),
FinishReason: choice.FinishReason,
ReasoningContent: choice.Message.ReasoningContent,
Usage: model.Usage{
PromptTokens: oaiResp.Usage.PromptTokens,
CompletionTokens: oaiResp.Usage.CompletionTokens,
TotalTokens: oaiResp.Usage.TotalTokens,
},
}
if len(choice.Message.ToolCalls) > 0 {
llmResp.ToolCalls = make([]model.ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
llmResp.ToolCalls = append(llmResp.ToolCalls, model.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
})
}
}
return llmResp, nil
}
// doChatStream 执行流式对话请求(返回原始HTTP响应)
func (p *OpenAIProvider) doChatStream(ctx context.Context, messages []model.LLMMessage, modelName string, tools []OpenAITool) (*http.Response, error) {
oaiMessages := make([]openAIMessage, len(messages))
for i, msg := range messages {
resolvedImages := p.resolveImages(msg.Images)
oaiMsg := openAIMessage{
Role: string(msg.Role),
Content: buildContent(msg.Content, resolvedImages),
Name: msg.Name,
ToolCallID: msg.ToolCallID,
ReasoningContent: msg.ReasoningContent,
}
if len(msg.ToolCalls) > 0 {
oaiMsg.ToolCalls = make([]openAIToolCall, len(msg.ToolCalls))
for j, tc := range msg.ToolCalls {
oaiMsg.ToolCalls[j] = openAIToolCall{
ID: tc.ID,
Type: "function",
Function: openAIToolCallFunction{
Name: tc.Name,
Arguments: tc.Arguments,
},
}
}
}
oaiMessages[i] = oaiMsg
}
reqBody := openAIRequest{
Model: modelName,
Messages: oaiMessages,
Temperature: 0.8,
Stream: true,
Tools: tools,
}
if len(tools) > 0 {
reqBody.ToolChoice = "auto"
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", p.config.BaseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.config.APIKey)
req.Header.Set("Accept", "text/event-stream")
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API返回状态码 %d: %s", resp.StatusCode, string(body))
}
return resp, nil
}
// ModelName 返回模型名称
func (p *OpenAIProvider) ModelName() string {
return p.config.Model
}
// contentString extracts a string from an interface{} Content value.
func contentString(v interface{}) string {
if v == nil {
return ""
}
if s, ok := v.(string); ok {
return s
}
return ""
}
// resolveImages converts non-data URLs to base64 data URLs so external LLM APIs can access them.
func (p *OpenAIProvider) resolveImages(images []string) []string {
if len(images) == 0 {
return images
}
resolved := make([]string, 0, len(images))
for _, img := range images {
if strings.HasPrefix(img, "data:") {
resolved = append(resolved, img)
continue
}
dataURL, err := p.downloadAsDataURL(img)
if err != nil {
logger.Printf("[openai] 图片下载失败, 保留原始 URL: %s, err=%v", img, err)
resolved = append(resolved, img) // 保留原始 URL 作为 fallback
continue
}
resolved = append(resolved, dataURL)
}
return resolved
}
// downloadAsDataURL downloads an image from a URL and returns it as a base64 data URL.
func (p *OpenAIProvider) downloadAsDataURL(url string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("下载失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
// 限制最大 20MB
const maxSize = 20 * 1024 * 1024
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize+1))
if err != nil {
return "", fmt.Errorf("读取失败: %w", err)
}
if len(body) > maxSize {
return "", fmt.Errorf("图片过大: %d bytes", len(body))
}
mimeType := resp.Header.Get("Content-Type")
if mimeType == "" {
mimeType = http.DetectContentType(body)
}
b64 := base64.StdEncoding.EncodeToString(body)
return fmt.Sprintf("data:%s;base64,%s", mimeType, b64), nil
}
// buildContent converts text + optional images to API content format.
// Returns a plain string if no images, or a multimodal array otherwise.
func buildContent(text string, images []string) interface{} {
if len(images) == 0 {
return text
}
parts := make([]model.ImageContent, 0, len(images)+1)
if text != "" {
parts = append(parts, model.ImageContent{
Type: "text",
Text: text,
})
}
for _, img := range images {
parts = append(parts, model.ImageContent{
Type: "image_url",
ImageURL: &model.ImageURL{
URL: img,
},
})
}
return parts
}