Files
Cyrene/backend/ai-core/internal/tools/registry.go
T
AskaEth 673ff752c5 feat: 插件-工具合并 — 创建 pkg/plugins 共享模块并移除 tool-engine
- 新增 backend/pkg/plugins/ 共享模块:SDK 接口、PluginManager、ToolRegistry(含环形缓冲区调用日志)
- 13 个通用插件从 plugin-manager 迁移至共享模块(import 路径统一)
- ai-core 切换至共享 ToolRegistry,进程内执行(零网络开销),包装 6 个专属工具
- plugin-manager 迁移至共享模块,保留管理 REST API
- 新增 DevTools 插件管理面板(侧边栏 → 🔌 插件管理)
- 移除 tool-engine 服务(从 go.work、DevTools 配置、编译系统)
- 工具调用记录 API 从 Tool-Engine 迁至 AI-Core(/api/v1/tools/calls)
- ai-core ContextStore 启动时从 PostgreSQL 恢复会话历史
- 清理所有过时引用和备份文件

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 20:52:39 +08:00

304 lines
7.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package tools
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/yourname/cyrene-ai/pkg/logger"
)
// ToolDefinition 工具定义(用于 LLM function calling
type ToolDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
// ToolResult 工具执行结果
type ToolResult struct {
ToolName string `json:"tool_name"`
Success bool `json:"success"`
Data string `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// ToolExecutor 工具执行器接口
type ToolExecutor interface {
// Execute 执行工具调用
Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error)
// Definition 返回工具定义
Definition() ToolDefinition
}
// CallLogRecord 工具调用记录
type CallLogRecord struct {
CallID string `json:"call_id"`
ToolName string `json:"tool_name"`
Arguments string `json:"arguments"`
Output string `json:"output"`
Error string `json:"error"`
Success bool `json:"success"`
DurationMs int `json:"duration_ms"`
Timestamp int64 `json:"timestamp"`
}
// callLogRing 线程安全的环形缓冲区
type callLogRing struct {
mu sync.Mutex
records []CallLogRecord
capacity int
head int
size int
}
func newCallLogRing(capacity int) *callLogRing {
return &callLogRing{capacity: capacity, records: make([]CallLogRecord, capacity)}
}
func (r *callLogRing) push(rec CallLogRecord) {
r.mu.Lock()
defer r.mu.Unlock()
rec.CallID = fmt.Sprintf("%d", time.Now().UnixNano())
rec.Timestamp = time.Now().UnixMilli()
r.records[r.head] = rec
r.head = (r.head + 1) % r.capacity
if r.size < r.capacity {
r.size++
}
}
func (r *callLogRing) get(limit int) []CallLogRecord {
r.mu.Lock()
defer r.mu.Unlock()
if limit <= 0 || limit > r.size {
limit = r.size
}
result := make([]CallLogRecord, limit)
for i := 0; i < limit; i++ {
idx := (r.head - 1 - i) % r.capacity
if idx < 0 {
idx += r.capacity
}
result[i] = r.records[idx]
}
return result
}
func (r *callLogRing) statsByTool() map[string]map[string]interface{} {
r.mu.Lock()
defer r.mu.Unlock()
byTool := make(map[string]map[string]interface{})
for i := 0; i < r.size; i++ {
idx := (r.head - 1 - i) % r.capacity
if idx < 0 {
idx += r.capacity
}
rec := r.records[idx]
if _, ok := byTool[rec.ToolName]; !ok {
byTool[rec.ToolName] = map[string]interface{}{
"tool_name": rec.ToolName, "count": 0, "success_count": 0, "fail_count": 0, "total_duration_ms": 0,
}
}
s := byTool[rec.ToolName]
s["count"] = s["count"].(int) + 1
if rec.Success {
s["success_count"] = s["success_count"].(int) + 1
} else {
s["fail_count"] = s["fail_count"].(int) + 1
}
s["total_duration_ms"] = s["total_duration_ms"].(int) + rec.DurationMs
}
return byTool
}
// Registry 工具注册中心
type Registry struct {
mu sync.RWMutex
tools map[string]ToolExecutor
enabled bool
callLog *callLogRing
}
// NewRegistry 创建工具注册中心
func NewRegistry() *Registry {
return &Registry{
tools: make(map[string]ToolExecutor),
enabled: true,
callLog: newCallLogRing(500),
}
}
// Register 注册工具
func (r *Registry) Register(executor ToolExecutor) {
r.mu.Lock()
defer r.mu.Unlock()
def := executor.Definition()
r.tools[def.Name] = executor
logger.Printf("[工具注册] 已注册工具: %s", def.Name)
}
// GetDefinitions 获取所有工具定义(用于 LLM function calling
func (r *Registry) GetDefinitions() []ToolDefinition {
r.mu.RLock()
defer r.mu.RUnlock()
defs := make([]ToolDefinition, 0, len(r.tools))
for _, executor := range r.tools {
defs = append(defs, executor.Definition())
}
return defs
}
// Execute 执行工具调用
func (r *Registry) Execute(ctx context.Context, toolName string, arguments map[string]interface{}) (*ToolResult, error) {
r.mu.RLock()
executor, ok := r.tools[toolName]
r.mu.RUnlock()
startTime := time.Now()
if !ok {
errMsg := fmt.Sprintf("未知工具: %s", toolName)
r.callLog.push(CallLogRecord{
ToolName: toolName, Error: errMsg, Success: false, DurationMs: int(time.Since(startTime).Milliseconds()),
})
return &ToolResult{ToolName: toolName, Success: false, Error: errMsg}, nil
}
logger.Printf("[工具执行] 调用工具 %s,参数: %v", toolName, arguments)
result, err := executor.Execute(ctx, arguments)
durationMs := int(time.Since(startTime).Milliseconds())
if err != nil {
logger.Printf("[工具执行] 工具 %s 执行失败: %v", toolName, err)
r.callLog.push(CallLogRecord{
ToolName: toolName, Error: err.Error(), Success: false, DurationMs: durationMs,
})
return &ToolResult{ToolName: toolName, Success: false, Error: err.Error()}, nil
}
argsJSON, _ := json.Marshal(arguments)
if result.Success {
logger.Printf("[工具执行] 工具 %s 执行成功 (数据长度: %d)", toolName, len(result.Data))
} else {
logger.Printf("[工具执行] 工具 %s 返回错误: %s", toolName, result.Error)
}
r.callLog.push(CallLogRecord{
ToolName: toolName, Arguments: string(argsJSON), Output: result.Data,
Error: result.Error, Success: result.Success, DurationMs: durationMs,
})
return result, nil
}
// IsEnabled 检查工具系统是否启用
func (r *Registry) IsEnabled() bool {
r.mu.RLock()
defer r.mu.RUnlock()
return r.enabled
}
// SetEnabled 启用/禁用工具系统
func (r *Registry) SetEnabled(enabled bool) {
r.mu.Lock()
defer r.mu.Unlock()
r.enabled = enabled
}
// HasTool 检查工具是否存在
func (r *Registry) HasTool(name string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, ok := r.tools[name]
return ok
}
// ListTools 列出所有已注册的工具名称
func (r *Registry) ListTools() []string {
r.mu.RLock()
defer r.mu.RUnlock()
names := make([]string, 0, len(r.tools))
for name := range r.tools {
names = append(names, name)
}
return names
}
// GetCallLogs 获取工具调用记录(最新在前)
func (r *Registry) GetCallLogs(toolName string, limit int) []CallLogRecord {
all := r.callLog.get(r.callLog.size)
if toolName == "" {
if limit > 0 && limit < len(all) {
all = all[:limit]
}
return all
}
filtered := make([]CallLogRecord, 0)
for _, rec := range all {
if rec.ToolName == toolName {
filtered = append(filtered, rec)
if limit > 0 && len(filtered) >= limit {
break
}
}
}
return filtered
}
// GetCallStats 获取工具调用统计
func (r *Registry) GetCallStats() map[string]interface{} {
byTool := r.callLog.statsByTool()
totalCalls, successCount, failCount, totalDurationMs := 0, 0, 0, 0
toolStats := make([]map[string]interface{}, 0, len(byTool))
for _, s := range byTool {
count := s["count"].(int)
success := s["success_count"].(int)
fail := s["fail_count"].(int)
totalDur := s["total_duration_ms"].(int)
avgDur := 0.0
if count > 0 {
avgDur = float64(totalDur) / float64(count)
}
s["avg_duration_ms"] = avgDur
delete(s, "total_duration_ms")
toolStats = append(toolStats, s)
totalCalls += count
successCount += success
failCount += fail
totalDurationMs += totalDur
}
avgDuration := 0.0
if totalCalls > 0 {
avgDuration = float64(totalDurationMs) / float64(totalCalls)
}
successRate := 0.0
if totalCalls > 0 {
successRate = float64(successCount) / float64(totalCalls) * 100
}
return map[string]interface{}{
"total_calls": totalCalls, "success_count": successCount, "fail_count": failCount,
"success_rate": successRate, "avg_duration_ms": avgDuration, "by_tool": toolStats,
}
}
// ToJSON 将工具定义序列化为 JSON(用于 LLM 请求)
func (r *Registry) ToJSON() ([]byte, error) {
defs := r.GetDefinitions()
tools := make([]map[string]interface{}, 0, len(defs))
for _, d := range defs {
tools = append(tools, map[string]interface{}{
"type": "function",
"function": map[string]interface{}{
"name": d.Name,
"description": d.Description,
"parameters": d.Parameters,
},
})
}
return json.Marshal(tools)
}