Files
Cyrene/backend/ai-core/internal/llm/adapter.go
T
2026-05-16 08:26:56 +08:00

84 lines
2.0 KiB
Go

package llm
import (
"context"
"io"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
)
// Adapter LLM适配器接口
// 支持不同的LLM后端(OpenAI、Ollama、vLLM等)
type Adapter struct {
provider LLMProvider
}
// LLMProvider LLM提供商接口
type LLMProvider interface {
// Chat 同步对话
Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)
// ChatStream 流式对话,返回一个channel逐token推送
ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error)
// ModelName 返回当前使用的模型名称
ModelName() string
}
// StreamChunk 流式响应的单个片段
type StreamChunk struct {
Content string // delta内容
Done bool // 是否为最后一块
Error error // 错误信息
Usage *model.Usage // 最后一块时返回token统计
}
// NewAdapter 创建LLM适配器
func NewAdapter(provider LLMProvider) *Adapter {
return &Adapter{provider: provider}
}
// Chat 同步对话
func (a *Adapter) Chat(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) {
return a.provider.Chat(ctx, messages)
}
// ChatStream 流式对话
func (a *Adapter) ChatStream(ctx context.Context, messages []model.LLMMessage) (<-chan StreamChunk, error) {
return a.provider.ChatStream(ctx, messages)
}
// ModelName 返回模型名称
func (a *Adapter) ModelName() string {
return a.provider.ModelName()
}
// collectStream 辅助函数:将流式响应收集为完整响应
func collectStream(ch <-chan StreamChunk) (*model.LLMResponse, error) {
var content string
var lastUsage *model.Usage
for chunk := range ch {
if chunk.Error != nil {
return nil, chunk.Error
}
if chunk.Done {
lastUsage = chunk.Usage
break
}
content += chunk.Content
}
resp := &model.LLMResponse{
Content: content,
FinishReason: "stop",
}
if lastUsage != nil {
resp.Usage = *lastUsage
}
return resp, nil
}
// Ensure io is used (will be needed for SSE parsing)
var _ io.Reader