package tools import ( "context" "encoding/json" "fmt" "github.com/yourname/cyrene-ai/pkg/logger" "sync" ) // ToolDefinition 工具定义(用于 LLM function calling) type ToolDefinition struct { Name string `json:"name"` Description string `json:"description"` Parameters map[string]interface{} `json:"parameters"` } // ToolResult 工具执行结果 type ToolResult struct { ToolName string `json:"tool_name"` Success bool `json:"success"` Data string `json:"data,omitempty"` Error string `json:"error,omitempty"` } // ToolExecutor 工具执行器接口 type ToolExecutor interface { // Execute 执行工具调用 Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) // Definition 返回工具定义 Definition() ToolDefinition } // Registry 工具注册中心 type Registry struct { mu sync.RWMutex tools map[string]ToolExecutor enabled bool } // NewRegistry 创建工具注册中心 func NewRegistry() *Registry { return &Registry{ tools: make(map[string]ToolExecutor), enabled: true, } } // Register 注册工具 func (r *Registry) Register(executor ToolExecutor) { r.mu.Lock() defer r.mu.Unlock() def := executor.Definition() r.tools[def.Name] = executor logger.Printf("[工具注册] 已注册工具: %s", def.Name) } // GetDefinitions 获取所有工具定义(用于 LLM function calling) func (r *Registry) GetDefinitions() []ToolDefinition { r.mu.RLock() defer r.mu.RUnlock() defs := make([]ToolDefinition, 0, len(r.tools)) for _, executor := range r.tools { defs = append(defs, executor.Definition()) } return defs } // Execute 执行工具调用 func (r *Registry) Execute(ctx context.Context, toolName string, arguments map[string]interface{}) (*ToolResult, error) { r.mu.RLock() executor, ok := r.tools[toolName] r.mu.RUnlock() if !ok { return &ToolResult{ ToolName: toolName, Success: false, Error: fmt.Sprintf("未知工具: %s", toolName), }, nil } logger.Printf("[工具执行] 调用工具 %s,参数: %v", toolName, arguments) result, err := executor.Execute(ctx, arguments) if err != nil { logger.Printf("[工具执行] 工具 %s 执行失败: %v", toolName, err) return &ToolResult{ ToolName: toolName, Success: false, Error: err.Error(), }, nil } if result.Success { logger.Printf("[工具执行] 工具 %s 执行成功 (数据长度: %d)", toolName, len(result.Data)) } else { logger.Printf("[工具执行] 工具 %s 返回错误: %s", toolName, result.Error) } return result, nil } // IsEnabled 检查工具系统是否启用 func (r *Registry) IsEnabled() bool { r.mu.RLock() defer r.mu.RUnlock() return r.enabled } // SetEnabled 启用/禁用工具系统 func (r *Registry) SetEnabled(enabled bool) { r.mu.Lock() defer r.mu.Unlock() r.enabled = enabled } // HasTool 检查工具是否存在 func (r *Registry) HasTool(name string) bool { r.mu.RLock() defer r.mu.RUnlock() _, ok := r.tools[name] return ok } // ListTools 列出所有已注册的工具名称 func (r *Registry) ListTools() []string { r.mu.RLock() defer r.mu.RUnlock() names := make([]string, 0, len(r.tools)) for name := range r.tools { names = append(names, name) } return names } // ToJSON 将工具定义序列化为 JSON(用于 LLM 请求) func (r *Registry) ToJSON() ([]byte, error) { defs := r.GetDefinitions() tools := make([]map[string]interface{}, 0, len(defs)) for _, d := range defs { tools = append(tools, map[string]interface{}{ "type": "function", "function": map[string]interface{}{ "name": d.Name, "description": d.Description, "parameters": d.Parameters, }, }) } return json.Marshal(tools) }