package handler import ( "crypto/rand" "encoding/json" "fmt" "log" "net/http" "strconv" "strings" "time" "github.com/yourname/cyrene-ai/tool-engine/internal/model" "github.com/yourname/cyrene-ai/tool-engine/internal/service" "github.com/yourname/cyrene-ai/tool-engine/internal/store" ) // ToolHandler HTTP API 处理器 type ToolHandler struct { svc *service.ToolService callLogStore *store.CallLogStore } // NewToolHandler 创建工具处理器 func NewToolHandler(svc *service.ToolService, callLogStore *store.CallLogStore) *ToolHandler { return &ToolHandler{svc: svc, callLogStore: callLogStore} } // RegisterRoutes 注册所有路由到 mux func (h *ToolHandler) RegisterRoutes(mux *http.ServeMux) { // GET /api/v1/tools - 列出所有工具 mux.HandleFunc("/api/v1/tools", h.handleTools) // GET /api/v1/tools/ - 工具详情和单个执行 (带名称) mux.HandleFunc("/api/v1/tools/", h.handleToolByName) // POST /api/v1/tools/execute - 批量执行 mux.HandleFunc("/api/v1/tools/execute", h.handleBatchExecute) } // handleTools GET /api/v1/tools - 列出所有工具 func (h *ToolHandler) handleTools(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return } tools := h.svc.ListTools() if tools == nil { tools = []model.ToolDefinition{} } writeJSON(w, http.StatusOK, map[string]interface{}{ "tools": tools, "total": len(tools), }) } // handleToolByName 处理 /api/v1/tools/{name} 和 /api/v1/tools/{name}/execute 和 /api/v1/tools/calls 和 /api/v1/tools/calls/stats func (h *ToolHandler) handleToolByName(w http.ResponseWriter, r *http.Request) { // 解析路径: /api/v1/tools/{name} 或 /api/v1/tools/{name}/execute path := strings.TrimPrefix(r.URL.Path, "/api/v1/tools/") parts := strings.SplitN(path, "/", 2) toolName := parts[0] if toolName == "" { writeError(w, http.StatusBadRequest, "缺少工具名称") return } // 处理 /api/v1/tools/calls/stats if toolName == "calls" && len(parts) == 2 && parts[1] == "stats" { if r.Method != http.MethodGet { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return } h.handleCallStats(w, r) return } // 处理 /api/v1/tools/calls if toolName == "calls" { if r.Method != http.MethodGet { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return } h.handleCallLogs(w, r) return } // 判断是否为执行请求 if len(parts) == 2 && parts[1] == "execute" { if r.Method != http.MethodPost { writeError(w, http.StatusMethodNotAllowed, "method not allowed, use POST") return } h.executeTool(w, r, toolName) return } // GET /api/v1/tools/{name} - 获取工具定义 if r.Method != http.MethodGet { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return } def, ok := h.svc.GetTool(toolName) if !ok { writeError(w, http.StatusNotFound, "工具 "+toolName+" 不存在") return } writeJSON(w, http.StatusOK, def) } // executeTool POST /api/v1/tools/{name}/execute - 执行单个工具 func (h *ToolHandler) executeTool(w http.ResponseWriter, r *http.Request, toolName string) { var req model.ExecuteRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "请求体格式错误: "+err.Error()) return } if req.Arguments == nil { req.Arguments = make(map[string]interface{}) } startTime := time.Now() result, err := h.svc.Execute(r.Context(), toolName, req.Arguments) durationMs := int(time.Since(startTime).Milliseconds()) if err != nil { log.Printf("[tool-handler] 执行工具 %s 失败: %v", toolName, err) h.logCall(toolName, req.Arguments, "", err.Error(), false, durationMs, r) writeError(w, http.StatusInternalServerError, err.Error()) return } // 异步记录调用日志 h.logCall(toolName, req.Arguments, result.Output, result.Error, result.Error == "" && err == nil, durationMs, r) writeJSON(w, http.StatusOK, result) } // handleBatchExecute POST /api/v1/tools/execute - 批量执行 func (h *ToolHandler) handleBatchExecute(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeError(w, http.StatusMethodNotAllowed, "method not allowed, use POST") return } var req model.BatchExecuteRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeError(w, http.StatusBadRequest, "请求体格式错误: "+err.Error()) return } if len(req.Calls) == 0 { writeError(w, http.StatusBadRequest, "calls 不能为空") return } startTime := time.Now() response := h.svc.ExecuteBatch(r.Context(), req.Calls) batchDuration := int(time.Since(startTime).Milliseconds()) // 异步记录每个调用 for i, call := range req.Calls { var output, errStr string var success bool if i < len(response.Results) { output = response.Results[i].Output errStr = response.Results[i].Error success = errStr == "" } h.logCall(call.Name, call.Arguments, output, errStr, success, batchDuration, r) } writeJSON(w, http.StatusOK, response) } // newUUID generates a UUID v4 string using crypto/rand func newUUID() string { b := make([]byte, 16) _, _ = rand.Read(b) b[6] = (b[6] & 0x0f) | 0x40 // Version 4 b[8] = (b[8] & 0x3f) | 0x80 // Variant 10 return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]) } // logCall 异步记录工具调用日志 func (h *ToolHandler) logCall(toolName string, args map[string]interface{}, output, errStr string, success bool, durationMs int, r *http.Request) { if h.callLogStore == nil { return } callID := newUUID() userID := r.URL.Query().Get("user_id") sessionID := r.URL.Query().Get("session_id") go func() { argsJSON, _ := json.Marshal(args) record := &store.CallLogRecord{ CallID: callID, ToolName: toolName, Arguments: argsJSON, Output: output, Error: errStr, Success: success, DurationMs: durationMs, UserID: userID, SessionID: sessionID, CreatedAt: time.Now(), } if err := h.callLogStore.Insert(record); err != nil { log.Printf("[tool-handler] 记录调用日志失败: %v", err) } }() } // handleCallLogs GET /api/v1/tools/calls - 查询调用记录 func (h *ToolHandler) handleCallLogs(w http.ResponseWriter, r *http.Request) { if h.callLogStore == nil { writeJSON(w, http.StatusOK, map[string]interface{}{ "calls": []interface{}{}, "total": 0, "page": 1, "limit": 20, "total_pages": 0, }) return } q := r.URL.Query() page, _ := strconv.Atoi(q.Get("page")) if page < 1 { page = 1 } limit, _ := strconv.Atoi(q.Get("limit")) if limit < 1 || limit > 100 { limit = 20 } query := store.CallLogQuery{ ToolName: q.Get("tool_name"), Page: page, Limit: limit, } result, err := h.callLogStore.Query(query) if err != nil { log.Printf("[tool-handler] 查询调用记录失败: %v", err) writeError(w, http.StatusInternalServerError, "查询调用记录失败: "+err.Error()) return } writeJSON(w, http.StatusOK, result) } // handleCallStats GET /api/v1/tools/calls/stats - 调用统计 func (h *ToolHandler) handleCallStats(w http.ResponseWriter, r *http.Request) { if h.callLogStore == nil { writeJSON(w, http.StatusOK, map[string]interface{}{ "total_calls": 0, "success_count": 0, "fail_count": 0, "success_rate": 0, "avg_duration_ms": 0, "by_tool": []interface{}{}, }) return } stats, err := h.callLogStore.Stats() if err != nil { log.Printf("[tool-handler] 查询调用统计失败: %v", err) writeError(w, http.StatusInternalServerError, "查询调用统计失败: "+err.Error()) return } writeJSON(w, http.StatusOK, stats) } // writeJSON 写入 JSON 响应 func writeJSON(w http.ResponseWriter, status int, data interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) if err := json.NewEncoder(w).Encode(data); err != nil { log.Printf("[tool-handler] JSON 编码失败: %v", err) } } // writeError 写入错误响应 func writeError(w http.ResponseWriter, status int, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(map[string]string{ "error": message, }) }