154 lines
3.6 KiB
Go
154 lines
3.6 KiB
Go
package tools
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log"
|
||
"sync"
|
||
)
|
||
|
||
// ToolDefinition 工具定义(用于 LLM function calling)
|
||
type ToolDefinition struct {
|
||
Name string `json:"name"`
|
||
Description string `json:"description"`
|
||
Parameters map[string]interface{} `json:"parameters"`
|
||
}
|
||
|
||
// ToolResult 工具执行结果
|
||
type ToolResult struct {
|
||
ToolName string `json:"tool_name"`
|
||
Success bool `json:"success"`
|
||
Data string `json:"data,omitempty"`
|
||
Error string `json:"error,omitempty"`
|
||
}
|
||
|
||
// ToolExecutor 工具执行器接口
|
||
type ToolExecutor interface {
|
||
// Execute 执行工具调用
|
||
Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error)
|
||
// Definition 返回工具定义
|
||
Definition() ToolDefinition
|
||
}
|
||
|
||
// Registry 工具注册中心
|
||
type Registry struct {
|
||
mu sync.RWMutex
|
||
tools map[string]ToolExecutor
|
||
enabled bool
|
||
}
|
||
|
||
// NewRegistry 创建工具注册中心
|
||
func NewRegistry() *Registry {
|
||
return &Registry{
|
||
tools: make(map[string]ToolExecutor),
|
||
enabled: true,
|
||
}
|
||
}
|
||
|
||
// Register 注册工具
|
||
func (r *Registry) Register(executor ToolExecutor) {
|
||
r.mu.Lock()
|
||
defer r.mu.Unlock()
|
||
def := executor.Definition()
|
||
r.tools[def.Name] = executor
|
||
log.Printf("[工具注册] 已注册工具: %s", def.Name)
|
||
}
|
||
|
||
// GetDefinitions 获取所有工具定义(用于 LLM function calling)
|
||
func (r *Registry) GetDefinitions() []ToolDefinition {
|
||
r.mu.RLock()
|
||
defer r.mu.RUnlock()
|
||
|
||
defs := make([]ToolDefinition, 0, len(r.tools))
|
||
for _, executor := range r.tools {
|
||
defs = append(defs, executor.Definition())
|
||
}
|
||
return defs
|
||
}
|
||
|
||
// Execute 执行工具调用
|
||
func (r *Registry) Execute(ctx context.Context, toolName string, arguments map[string]interface{}) (*ToolResult, error) {
|
||
r.mu.RLock()
|
||
executor, ok := r.tools[toolName]
|
||
r.mu.RUnlock()
|
||
|
||
if !ok {
|
||
return &ToolResult{
|
||
ToolName: toolName,
|
||
Success: false,
|
||
Error: fmt.Sprintf("未知工具: %s", toolName),
|
||
}, nil
|
||
}
|
||
|
||
log.Printf("[工具执行] 调用工具 %s,参数: %v", toolName, arguments)
|
||
result, err := executor.Execute(ctx, arguments)
|
||
if err != nil {
|
||
log.Printf("[工具执行] 工具 %s 执行失败: %v", toolName, err)
|
||
return &ToolResult{
|
||
ToolName: toolName,
|
||
Success: false,
|
||
Error: err.Error(),
|
||
}, nil
|
||
}
|
||
|
||
if result.Success {
|
||
log.Printf("[工具执行] 工具 %s 执行成功 (数据长度: %d)", toolName, len(result.Data))
|
||
} else {
|
||
log.Printf("[工具执行] 工具 %s 返回错误: %s", toolName, result.Error)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// IsEnabled 检查工具系统是否启用
|
||
func (r *Registry) IsEnabled() bool {
|
||
r.mu.RLock()
|
||
defer r.mu.RUnlock()
|
||
return r.enabled
|
||
}
|
||
|
||
// SetEnabled 启用/禁用工具系统
|
||
func (r *Registry) SetEnabled(enabled bool) {
|
||
r.mu.Lock()
|
||
defer r.mu.Unlock()
|
||
r.enabled = enabled
|
||
}
|
||
|
||
// HasTool 检查工具是否存在
|
||
func (r *Registry) HasTool(name string) bool {
|
||
r.mu.RLock()
|
||
defer r.mu.RUnlock()
|
||
_, ok := r.tools[name]
|
||
return ok
|
||
}
|
||
|
||
// ListTools 列出所有已注册的工具名称
|
||
func (r *Registry) ListTools() []string {
|
||
r.mu.RLock()
|
||
defer r.mu.RUnlock()
|
||
|
||
names := make([]string, 0, len(r.tools))
|
||
for name := range r.tools {
|
||
names = append(names, name)
|
||
}
|
||
return names
|
||
}
|
||
|
||
// ToJSON 将工具定义序列化为 JSON(用于 LLM 请求)
|
||
func (r *Registry) ToJSON() ([]byte, error) {
|
||
defs := r.GetDefinitions()
|
||
tools := make([]map[string]interface{}, 0, len(defs))
|
||
for _, d := range defs {
|
||
tools = append(tools, map[string]interface{}{
|
||
"type": "function",
|
||
"function": map[string]interface{}{
|
||
"name": d.Name,
|
||
"description": d.Description,
|
||
"parameters": d.Parameters,
|
||
},
|
||
})
|
||
}
|
||
return json.Marshal(tools)
|
||
}
|