Files
Cyrene/backend/ai-core/internal/tools/registry.go
T

154 lines
3.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package tools
import (
"context"
"encoding/json"
"fmt"
"log"
"sync"
)
// ToolDefinition 工具定义(用于 LLM function calling
type ToolDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
// ToolResult 工具执行结果
type ToolResult struct {
ToolName string `json:"tool_name"`
Success bool `json:"success"`
Data string `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// ToolExecutor 工具执行器接口
type ToolExecutor interface {
// Execute 执行工具调用
Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error)
// Definition 返回工具定义
Definition() ToolDefinition
}
// Registry 工具注册中心
type Registry struct {
mu sync.RWMutex
tools map[string]ToolExecutor
enabled bool
}
// NewRegistry 创建工具注册中心
func NewRegistry() *Registry {
return &Registry{
tools: make(map[string]ToolExecutor),
enabled: true,
}
}
// Register 注册工具
func (r *Registry) Register(executor ToolExecutor) {
r.mu.Lock()
defer r.mu.Unlock()
def := executor.Definition()
r.tools[def.Name] = executor
log.Printf("[工具注册] 已注册工具: %s", def.Name)
}
// GetDefinitions 获取所有工具定义(用于 LLM function calling
func (r *Registry) GetDefinitions() []ToolDefinition {
r.mu.RLock()
defer r.mu.RUnlock()
defs := make([]ToolDefinition, 0, len(r.tools))
for _, executor := range r.tools {
defs = append(defs, executor.Definition())
}
return defs
}
// Execute 执行工具调用
func (r *Registry) Execute(ctx context.Context, toolName string, arguments map[string]interface{}) (*ToolResult, error) {
r.mu.RLock()
executor, ok := r.tools[toolName]
r.mu.RUnlock()
if !ok {
return &ToolResult{
ToolName: toolName,
Success: false,
Error: fmt.Sprintf("未知工具: %s", toolName),
}, nil
}
log.Printf("[工具执行] 调用工具 %s,参数: %v", toolName, arguments)
result, err := executor.Execute(ctx, arguments)
if err != nil {
log.Printf("[工具执行] 工具 %s 执行失败: %v", toolName, err)
return &ToolResult{
ToolName: toolName,
Success: false,
Error: err.Error(),
}, nil
}
if result.Success {
log.Printf("[工具执行] 工具 %s 执行成功 (数据长度: %d)", toolName, len(result.Data))
} else {
log.Printf("[工具执行] 工具 %s 返回错误: %s", toolName, result.Error)
}
return result, nil
}
// IsEnabled 检查工具系统是否启用
func (r *Registry) IsEnabled() bool {
r.mu.RLock()
defer r.mu.RUnlock()
return r.enabled
}
// SetEnabled 启用/禁用工具系统
func (r *Registry) SetEnabled(enabled bool) {
r.mu.Lock()
defer r.mu.Unlock()
r.enabled = enabled
}
// HasTool 检查工具是否存在
func (r *Registry) HasTool(name string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, ok := r.tools[name]
return ok
}
// ListTools 列出所有已注册的工具名称
func (r *Registry) ListTools() []string {
r.mu.RLock()
defer r.mu.RUnlock()
names := make([]string, 0, len(r.tools))
for name := range r.tools {
names = append(names, name)
}
return names
}
// ToJSON 将工具定义序列化为 JSON(用于 LLM 请求)
func (r *Registry) ToJSON() ([]byte, error) {
defs := r.GetDefinitions()
tools := make([]map[string]interface{}, 0, len(defs))
for _, d := range defs {
tools = append(tools, map[string]interface{}{
"type": "function",
"function": map[string]interface{}{
"name": d.Name,
"description": d.Description,
"parameters": d.Parameters,
},
})
}
return json.Marshal(tools)
}