87214b9441
Phase 1 (基础设施): - ThinkChain 思考链连续性 + 差异化思考提示词 (persistent) - AutonomousToolPolicy 工具安全策略 (safe/unsafe/conditional) - MessageScheduler 自适应消息节奏 (Idle/Available/Busy) - SessionEnrichmentStore 渐进式上下文丰富 (5层) - ConversationBus 事件总线 + ResponseCache (dedup) - pkg/logger 统一日志 + 所有 handler 替换 fmt.Printf - NPE 守卫/链路优化/数据库表修复/Go workspace Phase 2 (人格交互): - EmotionState/EmotionTracker 情感状态机 (5种心情, 情绪衰减) - ProactiveGuard 主动消息多维决策 (静默时段/紧急度/频率/校验) - Gateway↔ai-core 在线状态感知链路 (presence notification) - 离线思考频率控制 + 重连问候 + 离线消息排队 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
301 lines
8.2 KiB
Go
301 lines
8.2 KiB
Go
package handler
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/yourname/cyrene-ai/pkg/logger"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
|
"github.com/yourname/cyrene-ai/tool-engine/internal/service"
|
|
"github.com/yourname/cyrene-ai/tool-engine/internal/store"
|
|
)
|
|
|
|
// ToolHandler HTTP API 处理器
|
|
type ToolHandler struct {
|
|
svc *service.ToolService
|
|
callLogStore *store.CallLogStore
|
|
}
|
|
|
|
// NewToolHandler 创建工具处理器
|
|
func NewToolHandler(svc *service.ToolService, callLogStore *store.CallLogStore) *ToolHandler {
|
|
return &ToolHandler{svc: svc, callLogStore: callLogStore}
|
|
}
|
|
|
|
// RegisterRoutes 注册所有路由到 mux
|
|
func (h *ToolHandler) RegisterRoutes(mux *http.ServeMux) {
|
|
// GET /api/v1/tools - 列出所有工具
|
|
mux.HandleFunc("/api/v1/tools", h.handleTools)
|
|
// GET /api/v1/tools/ - 工具详情和单个执行 (带名称)
|
|
mux.HandleFunc("/api/v1/tools/", h.handleToolByName)
|
|
// POST /api/v1/tools/execute - 批量执行
|
|
mux.HandleFunc("/api/v1/tools/execute", h.handleBatchExecute)
|
|
}
|
|
|
|
// handleTools GET /api/v1/tools - 列出所有工具
|
|
func (h *ToolHandler) handleTools(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
|
|
tools := h.svc.ListTools()
|
|
if tools == nil {
|
|
tools = []model.ToolDefinition{}
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
|
"tools": tools,
|
|
"total": len(tools),
|
|
})
|
|
}
|
|
|
|
// handleToolByName 处理 /api/v1/tools/{name} 和 /api/v1/tools/{name}/execute 和 /api/v1/tools/calls 和 /api/v1/tools/calls/stats
|
|
func (h *ToolHandler) handleToolByName(w http.ResponseWriter, r *http.Request) {
|
|
// 解析路径: /api/v1/tools/{name} 或 /api/v1/tools/{name}/execute
|
|
path := strings.TrimPrefix(r.URL.Path, "/api/v1/tools/")
|
|
parts := strings.SplitN(path, "/", 2)
|
|
|
|
toolName := parts[0]
|
|
if toolName == "" {
|
|
writeError(w, http.StatusBadRequest, "缺少工具名称")
|
|
return
|
|
}
|
|
|
|
// 处理 /api/v1/tools/calls/stats
|
|
if toolName == "calls" && len(parts) == 2 && parts[1] == "stats" {
|
|
if r.Method != http.MethodGet {
|
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
h.handleCallStats(w, r)
|
|
return
|
|
}
|
|
|
|
// 处理 /api/v1/tools/calls
|
|
if toolName == "calls" {
|
|
if r.Method != http.MethodGet {
|
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
h.handleCallLogs(w, r)
|
|
return
|
|
}
|
|
|
|
// 判断是否为执行请求
|
|
if len(parts) == 2 && parts[1] == "execute" {
|
|
if r.Method != http.MethodPost {
|
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed, use POST")
|
|
return
|
|
}
|
|
h.executeTool(w, r, toolName)
|
|
return
|
|
}
|
|
|
|
// GET /api/v1/tools/{name} - 获取工具定义
|
|
if r.Method != http.MethodGet {
|
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
|
|
def, ok := h.svc.GetTool(toolName)
|
|
if !ok {
|
|
writeError(w, http.StatusNotFound, "工具 "+toolName+" 不存在")
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, def)
|
|
}
|
|
|
|
// executeTool POST /api/v1/tools/{name}/execute - 执行单个工具
|
|
func (h *ToolHandler) executeTool(w http.ResponseWriter, r *http.Request, toolName string) {
|
|
var req model.ExecuteRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeError(w, http.StatusBadRequest, "请求体格式错误: "+err.Error())
|
|
return
|
|
}
|
|
|
|
if req.Arguments == nil {
|
|
req.Arguments = make(map[string]interface{})
|
|
}
|
|
|
|
startTime := time.Now()
|
|
result, err := h.svc.Execute(r.Context(), toolName, req.Arguments)
|
|
durationMs := int(time.Since(startTime).Milliseconds())
|
|
|
|
if err != nil {
|
|
logger.Printf("[tool-handler] 执行工具 %s 失败: %v", toolName, err)
|
|
h.logCall(toolName, req.Arguments, "", err.Error(), false, durationMs, r)
|
|
writeError(w, http.StatusInternalServerError, err.Error())
|
|
return
|
|
}
|
|
|
|
// 异步记录调用日志
|
|
h.logCall(toolName, req.Arguments, result.Output, result.Error, result.Error == "" && err == nil, durationMs, r)
|
|
|
|
writeJSON(w, http.StatusOK, result)
|
|
}
|
|
|
|
// handleBatchExecute POST /api/v1/tools/execute - 批量执行
|
|
func (h *ToolHandler) handleBatchExecute(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed, use POST")
|
|
return
|
|
}
|
|
|
|
var req model.BatchExecuteRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeError(w, http.StatusBadRequest, "请求体格式错误: "+err.Error())
|
|
return
|
|
}
|
|
|
|
if len(req.Calls) == 0 {
|
|
writeError(w, http.StatusBadRequest, "calls 不能为空")
|
|
return
|
|
}
|
|
|
|
startTime := time.Now()
|
|
response := h.svc.ExecuteBatch(r.Context(), req.Calls)
|
|
batchDuration := int(time.Since(startTime).Milliseconds())
|
|
|
|
// 异步记录每个调用
|
|
for i, call := range req.Calls {
|
|
var output, errStr string
|
|
var success bool
|
|
if i < len(response.Results) {
|
|
output = response.Results[i].Output
|
|
errStr = response.Results[i].Error
|
|
success = errStr == ""
|
|
}
|
|
h.logCall(call.Name, call.Arguments, output, errStr, success, batchDuration, r)
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, response)
|
|
}
|
|
|
|
// newUUID generates a UUID v4 string using crypto/rand
|
|
func newUUID() string {
|
|
b := make([]byte, 16)
|
|
_, _ = rand.Read(b)
|
|
b[6] = (b[6] & 0x0f) | 0x40 // Version 4
|
|
b[8] = (b[8] & 0x3f) | 0x80 // Variant 10
|
|
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
|
}
|
|
|
|
// logCall 异步记录工具调用日志
|
|
func (h *ToolHandler) logCall(toolName string, args map[string]interface{}, output, errStr string, success bool, durationMs int, r *http.Request) {
|
|
if h.callLogStore == nil {
|
|
return
|
|
}
|
|
|
|
callID := newUUID()
|
|
userID := r.URL.Query().Get("user_id")
|
|
sessionID := r.URL.Query().Get("session_id")
|
|
|
|
go func() {
|
|
argsJSON, _ := json.Marshal(args)
|
|
record := &store.CallLogRecord{
|
|
CallID: callID,
|
|
ToolName: toolName,
|
|
Arguments: argsJSON,
|
|
Output: output,
|
|
Error: errStr,
|
|
Success: success,
|
|
DurationMs: durationMs,
|
|
UserID: userID,
|
|
SessionID: sessionID,
|
|
CreatedAt: time.Now(),
|
|
}
|
|
if err := h.callLogStore.Insert(record); err != nil {
|
|
logger.Printf("[tool-handler] 记录调用日志失败: %v", err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// handleCallLogs GET /api/v1/tools/calls - 查询调用记录
|
|
func (h *ToolHandler) handleCallLogs(w http.ResponseWriter, r *http.Request) {
|
|
if h.callLogStore == nil {
|
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
|
"calls": []interface{}{},
|
|
"total": 0,
|
|
"page": 1,
|
|
"limit": 20,
|
|
"total_pages": 0,
|
|
})
|
|
return
|
|
}
|
|
|
|
q := r.URL.Query()
|
|
|
|
page, _ := strconv.Atoi(q.Get("page"))
|
|
if page < 1 {
|
|
page = 1
|
|
}
|
|
|
|
limit, _ := strconv.Atoi(q.Get("limit"))
|
|
if limit < 1 || limit > 100 {
|
|
limit = 20
|
|
}
|
|
|
|
query := store.CallLogQuery{
|
|
ToolName: q.Get("tool_name"),
|
|
Page: page,
|
|
Limit: limit,
|
|
}
|
|
|
|
result, err := h.callLogStore.Query(query)
|
|
if err != nil {
|
|
logger.Printf("[tool-handler] 查询调用记录失败: %v", err)
|
|
writeError(w, http.StatusInternalServerError, "查询调用记录失败: "+err.Error())
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, result)
|
|
}
|
|
|
|
// handleCallStats GET /api/v1/tools/calls/stats - 调用统计
|
|
func (h *ToolHandler) handleCallStats(w http.ResponseWriter, r *http.Request) {
|
|
if h.callLogStore == nil {
|
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
|
"total_calls": 0,
|
|
"success_count": 0,
|
|
"fail_count": 0,
|
|
"success_rate": 0,
|
|
"avg_duration_ms": 0,
|
|
"by_tool": []interface{}{},
|
|
})
|
|
return
|
|
}
|
|
|
|
stats, err := h.callLogStore.Stats()
|
|
if err != nil {
|
|
logger.Printf("[tool-handler] 查询调用统计失败: %v", err)
|
|
writeError(w, http.StatusInternalServerError, "查询调用统计失败: "+err.Error())
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, stats)
|
|
}
|
|
|
|
// writeJSON 写入 JSON 响应
|
|
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
if err := json.NewEncoder(w).Encode(data); err != nil {
|
|
logger.Printf("[tool-handler] JSON 编码失败: %v", err)
|
|
}
|
|
}
|
|
|
|
// writeError 写入错误响应
|
|
func writeError(w http.ResponseWriter, status int, message string) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": message,
|
|
})
|
|
}
|