package manager import ( "context" "encoding/json" "fmt" "sync" "time" "git.yeij.top/AskaEth/Cyrene-Plugins/sdk" ) // CtxKeyIsAdmin is the context key for the admin flag. type ctxKey string const CtxKeyIsAdmin ctxKey = "isAdmin" // adminOnlyTools lists tools that require admin permission to execute. var adminOnlyTools = map[string]bool{ "host_exec": true, "os_exec": true, "host_file": true, } // IsAdminFromCtx returns true if the context carries an admin flag. func IsAdminFromCtx(ctx context.Context) bool { v, _ := ctx.Value(CtxKeyIsAdmin).(bool) return v } // CallLogRecord 工具调用记录 type CallLogRecord struct { CallID string `json:"call_id"` ToolName string `json:"tool_name"` Arguments string `json:"arguments"` Output string `json:"output"` Error string `json:"error"` Success bool `json:"success"` DurationMs int `json:"duration_ms"` Timestamp int64 `json:"timestamp"` } // callLogRing 线程安全的环形缓冲区 type callLogRing struct { mu sync.Mutex records []CallLogRecord capacity int head int size int } func newCallLogRing(capacity int) *callLogRing { return &callLogRing{capacity: capacity, records: make([]CallLogRecord, capacity)} } func (r *callLogRing) push(rec CallLogRecord) { r.mu.Lock() defer r.mu.Unlock() rec.CallID = fmt.Sprintf("%d", time.Now().UnixNano()) rec.Timestamp = time.Now().UnixMilli() r.records[r.head] = rec r.head = (r.head + 1) % r.capacity if r.size < r.capacity { r.size++ } } func (r *callLogRing) getAll() []CallLogRecord { r.mu.Lock() defer r.mu.Unlock() result := make([]CallLogRecord, r.size) for i := 0; i < r.size; i++ { idx := (r.head - 1 - i) % r.capacity if idx < 0 { idx += r.capacity } result[i] = r.records[idx] } return result } func (r *callLogRing) statsByTool() map[string]map[string]interface{} { r.mu.Lock() defer r.mu.Unlock() byTool := make(map[string]map[string]interface{}) for i := 0; i < r.size; i++ { idx := (r.head - 1 - i) % r.capacity if idx < 0 { idx += r.capacity } rec := r.records[idx] if _, ok := byTool[rec.ToolName]; !ok { byTool[rec.ToolName] = map[string]interface{}{ "tool_name": rec.ToolName, "count": 0, "success_count": 0, "fail_count": 0, "total_duration_ms": 0, } } s := byTool[rec.ToolName] s["count"] = s["count"].(int) + 1 if rec.Success { s["success_count"] = s["success_count"].(int) + 1 } else { s["fail_count"] = s["fail_count"].(int) + 1 } s["total_duration_ms"] = s["total_duration_ms"].(int) + rec.DurationMs } return byTool } // ToolRegistry aggregates tool definitions from all running plugins and dispatches execution. type ToolRegistry struct { mu sync.RWMutex tools map[string]sdk.Tool // tool ID -> Tool callLog *callLogRing enabled bool } func NewToolRegistry() *ToolRegistry { return &ToolRegistry{ tools: make(map[string]sdk.Tool), callLog: newCallLogRing(500), enabled: true, } } // IsEnabled returns whether tool execution is enabled. func (r *ToolRegistry) IsEnabled() bool { r.mu.RLock() defer r.mu.RUnlock() return r.enabled } // SetEnabled enables or disables tool execution. func (r *ToolRegistry) SetEnabled(enabled bool) { r.mu.Lock() defer r.mu.Unlock() r.enabled = enabled } // DefinitionNames returns all registered tool names. func (r *ToolRegistry) DefinitionNames() []string { r.mu.RLock() defer r.mu.RUnlock() names := make([]string, 0, len(r.tools)) for id := range r.tools { names = append(names, id) } return names } func (r *ToolRegistry) Register(tool sdk.Tool) error { r.mu.Lock() defer r.mu.Unlock() id := tool.Definition().ID if _, exists := r.tools[id]; exists { return fmt.Errorf("tool %q already registered", id) } r.tools[id] = tool return nil } func (r *ToolRegistry) Unregister(toolID string) { r.mu.Lock() defer r.mu.Unlock() delete(r.tools, toolID) } func (r *ToolRegistry) Get(toolID string) (sdk.Tool, bool) { r.mu.RLock() defer r.mu.RUnlock() t, ok := r.tools[toolID] return t, ok } func (r *ToolRegistry) List() []sdk.Tool { r.mu.RLock() defer r.mu.RUnlock() result := make([]sdk.Tool, 0, len(r.tools)) for _, t := range r.tools { result = append(result, t) } return result } func (r *ToolRegistry) Definitions() []sdk.ToolDefinition { r.mu.RLock() defer r.mu.RUnlock() defs := make([]sdk.ToolDefinition, 0, len(r.tools)) for _, t := range r.tools { defs = append(defs, t.Definition()) } return defs } func (r *ToolRegistry) Execute(ctx context.Context, toolID string, args map[string]interface{}) (*sdk.ToolResult, error) { r.mu.RLock() tool, ok := r.tools[toolID] r.mu.RUnlock() startTime := time.Now() if !ok { r.callLog.push(CallLogRecord{ ToolName: toolID, Error: fmt.Sprintf("tool %q not found", toolID), Success: false, DurationMs: int(time.Since(startTime).Milliseconds()), }) return nil, fmt.Errorf("tool %q not found", toolID) } if err := tool.Validate(args); err != nil { r.callLog.push(CallLogRecord{ ToolName: toolID, Error: err.Error(), Success: false, DurationMs: int(time.Since(startTime).Milliseconds()), }) return &sdk.ToolResult{Success: false, Error: err.Error()}, nil } // Admin-only tools: deny non-admin callers. if adminOnlyTools[toolID] && !IsAdminFromCtx(ctx) { errMsg := fmt.Sprintf("工具 %s 仅限管理员使用", toolID) r.callLog.push(CallLogRecord{ ToolName: toolID, Error: errMsg, Success: false, DurationMs: int(time.Since(startTime).Milliseconds()), }) return &sdk.ToolResult{Success: false, Error: errMsg}, nil } result, err := tool.Execute(ctx, args) durationMs := int(time.Since(startTime).Milliseconds()) if err != nil { r.callLog.push(CallLogRecord{ ToolName: toolID, Error: err.Error(), Success: false, DurationMs: durationMs, }) return result, err } var argsJSON string if args != nil { if b, _ := json.Marshal(args); b != nil { argsJSON = string(b) } } r.callLog.push(CallLogRecord{ ToolName: toolID, Arguments: argsJSON, Output: result.Output, Error: result.Error, Success: result.Success, DurationMs: durationMs, }) return result, nil } // UnregisterAll removes all tools matching given IDs. func (r *ToolRegistry) UnregisterAll(toolIDs []string) { r.mu.Lock() defer r.mu.Unlock() for _, id := range toolIDs { delete(r.tools, id) } } // GetCallLogs 获取工具调用记录(最新在前,支持按工具名过滤、分页) func (r *ToolRegistry) GetCallLogs(toolName string, limit, offset int) ([]CallLogRecord, int) { all := r.callLog.getAll() // 过滤 var filtered []CallLogRecord if toolName == "" { filtered = all } else { filtered = make([]CallLogRecord, 0) for _, rec := range all { if rec.ToolName == toolName { filtered = append(filtered, rec) } } } total := len(filtered) // 分页 if offset >= len(filtered) { return []CallLogRecord{}, total } page := filtered[offset:] if limit > 0 && limit < len(page) { page = page[:limit] } return page, total } // GetCallStats 获取工具调用统计 func (r *ToolRegistry) GetCallStats() map[string]interface{} { byTool := r.callLog.statsByTool() totalCalls, successCount, failCount, totalDurationMs := 0, 0, 0, 0 toolStats := make([]map[string]interface{}, 0, len(byTool)) for _, s := range byTool { count := s["count"].(int) success := s["success_count"].(int) fail := s["fail_count"].(int) totalDur := s["total_duration_ms"].(int) avgDur := 0.0 if count > 0 { avgDur = float64(totalDur) / float64(count) } s["avg_duration_ms"] = avgDur delete(s, "total_duration_ms") toolStats = append(toolStats, s) totalCalls += count successCount += success failCount += fail totalDurationMs += totalDur } avgDuration := 0.0 if totalCalls > 0 { avgDuration = float64(totalDurationMs) / float64(totalCalls) } successRate := 0.0 if totalCalls > 0 { successRate = float64(successCount) / float64(totalCalls) * 100 } return map[string]interface{}{ "total_calls": totalCalls, "success_count": successCount, "fail_count": failCount, "success_rate": successRate, "avg_duration_ms": avgDuration, "by_tool": toolStats, } }