package tools import ( "context" "encoding/json" "fmt" "sync" "time" "github.com/yourname/cyrene-ai/pkg/logger" ) // ToolDefinition 工具定义(用于 LLM function calling) type ToolDefinition struct { Name string `json:"name"` Description string `json:"description"` Parameters map[string]interface{} `json:"parameters"` } // ToolResult 工具执行结果 type ToolResult struct { ToolName string `json:"tool_name"` Success bool `json:"success"` Data string `json:"data,omitempty"` Error string `json:"error,omitempty"` } // ToolExecutor 工具执行器接口 type ToolExecutor interface { // Execute 执行工具调用 Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) // Definition 返回工具定义 Definition() ToolDefinition } // CallLogRecord 工具调用记录 type CallLogRecord struct { CallID string `json:"call_id"` ToolName string `json:"tool_name"` Arguments string `json:"arguments"` Output string `json:"output"` Error string `json:"error"` Success bool `json:"success"` DurationMs int `json:"duration_ms"` Timestamp int64 `json:"timestamp"` } // callLogRing 线程安全的环形缓冲区 type callLogRing struct { mu sync.Mutex records []CallLogRecord capacity int head int size int } func newCallLogRing(capacity int) *callLogRing { return &callLogRing{capacity: capacity, records: make([]CallLogRecord, capacity)} } func (r *callLogRing) push(rec CallLogRecord) { r.mu.Lock() defer r.mu.Unlock() rec.CallID = fmt.Sprintf("%d", time.Now().UnixNano()) rec.Timestamp = time.Now().UnixMilli() r.records[r.head] = rec r.head = (r.head + 1) % r.capacity if r.size < r.capacity { r.size++ } } func (r *callLogRing) get(limit int) []CallLogRecord { r.mu.Lock() defer r.mu.Unlock() if limit <= 0 || limit > r.size { limit = r.size } result := make([]CallLogRecord, limit) for i := 0; i < limit; i++ { idx := (r.head - 1 - i) % r.capacity if idx < 0 { idx += r.capacity } result[i] = r.records[idx] } return result } func (r *callLogRing) statsByTool() map[string]map[string]interface{} { r.mu.Lock() defer r.mu.Unlock() byTool := make(map[string]map[string]interface{}) for i := 0; i < r.size; i++ { idx := (r.head - 1 - i) % r.capacity if idx < 0 { idx += r.capacity } rec := r.records[idx] if _, ok := byTool[rec.ToolName]; !ok { byTool[rec.ToolName] = map[string]interface{}{ "tool_name": rec.ToolName, "count": 0, "success_count": 0, "fail_count": 0, "total_duration_ms": 0, } } s := byTool[rec.ToolName] s["count"] = s["count"].(int) + 1 if rec.Success { s["success_count"] = s["success_count"].(int) + 1 } else { s["fail_count"] = s["fail_count"].(int) + 1 } s["total_duration_ms"] = s["total_duration_ms"].(int) + rec.DurationMs } return byTool } // Registry 工具注册中心 type Registry struct { mu sync.RWMutex tools map[string]ToolExecutor enabled bool callLog *callLogRing } // NewRegistry 创建工具注册中心 func NewRegistry() *Registry { return &Registry{ tools: make(map[string]ToolExecutor), enabled: true, callLog: newCallLogRing(500), } } // Register 注册工具 func (r *Registry) Register(executor ToolExecutor) { r.mu.Lock() defer r.mu.Unlock() def := executor.Definition() r.tools[def.Name] = executor logger.Printf("[工具注册] 已注册工具: %s", def.Name) } // GetDefinitions 获取所有工具定义(用于 LLM function calling) func (r *Registry) GetDefinitions() []ToolDefinition { r.mu.RLock() defer r.mu.RUnlock() defs := make([]ToolDefinition, 0, len(r.tools)) for _, executor := range r.tools { defs = append(defs, executor.Definition()) } return defs } // Execute 执行工具调用 func (r *Registry) Execute(ctx context.Context, toolName string, arguments map[string]interface{}) (*ToolResult, error) { r.mu.RLock() executor, ok := r.tools[toolName] r.mu.RUnlock() startTime := time.Now() if !ok { errMsg := fmt.Sprintf("未知工具: %s", toolName) r.callLog.push(CallLogRecord{ ToolName: toolName, Error: errMsg, Success: false, DurationMs: int(time.Since(startTime).Milliseconds()), }) return &ToolResult{ToolName: toolName, Success: false, Error: errMsg}, nil } logger.Printf("[工具执行] 调用工具 %s,参数: %v", toolName, arguments) result, err := executor.Execute(ctx, arguments) durationMs := int(time.Since(startTime).Milliseconds()) if err != nil { logger.Printf("[工具执行] 工具 %s 执行失败: %v", toolName, err) r.callLog.push(CallLogRecord{ ToolName: toolName, Error: err.Error(), Success: false, DurationMs: durationMs, }) return &ToolResult{ToolName: toolName, Success: false, Error: err.Error()}, nil } argsJSON, _ := json.Marshal(arguments) if result.Success { logger.Printf("[工具执行] 工具 %s 执行成功 (数据长度: %d)", toolName, len(result.Data)) } else { logger.Printf("[工具执行] 工具 %s 返回错误: %s", toolName, result.Error) } r.callLog.push(CallLogRecord{ ToolName: toolName, Arguments: string(argsJSON), Output: result.Data, Error: result.Error, Success: result.Success, DurationMs: durationMs, }) return result, nil } // IsEnabled 检查工具系统是否启用 func (r *Registry) IsEnabled() bool { r.mu.RLock() defer r.mu.RUnlock() return r.enabled } // SetEnabled 启用/禁用工具系统 func (r *Registry) SetEnabled(enabled bool) { r.mu.Lock() defer r.mu.Unlock() r.enabled = enabled } // HasTool 检查工具是否存在 func (r *Registry) HasTool(name string) bool { r.mu.RLock() defer r.mu.RUnlock() _, ok := r.tools[name] return ok } // ListTools 列出所有已注册的工具名称 func (r *Registry) ListTools() []string { r.mu.RLock() defer r.mu.RUnlock() names := make([]string, 0, len(r.tools)) for name := range r.tools { names = append(names, name) } return names } // GetCallLogs 获取工具调用记录(最新在前) func (r *Registry) GetCallLogs(toolName string, limit int) []CallLogRecord { all := r.callLog.get(r.callLog.size) if toolName == "" { if limit > 0 && limit < len(all) { all = all[:limit] } return all } filtered := make([]CallLogRecord, 0) for _, rec := range all { if rec.ToolName == toolName { filtered = append(filtered, rec) if limit > 0 && len(filtered) >= limit { break } } } return filtered } // GetCallStats 获取工具调用统计 func (r *Registry) GetCallStats() map[string]interface{} { byTool := r.callLog.statsByTool() totalCalls, successCount, failCount, totalDurationMs := 0, 0, 0, 0 toolStats := make([]map[string]interface{}, 0, len(byTool)) for _, s := range byTool { count := s["count"].(int) success := s["success_count"].(int) fail := s["fail_count"].(int) totalDur := s["total_duration_ms"].(int) avgDur := 0.0 if count > 0 { avgDur = float64(totalDur) / float64(count) } s["avg_duration_ms"] = avgDur delete(s, "total_duration_ms") toolStats = append(toolStats, s) totalCalls += count successCount += success failCount += fail totalDurationMs += totalDur } avgDuration := 0.0 if totalCalls > 0 { avgDuration = float64(totalDurationMs) / float64(totalCalls) } successRate := 0.0 if totalCalls > 0 { successRate = float64(successCount) / float64(totalCalls) * 100 } return map[string]interface{}{ "total_calls": totalCalls, "success_count": successCount, "fail_count": failCount, "success_rate": successRate, "avg_duration_ms": avgDuration, "by_tool": toolStats, } } // ToJSON 将工具定义序列化为 JSON(用于 LLM 请求) func (r *Registry) ToJSON() ([]byte, error) { defs := r.GetDefinitions() tools := make([]map[string]interface{}, 0, len(defs)) for _, d := range defs { tools = append(tools, map[string]interface{}{ "type": "function", "function": map[string]interface{}{ "name": d.Name, "description": d.Description, "parameters": d.Parameters, }, }) } return json.Marshal(tools) }