package llm import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "github.com/yourname/cyrene-ai/pkg/logger" "net/http" "strings" "time" "github.com/yourname/cyrene-ai/ai-core/internal/model" ) // OpenAIConfig OpenAI适配器配置 type OpenAIConfig struct { BaseURL string // API基础URL APIKey string // API密钥 Model string // 主模型 FallbackModel string // 备用模型(主模型不可用时) MaxRetries int // 最大重试次数 Timeout time.Duration // 请求超时 } // OpenAIProvider OpenAI兼容的LLM提供商 type OpenAIProvider struct { config OpenAIConfig httpClient *http.Client } // NewOpenAIProvider 创建OpenAI提供商 func NewOpenAIProvider(cfg OpenAIConfig) *OpenAIProvider { if cfg.MaxRetries == 0 { cfg.MaxRetries = 3 } if cfg.Timeout == 0 { cfg.Timeout = 60 * time.Second } return &OpenAIProvider{ config: cfg, httpClient: &http.Client{ Timeout: cfg.Timeout, }, } } // openAIRequest OpenAI请求结构 type openAIRequest struct { Model string `json:"model"` Messages []openAIMessage `json:"messages"` Temperature float64 `json:"temperature"` MaxTokens int `json:"max_tokens,omitempty"` Stream bool `json:"stream"` Tools []OpenAITool `json:"tools,omitempty"` ToolChoice string `json:"tool_choice,omitempty"` // "auto", "none", or specific tool } type openAIMessage struct { Role string `json:"role"` Content string `json:"content,omitempty"` Name string `json:"name,omitempty"` ToolCalls []openAIToolCall `json:"tool_calls,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"` ReasoningContent string `json:"reasoning_content,omitempty"` // DeepSeek 思考链 } // openAIToolCall OpenAI工具调用 type openAIToolCall struct { ID string `json:"id"` Type string `json:"type"` Function openAIToolCallFunction `json:"function"` } type openAIToolCallFunction struct { Name string `json:"name"` Arguments string `json:"arguments"` // JSON string } // openAIResponse OpenAI响应结构 type openAIResponse struct { ID string `json:"id"` Object string `json:"object"` Choices []openAIChoice `json:"choices"` Usage openAIUsage `json:"usage,omitempty"` Error *openAIError `json:"error,omitempty"` } type openAIChoice struct { Index int `json:"index"` Message openAIMessage `json:"message"` Delta openAIMessage `json:"delta,omitempty"` FinishReason string `json:"finish_reason"` } type openAIUsage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } type openAIError struct { Message string `json:"message"` Type string `json:"type"` Code string `json:"code,omitempty"` } // Chat 同步对话 func (p *OpenAIProvider) Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) { return p.ChatWithTools(ctx, messages, nil) } // ChatWithTools 同步对话(支持工具调用) func (p *OpenAIProvider) ChatWithTools(ctx context.Context, messages []model.LLMMessage, tools []OpenAITool) (*model.LLMResponse, error) { resp, err := p.doChat(ctx, messages, p.config.Model, false, tools) if err != nil { // 尝试fallback模型 if p.config.FallbackModel != "" && p.config.FallbackModel != p.config.Model { logger.Printf("[LLM] 主模型 %s 调用失败,降级到 %s: %v", p.config.Model, p.config.FallbackModel, err) return p.doChat(ctx, messages, p.config.FallbackModel, false, tools) } return nil, err } return resp, nil } // ChatStream 流式对话 func (p *OpenAIProvider) ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error) { return p.ChatStreamWithTools(ctx, messages, nil) } // ChatStreamWithTools 流式对话(支持工具调用) func (p *OpenAIProvider) ChatStreamWithTools(ctx context.Context, messages []model.LLMMessage, tools []OpenAITool) (<-chan StreamChunk, error) { ch := make(chan StreamChunk, 100) go func() { defer close(ch) resp, err := p.doChatStream(ctx, messages, p.config.Model, tools) if err != nil { // Fallback if p.config.FallbackModel != "" { logger.Printf("[LLM] 流式调用主模型失败,降级: %v", err) resp, err = p.doChatStream(ctx, messages, p.config.FallbackModel, tools) } if err != nil { ch <- StreamChunk{Error: err, Done: true} return } } defer resp.Body.Close() scanner := bufio.NewScanner(resp.Body) // 增大scanner buffer以处理大块SSE数据 scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) for scanner.Scan() { line := scanner.Text() // SSE格式: data: {...} if !strings.HasPrefix(line, "data: ") { continue } data := strings.TrimPrefix(line, "data: ") // 流结束标记 if data == "[DONE]" { ch <- StreamChunk{Done: true} return } var streamResp openAIStreamResponse if err := json.Unmarshal([]byte(data), &streamResp); err != nil { continue } if len(streamResp.Choices) > 0 { delta := streamResp.Choices[0].Delta if delta.Content != "" { ch <- StreamChunk{Content: delta.Content} } if streamResp.Choices[0].FinishReason != "" { usage := &model.Usage{} if streamResp.Usage != nil { usage.PromptTokens = streamResp.Usage.PromptTokens usage.CompletionTokens = streamResp.Usage.CompletionTokens usage.TotalTokens = streamResp.Usage.TotalTokens } ch <- StreamChunk{Done: true, Usage: usage} return } } } if err := scanner.Err(); err != nil { ch <- StreamChunk{Error: fmt.Errorf("读取流式响应失败: %w", err), Done: true} return } ch <- StreamChunk{Done: true} }() return ch, nil } // openAIStreamResponse 流式响应结构 type openAIStreamResponse struct { ID string `json:"id"` Object string `json:"object"` Choices []openAIStreamChoice `json:"choices"` Usage *openAIUsage `json:"usage,omitempty"` } type openAIStreamChoice struct { Index int `json:"index"` Delta openAIMessage `json:"delta"` FinishReason string `json:"finish_reason"` } // doChat 执行同步对话请求 func (p *OpenAIProvider) doChat(ctx context.Context, messages []model.LLMMessage, modelName string, stream bool, tools []OpenAITool) (*model.LLMResponse, error) { // 转换消息格式 oaiMessages := make([]openAIMessage, len(messages)) for i, msg := range messages { oaiMsg := openAIMessage{ Role: string(msg.Role), Content: msg.Content, Name: msg.Name, ToolCallID: msg.ToolCallID, ReasoningContent: msg.ReasoningContent, } // 转换工具调用 if len(msg.ToolCalls) > 0 { oaiMsg.ToolCalls = make([]openAIToolCall, len(msg.ToolCalls)) for j, tc := range msg.ToolCalls { oaiMsg.ToolCalls[j] = openAIToolCall{ ID: tc.ID, Type: "function", Function: openAIToolCallFunction{ Name: tc.Name, Arguments: tc.Arguments, }, } } } oaiMessages[i] = oaiMsg } reqBody := openAIRequest{ Model: modelName, Messages: oaiMessages, Temperature: 0.8, Stream: stream, Tools: tools, } if len(tools) > 0 { reqBody.ToolChoice = "auto" } jsonBody, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("序列化请求失败: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", p.config.BaseURL+"/chat/completions", bytes.NewReader(jsonBody)) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+p.config.APIKey) resp, err := p.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("请求失败: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("读取响应失败: %w", err) } if resp.StatusCode != http.StatusOK { var errResp openAIResponse if json.Unmarshal(body, &errResp) == nil && errResp.Error != nil { return nil, fmt.Errorf("API错误 [%s]: %s", errResp.Error.Code, errResp.Error.Message) } return nil, fmt.Errorf("API返回状态码 %d: %s", resp.StatusCode, string(body)) } var oaiResp openAIResponse if err := json.Unmarshal(body, &oaiResp); err != nil { return nil, fmt.Errorf("解析响应失败: %w", err) } if len(oaiResp.Choices) == 0 { return nil, fmt.Errorf("API返回空choices") } // 检查是否有工具调用 choice := oaiResp.Choices[0] llmResp := &model.LLMResponse{ Content: choice.Message.Content, FinishReason: choice.FinishReason, ReasoningContent: choice.Message.ReasoningContent, Usage: model.Usage{ PromptTokens: oaiResp.Usage.PromptTokens, CompletionTokens: oaiResp.Usage.CompletionTokens, TotalTokens: oaiResp.Usage.TotalTokens, }, } if len(choice.Message.ToolCalls) > 0 { llmResp.ToolCalls = make([]model.ToolCall, 0, len(choice.Message.ToolCalls)) for _, tc := range choice.Message.ToolCalls { llmResp.ToolCalls = append(llmResp.ToolCalls, model.ToolCall{ ID: tc.ID, Name: tc.Function.Name, Arguments: tc.Function.Arguments, }) } } return llmResp, nil } // doChatStream 执行流式对话请求(返回原始HTTP响应) func (p *OpenAIProvider) doChatStream(ctx context.Context, messages []model.LLMMessage, modelName string, tools []OpenAITool) (*http.Response, error) { oaiMessages := make([]openAIMessage, len(messages)) for i, msg := range messages { oaiMsg := openAIMessage{ Role: string(msg.Role), Content: msg.Content, Name: msg.Name, ToolCallID: msg.ToolCallID, ReasoningContent: msg.ReasoningContent, } if len(msg.ToolCalls) > 0 { oaiMsg.ToolCalls = make([]openAIToolCall, len(msg.ToolCalls)) for j, tc := range msg.ToolCalls { oaiMsg.ToolCalls[j] = openAIToolCall{ ID: tc.ID, Type: "function", Function: openAIToolCallFunction{ Name: tc.Name, Arguments: tc.Arguments, }, } } } oaiMessages[i] = oaiMsg } reqBody := openAIRequest{ Model: modelName, Messages: oaiMessages, Temperature: 0.8, Stream: true, Tools: tools, } if len(tools) > 0 { reqBody.ToolChoice = "auto" } jsonBody, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("序列化请求失败: %w", err) } req, err := http.NewRequestWithContext(ctx, "POST", p.config.BaseURL+"/chat/completions", bytes.NewReader(jsonBody)) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+p.config.APIKey) req.Header.Set("Accept", "text/event-stream") resp, err := p.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("请求失败: %w", err) } if resp.StatusCode != http.StatusOK { defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("API返回状态码 %d: %s", resp.StatusCode, string(body)) } return resp, nil } // ModelName 返回模型名称 func (p *OpenAIProvider) ModelName() string { return p.config.Model }