Files
Cyrene/backend/ai-core/internal/tools/registry.go
T
AskaEth 71f0a1abdb feat: Go模块路径迁移 + Docker生产部署适配 + ethend Docker兼容
- 所有Go模块路径从 github.com/yourname/cyrene-ai 迁移到 git.yeij.top/AskaEth/Cyrene
- 5个Go Dockerfile添加 GOPROXY=https://goproxy.cn,direct 解决国内构建问题
- ai-core go.mod 添加 pkg/plugins replace 指令
- Caddyfile 简化为 http:// 通配 + handle 保留 /api 前缀
- ethend Dockerfile 适配 (npm install + 仅 COPY package.json)
- ethend 新增 RUNNING_IN_DOCKER 环境变量,健康检查改用Docker服务名
- ethend 数据库状态检查支持Docker hostname (postgres/redis/qdrant/minio)
- process-manager 新增 CONTAINER_SVC_MAP + Docker模式自动检测
- 统一 docker-compose.dev.db.yml 卷名 (pg_data/redis_data/qdrant_data/minio_data)
- docker-compose.yml ethend服务挂载docker.sock + 端口变量化
- 清理 .env 统一后的残留文件与提示信息

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-30 13:43:22 +08:00

304 lines
7.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package tools
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"git.yeij.top/AskaEth/Cyrene/pkg/logger"
)
// ToolDefinition 工具定义(用于 LLM function calling
type ToolDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
// ToolResult 工具执行结果
type ToolResult struct {
ToolName string `json:"tool_name"`
Success bool `json:"success"`
Data string `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// ToolExecutor 工具执行器接口
type ToolExecutor interface {
// Execute 执行工具调用
Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error)
// Definition 返回工具定义
Definition() ToolDefinition
}
// CallLogRecord 工具调用记录
type CallLogRecord struct {
CallID string `json:"call_id"`
ToolName string `json:"tool_name"`
Arguments string `json:"arguments"`
Output string `json:"output"`
Error string `json:"error"`
Success bool `json:"success"`
DurationMs int `json:"duration_ms"`
Timestamp int64 `json:"timestamp"`
}
// callLogRing 线程安全的环形缓冲区
type callLogRing struct {
mu sync.Mutex
records []CallLogRecord
capacity int
head int
size int
}
func newCallLogRing(capacity int) *callLogRing {
return &callLogRing{capacity: capacity, records: make([]CallLogRecord, capacity)}
}
func (r *callLogRing) push(rec CallLogRecord) {
r.mu.Lock()
defer r.mu.Unlock()
rec.CallID = fmt.Sprintf("%d", time.Now().UnixNano())
rec.Timestamp = time.Now().UnixMilli()
r.records[r.head] = rec
r.head = (r.head + 1) % r.capacity
if r.size < r.capacity {
r.size++
}
}
func (r *callLogRing) get(limit int) []CallLogRecord {
r.mu.Lock()
defer r.mu.Unlock()
if limit <= 0 || limit > r.size {
limit = r.size
}
result := make([]CallLogRecord, limit)
for i := 0; i < limit; i++ {
idx := (r.head - 1 - i) % r.capacity
if idx < 0 {
idx += r.capacity
}
result[i] = r.records[idx]
}
return result
}
func (r *callLogRing) statsByTool() map[string]map[string]interface{} {
r.mu.Lock()
defer r.mu.Unlock()
byTool := make(map[string]map[string]interface{})
for i := 0; i < r.size; i++ {
idx := (r.head - 1 - i) % r.capacity
if idx < 0 {
idx += r.capacity
}
rec := r.records[idx]
if _, ok := byTool[rec.ToolName]; !ok {
byTool[rec.ToolName] = map[string]interface{}{
"tool_name": rec.ToolName, "count": 0, "success_count": 0, "fail_count": 0, "total_duration_ms": 0,
}
}
s := byTool[rec.ToolName]
s["count"] = s["count"].(int) + 1
if rec.Success {
s["success_count"] = s["success_count"].(int) + 1
} else {
s["fail_count"] = s["fail_count"].(int) + 1
}
s["total_duration_ms"] = s["total_duration_ms"].(int) + rec.DurationMs
}
return byTool
}
// Registry 工具注册中心
type Registry struct {
mu sync.RWMutex
tools map[string]ToolExecutor
enabled bool
callLog *callLogRing
}
// NewRegistry 创建工具注册中心
func NewRegistry() *Registry {
return &Registry{
tools: make(map[string]ToolExecutor),
enabled: true,
callLog: newCallLogRing(500),
}
}
// Register 注册工具
func (r *Registry) Register(executor ToolExecutor) {
r.mu.Lock()
defer r.mu.Unlock()
def := executor.Definition()
r.tools[def.Name] = executor
logger.Printf("[工具注册] 已注册工具: %s", def.Name)
}
// GetDefinitions 获取所有工具定义(用于 LLM function calling
func (r *Registry) GetDefinitions() []ToolDefinition {
r.mu.RLock()
defer r.mu.RUnlock()
defs := make([]ToolDefinition, 0, len(r.tools))
for _, executor := range r.tools {
defs = append(defs, executor.Definition())
}
return defs
}
// Execute 执行工具调用
func (r *Registry) Execute(ctx context.Context, toolName string, arguments map[string]interface{}) (*ToolResult, error) {
r.mu.RLock()
executor, ok := r.tools[toolName]
r.mu.RUnlock()
startTime := time.Now()
if !ok {
errMsg := fmt.Sprintf("未知工具: %s", toolName)
r.callLog.push(CallLogRecord{
ToolName: toolName, Error: errMsg, Success: false, DurationMs: int(time.Since(startTime).Milliseconds()),
})
return &ToolResult{ToolName: toolName, Success: false, Error: errMsg}, nil
}
logger.Printf("[工具执行] 调用工具 %s,参数: %v", toolName, arguments)
result, err := executor.Execute(ctx, arguments)
durationMs := int(time.Since(startTime).Milliseconds())
if err != nil {
logger.Printf("[工具执行] 工具 %s 执行失败: %v", toolName, err)
r.callLog.push(CallLogRecord{
ToolName: toolName, Error: err.Error(), Success: false, DurationMs: durationMs,
})
return &ToolResult{ToolName: toolName, Success: false, Error: err.Error()}, nil
}
argsJSON, _ := json.Marshal(arguments)
if result.Success {
logger.Printf("[工具执行] 工具 %s 执行成功 (数据长度: %d)", toolName, len(result.Data))
} else {
logger.Printf("[工具执行] 工具 %s 返回错误: %s", toolName, result.Error)
}
r.callLog.push(CallLogRecord{
ToolName: toolName, Arguments: string(argsJSON), Output: result.Data,
Error: result.Error, Success: result.Success, DurationMs: durationMs,
})
return result, nil
}
// IsEnabled 检查工具系统是否启用
func (r *Registry) IsEnabled() bool {
r.mu.RLock()
defer r.mu.RUnlock()
return r.enabled
}
// SetEnabled 启用/禁用工具系统
func (r *Registry) SetEnabled(enabled bool) {
r.mu.Lock()
defer r.mu.Unlock()
r.enabled = enabled
}
// HasTool 检查工具是否存在
func (r *Registry) HasTool(name string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, ok := r.tools[name]
return ok
}
// ListTools 列出所有已注册的工具名称
func (r *Registry) ListTools() []string {
r.mu.RLock()
defer r.mu.RUnlock()
names := make([]string, 0, len(r.tools))
for name := range r.tools {
names = append(names, name)
}
return names
}
// GetCallLogs 获取工具调用记录(最新在前)
func (r *Registry) GetCallLogs(toolName string, limit int) []CallLogRecord {
all := r.callLog.get(r.callLog.size)
if toolName == "" {
if limit > 0 && limit < len(all) {
all = all[:limit]
}
return all
}
filtered := make([]CallLogRecord, 0)
for _, rec := range all {
if rec.ToolName == toolName {
filtered = append(filtered, rec)
if limit > 0 && len(filtered) >= limit {
break
}
}
}
return filtered
}
// GetCallStats 获取工具调用统计
func (r *Registry) GetCallStats() map[string]interface{} {
byTool := r.callLog.statsByTool()
totalCalls, successCount, failCount, totalDurationMs := 0, 0, 0, 0
toolStats := make([]map[string]interface{}, 0, len(byTool))
for _, s := range byTool {
count := s["count"].(int)
success := s["success_count"].(int)
fail := s["fail_count"].(int)
totalDur := s["total_duration_ms"].(int)
avgDur := 0.0
if count > 0 {
avgDur = float64(totalDur) / float64(count)
}
s["avg_duration_ms"] = avgDur
delete(s, "total_duration_ms")
toolStats = append(toolStats, s)
totalCalls += count
successCount += success
failCount += fail
totalDurationMs += totalDur
}
avgDuration := 0.0
if totalCalls > 0 {
avgDuration = float64(totalDurationMs) / float64(totalCalls)
}
successRate := 0.0
if totalCalls > 0 {
successRate = float64(successCount) / float64(totalCalls) * 100
}
return map[string]interface{}{
"total_calls": totalCalls, "success_count": successCount, "fail_count": failCount,
"success_rate": successRate, "avg_duration_ms": avgDuration, "by_tool": toolStats,
}
}
// ToJSON 将工具定义序列化为 JSON(用于 LLM 请求)
func (r *Registry) ToJSON() ([]byte, error) {
defs := r.GetDefinitions()
tools := make([]map[string]interface{}, 0, len(defs))
for _, d := range defs {
tools = append(tools, map[string]interface{}{
"type": "function",
"function": map[string]interface{}{
"name": d.Name,
"description": d.Description,
"parameters": d.Parameters,
},
})
}
return json.Marshal(tools)
}