package llm import ( "bufio" "bytes" "context" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "strings" "time" "git.yeij.top/AskaEth/Cyrene/ai-core/internal/model" "git.yeij.top/AskaEth/Cyrene/pkg/logger" ) // OpenAIConfig OpenAI适配器配置 type OpenAIConfig struct { BaseURL string // API基础URL APIKey string // API密钥 Model string // 主模型 FallbackModel string // 备用模型(主模型不可用时) MaxRetries int // 最大重试次数 Timeout time.Duration // 请求超时 } // OpenAIProvider OpenAI兼容的LLM提供商 type OpenAIProvider struct { config OpenAIConfig httpClient *http.Client } // NewOpenAIProvider 创建OpenAI提供商 func NewOpenAIProvider(cfg OpenAIConfig) *OpenAIProvider { if cfg.MaxRetries == 0 { cfg.MaxRetries = 3 } if cfg.Timeout == 0 { cfg.Timeout = 60 * time.Second } return &OpenAIProvider{ config: cfg, httpClient: &http.Client{ Timeout: cfg.Timeout, }, } } // openAIRequest OpenAI请求结构 type openAIRequest struct { Model string `json:"model"` Messages []openAIMessage `json:"messages"` Temperature float64 `json:"temperature"` MaxTokens int `json:"max_tokens,omitempty"` Stream bool `json:"stream"` Tools []OpenAITool `json:"tools,omitempty"` ToolChoice string `json:"tool_choice,omitempty"` // "auto", "none", or specific tool } type openAIMessage struct { Role string `json:"role"` Content interface{} `json:"content,omitempty"` // string or []model.ImageContent for multimodal Name string `json:"name,omitempty"` ToolCalls []openAIToolCall `json:"tool_calls,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"` // DeepSeek 思考链 } // openAIToolCall OpenAI工具调用 type openAIToolCall struct { ID string `json:"id"` Type string `json:"type"` Function openAIToolCallFunction `json:"function"` } type openAIToolCallFunction struct { Name string `json:"name"` Arguments string `json:"arguments"` // JSON string } // openAIResponse OpenAI响应结构 type openAIResponse struct { ID string `json:"id"` Object string `json:"object"` Choices []openAIChoice `json:"choices"` Usage openAIUsage `json:"usage,omitempty"` Error *openAIError `json:"error,omitempty"` } type openAIChoice struct { Index int `json:"index"` Message openAIMessage `json:"message"` Delta openAIMessage `json:"delta,omitempty"` FinishReason string `json:"finish_reason"` } type openAIUsage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } type openAIError struct { Message string `json:"message"` Type string `json:"type"` Code string `json:"code,omitempty"` } // Chat 同步对话 func (p *OpenAIProvider) Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) { return p.ChatWithTools(ctx, messages, nil) } // ChatWithTools 同步对话(支持工具调用) func (p *OpenAIProvider) ChatWithTools(ctx context.Context, messages []model.LLMMessage, tools []OpenAITool) (*model.LLMResponse, error) { resp, err := p.doChat(ctx, messages, p.config.Model, false, tools) if err != nil { // 尝试fallback模型 if p.config.FallbackModel != "" && p.config.FallbackModel != p.config.Model { logger.Printf("[LLM] 主模型 %s 调用失败,降级到 %s: %v", p.config.Model, p.config.FallbackModel, err) return p.doChat(ctx, messages, p.config.FallbackModel, false, tools) } return nil, err } return resp, nil } // ChatStream 流式对话 func (p *OpenAIProvider) ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error) { return p.ChatStreamWithTools(ctx, messages, nil) } // ChatStreamWithTools 流式对话(支持工具调用) func (p *OpenAIProvider) ChatStreamWithTools(ctx context.Context, messages []model.LLMMessage, tools []OpenAITool) (<-chan StreamChunk, error) { ch := make(chan StreamChunk, 100) go func() { defer close(ch) startTime := time.Now() modelName := p.config.Model var streamErr error var finalUsage *model.Usage defer func() { r := CallRecord{ Model: modelName, Duration: time.Since(startTime), Success: streamErr == nil, } if streamErr != nil { r.Error = streamErr.Error() } if finalUsage != nil { r.PromptTokens = finalUsage.PromptTokens r.CompletionTokens = finalUsage.CompletionTokens r.TotalTokens = finalUsage.TotalTokens } LogCall(r) }() resp, err := p.doChatStream(ctx, messages, p.config.Model, tools) if err != nil { // Fallback if p.config.FallbackModel != "" { logger.Printf("[LLM] 流式调用主模型失败,降级: %v", err) modelName = p.config.FallbackModel resp, err = p.doChatStream(ctx, messages, p.config.FallbackModel, tools) } if err != nil { streamErr = err ch <- StreamChunk{Error: err, Done: true} return } } defer resp.Body.Close() scanner := bufio.NewScanner(resp.Body) // 增大scanner buffer以处理大块SSE数据 scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) for scanner.Scan() { line := scanner.Text() // SSE格式: data: {...} if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") // 流结束标记 if data == "[DONE]" { ch <- StreamChunk{Done: true} return } var streamResp openAIStreamResponse if err := json.Unmarshal([]byte(data), &streamResp); err != nil { continue } if len(streamResp.Choices) > 0 { delta := streamResp.Choices[0].Delta if deltaStr := contentString(delta.Content); deltaStr != "" { ch <- StreamChunk{Content: deltaStr} } if streamResp.Choices[0].FinishReason != "" { if streamResp.Usage != nil { finalUsage = &model.Usage{ PromptTokens: streamResp.Usage.PromptTokens, CompletionTokens: streamResp.Usage.CompletionTokens, TotalTokens: streamResp.Usage.TotalTokens, } } ch <- StreamChunk{Done: true, Usage: finalUsage} return } } } if err := scanner.Err(); err != nil { streamErr = fmt.Errorf("读取流式响应失败: %w", err) ch <- StreamChunk{Error: streamErr, Done: true} return } ch <- StreamChunk{Done: true} }() return ch, nil } // openAIStreamResponse 流式响应结构 type openAIStreamResponse struct { ID string `json:"id"` Object string `json:"object"` Choices []openAIStreamChoice `json:"choices"` Usage *openAIUsage `json:"usage,omitempty"` } type openAIStreamChoice struct { Index int `json:"index"` Delta openAIMessage `json:"delta"` FinishReason string `json:"finish_reason"` } // doChat 执行同步对话请求 func (p *OpenAIProvider) doChat(ctx context.Context, messages []model.LLMMessage, modelName string, stream bool, tools []OpenAITool) (llmResp *model.LLMResponse, err error) { startTime := time.Now() defer func() { r := CallRecord{ Model: modelName, Duration: time.Since(startTime), Success: err == nil, } if err != nil { r.Error = err.Error() } if llmResp != nil { r.PromptTokens = llmResp.Usage.PromptTokens r.CompletionTokens = llmResp.Usage.CompletionTokens r.TotalTokens = llmResp.Usage.TotalTokens } LogCall(r) }() // 转换消息格式(先解析图片 URL 为 data URL) oaiMessages := make([]openAIMessage, len(messages)) for i, msg := range messages { resolvedImages := p.resolveImages(msg.Images) oaiMsg := openAIMessage{ Role: string(msg.Role), Content: buildContent(msg.Content, resolvedImages), Name: msg.Name, ToolCallID: msg.ToolCallID, ReasoningContent: msg.ReasoningContent, } // 转换工具调用 if len(msg.ToolCalls) > 0 { oaiMsg.ToolCalls = make([]openAIToolCall, len(msg.ToolCalls)) for j, tc := range msg.ToolCalls { oaiMsg.ToolCalls[j] = openAIToolCall{ ID: tc.ID, Type: "function", Function: openAIToolCallFunction{ Name: tc.Name, Arguments: tc.Arguments, }, } } } oaiMessages[i] = oaiMsg } reqBody := openAIRequest{ Model: modelName, Messages: oaiMessages, Temperature: 0.8, Stream: stream, Tools: tools, } if len(tools) > 0 { reqBody.ToolChoice = "auto" } jsonBody, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("序列化请求失败: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", p.config.BaseURL+"/chat/completions", bytes.NewReader(jsonBody)) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+p.config.APIKey) resp, err := p.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("请求失败: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("读取响应失败: %w", err) } if resp.StatusCode != http.StatusOK { var errResp openAIResponse if json.Unmarshal(body, &errResp) == nil && errResp.Error != nil { return nil, fmt.Errorf("API错误 [%s]: %s", errResp.Error.Code, errResp.Error.Message) } return nil, fmt.Errorf("API返回状态码 %d: %s", resp.StatusCode, string(body)) } var oaiResp openAIResponse if err := json.Unmarshal(body, &oaiResp); err != nil { return nil, fmt.Errorf("解析响应失败: %w", err) } if len(oaiResp.Choices) == 0 { return nil, fmt.Errorf("API返回空choices") } // 检查是否有工具调用 choice := oaiResp.Choices[0] llmResp = &model.LLMResponse{ Content: contentString(choice.Message.Content), FinishReason: choice.FinishReason, ReasoningContent: choice.Message.ReasoningContent, Usage: model.Usage{ PromptTokens: oaiResp.Usage.PromptTokens, CompletionTokens: oaiResp.Usage.CompletionTokens, TotalTokens: oaiResp.Usage.TotalTokens, }, } if len(choice.Message.ToolCalls) > 0 { llmResp.ToolCalls = make([]model.ToolCall, 0, len(choice.Message.ToolCalls)) for _, tc := range choice.Message.ToolCalls { llmResp.ToolCalls = append(llmResp.ToolCalls, model.ToolCall{ ID: tc.ID, Name: tc.Function.Name, Arguments: tc.Function.Arguments, }) } } return llmResp, nil } // doChatStream 执行流式对话请求(返回原始HTTP响应) func (p *OpenAIProvider) doChatStream(ctx context.Context, messages []model.LLMMessage, modelName string, tools []OpenAITool) (*http.Response, error) { oaiMessages := make([]openAIMessage, len(messages)) for i, msg := range messages { resolvedImages := p.resolveImages(msg.Images) oaiMsg := openAIMessage{ Role: string(msg.Role), Content: buildContent(msg.Content, resolvedImages), Name: msg.Name, ToolCallID: msg.ToolCallID, ReasoningContent: msg.ReasoningContent, } if len(msg.ToolCalls) > 0 { oaiMsg.ToolCalls = make([]openAIToolCall, len(msg.ToolCalls)) for j, tc := range msg.ToolCalls { oaiMsg.ToolCalls[j] = openAIToolCall{ ID: tc.ID, Type: "function", Function: openAIToolCallFunction{ Name: tc.Name, Arguments: tc.Arguments, }, } } } oaiMessages[i] = oaiMsg } reqBody := openAIRequest{ Model: modelName, Messages: oaiMessages, Temperature: 0.8, Stream: true, Tools: tools, } if len(tools) > 0 { reqBody.ToolChoice = "auto" } jsonBody, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("序列化请求失败: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", p.config.BaseURL+"/chat/completions", bytes.NewReader(jsonBody)) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+p.config.APIKey) req.Header.Set("Accept", "text/event-stream") resp, err := p.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("请求失败: %w", err) } if resp.StatusCode != http.StatusOK { defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("API返回状态码 %d: %s", resp.StatusCode, string(body)) } return resp, nil } // ModelName 返回模型名称 func (p *OpenAIProvider) ModelName() string { return p.config.Model } // contentString extracts a string from an interface{} Content value. func contentString(v interface{}) string { if v == nil { return "" } if s, ok := v.(string); ok { return s } return "" } // resolveImages converts non-data URLs to base64 data URLs so external LLM APIs can access them. func (p *OpenAIProvider) resolveImages(images []string) []string { if len(images) == 0 { return images } resolved := make([]string, 0, len(images)) for _, img := range images { if strings.HasPrefix(img, "data:") { resolved = append(resolved, img) continue } dataURL, err := p.downloadAsDataURL(img) if err != nil { logger.Printf("[openai] 图片下载失败, 保留原始 URL: %s, err=%v", img, err) resolved = append(resolved, img) // 保留原始 URL 作为 fallback continue } resolved = append(resolved, dataURL) } return resolved } // downloadAsDataURL downloads an image from a URL and returns it as a base64 data URL. func (p *OpenAIProvider) downloadAsDataURL(url string) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return "", fmt.Errorf("创建请求失败: %w", err) } resp, err := p.httpClient.Do(req) if err != nil { return "", fmt.Errorf("下载失败: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("HTTP %d", resp.StatusCode) } // 限制最大 20MB const maxSize = 20 * 1024 * 1024 body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize+1)) if err != nil { return "", fmt.Errorf("读取失败: %w", err) } if len(body) > maxSize { return "", fmt.Errorf("图片过大: %d bytes", len(body)) } mimeType := resp.Header.Get("Content-Type") if mimeType == "" { mimeType = http.DetectContentType(body) } b64 := base64.StdEncoding.EncodeToString(body) return fmt.Sprintf("data:%s;base64,%s", mimeType, b64), nil } // buildContent converts text + optional images to API content format. // Returns a plain string if no images, or a multimodal array otherwise. func buildContent(text string, images []string) interface{} { if len(images) == 0 { return text } parts := make([]model.ImageContent, 0, len(images)+1) if text != "" { parts = append(parts, model.ImageContent{ Type: "text", Text: text, }) } for _, img := range images { parts = append(parts, model.ImageContent{ Type: "image_url", ImageURL: &model.ImageURL{ URL: img, }, }) } return parts }