package store import ( "database/sql" "encoding/json" "fmt" "log" "os" "time" _ "github.com/lib/pq" ) // CallLogRecord 工具调用记录 type CallLogRecord struct { ID int `json:"id"` CallID string `json:"call_id"` ToolName string `json:"tool_name"` Arguments json.RawMessage `json:"arguments"` Output string `json:"output"` Error string `json:"error"` Success bool `json:"success"` DurationMs int `json:"duration_ms"` UserID string `json:"user_id"` SessionID string `json:"session_id"` CreatedAt time.Time `json:"created_at"` } // CallLogQuery 查询参数 type CallLogQuery struct { ToolName string Page int Limit int } // CallLogPageResult 分页结果 type CallLogPageResult struct { Calls []CallLogRecord `json:"calls"` Total int `json:"total"` Page int `json:"page"` Limit int `json:"limit"` TotalPages int `json:"total_pages"` } // CallLogStats 调用统计 type CallLogStats struct { TotalCalls int `json:"total_calls"` SuccessCount int `json:"success_count"` FailCount int `json:"fail_count"` SuccessRate float64 `json:"success_rate"` AvgDuration float64 `json:"avg_duration_ms"` ByTool []ToolCallCount `json:"by_tool"` } // ToolCallCount 按工具统计 type ToolCallCount struct { ToolName string `json:"tool_name"` Count int `json:"count"` SuccessCount int `json:"success_count"` FailCount int `json:"fail_count"` AvgDuration float64 `json:"avg_duration_ms"` } // CallLogStore 工具调用日志存储 type CallLogStore struct { db *sql.DB } // NewCallLogStore 创建调用日志存储并自动建表 func NewCallLogStore(dbURL string) (*CallLogStore, error) { if dbURL == "" { log.Println("[call-log-store] DB_URL 未设置,工具调用日志将不会持久化") return &CallLogStore{}, nil } db, err := sql.Open("postgres", dbURL) if err != nil { return nil, fmt.Errorf("打开数据库连接失败: %w", err) } db.SetMaxOpenConns(5) db.SetMaxIdleConns(2) db.SetConnMaxLifetime(5 * time.Minute) if err := db.Ping(); err != nil { log.Printf("[call-log-store] 数据库连接失败: %v (将尝试继续运行)", err) return &CallLogStore{}, nil } store := &CallLogStore{db: db} if err := store.migrate(); err != nil { log.Printf("[call-log-store] 数据库迁移失败: %v", err) } log.Println("[call-log-store] 数据库连接成功,表已就绪") return store, nil } // migrate 创建表结构 func (s *CallLogStore) migrate() error { if s.db == nil { return nil } query := ` CREATE TABLE IF NOT EXISTS tool_call_logs ( id SERIAL PRIMARY KEY, call_id TEXT NOT NULL, tool_name TEXT NOT NULL, arguments JSONB, output TEXT, error TEXT, success BOOLEAN NOT NULL DEFAULT true, duration_ms INTEGER, user_id TEXT DEFAULT '', session_id TEXT DEFAULT '', created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() ); CREATE INDEX IF NOT EXISTS idx_tcl_tool_name ON tool_call_logs(tool_name); CREATE INDEX IF NOT EXISTS idx_tcl_created_at ON tool_call_logs(created_at); CREATE INDEX IF NOT EXISTS idx_tcl_user_id ON tool_call_logs(user_id); ` _, err := s.db.Exec(query) return err } // Insert 插入一条调用记录 func (s *CallLogStore) Insert(record *CallLogRecord) error { if s.db == nil { return nil } argsJSON, _ := json.Marshal(record.Arguments) _, err := s.db.Exec( `INSERT INTO tool_call_logs (call_id, tool_name, arguments, output, error, success, duration_ms, user_id, session_id, created_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`, record.CallID, record.ToolName, argsJSON, record.Output, record.Error, record.Success, record.DurationMs, record.UserID, record.SessionID, record.CreatedAt, ) if err != nil { log.Printf("[call-log-store] 插入记录失败: %v", err) return err } return nil } // Query 分页查询调用记录 func (s *CallLogStore) Query(q CallLogQuery) (*CallLogPageResult, error) { if s.db == nil { return &CallLogPageResult{Calls: []CallLogRecord{}, Total: 0, Page: q.Page, Limit: q.Limit, TotalPages: 0}, nil } if q.Page < 1 { q.Page = 1 } if q.Limit < 1 || q.Limit > 100 { q.Limit = 20 } // 构建 WHERE 条件 where := "" whereArgs := []interface{}{} argIdx := 1 if q.ToolName != "" { where = fmt.Sprintf(" WHERE tool_name = $%d", argIdx) whereArgs = append(whereArgs, q.ToolName) argIdx++ } // 计数 countQuery := "SELECT COUNT(*) FROM tool_call_logs" + where var total int if err := s.db.QueryRow(countQuery, whereArgs...).Scan(&total); err != nil { return nil, fmt.Errorf("查询总数失败: %w", err) } // 分页查询 offset := (q.Page - 1) * q.Limit querySql := fmt.Sprintf( "SELECT id, call_id, tool_name, arguments, COALESCE(output,''), COALESCE(error,''), success, COALESCE(duration_ms,0), COALESCE(user_id,''), COALESCE(session_id,''), created_at FROM tool_call_logs%s ORDER BY created_at DESC LIMIT $%d OFFSET $%d", where, argIdx, argIdx+1, ) queryArgs := append(whereArgs, q.Limit, offset) rows, err := s.db.Query(querySql, queryArgs...) if err != nil { return nil, fmt.Errorf("查询记录失败: %w", err) } defer rows.Close() calls := make([]CallLogRecord, 0) for rows.Next() { var r CallLogRecord var argsJSON []byte if err := rows.Scan(&r.ID, &r.CallID, &r.ToolName, &argsJSON, &r.Output, &r.Error, &r.Success, &r.DurationMs, &r.UserID, &r.SessionID, &r.CreatedAt); err != nil { log.Printf("[call-log-store] 扫描行失败: %v", err) continue } r.Arguments = argsJSON calls = append(calls, r) } totalPages := (total + q.Limit - 1) / q.Limit return &CallLogPageResult{ Calls: calls, Total: total, Page: q.Page, Limit: q.Limit, TotalPages: totalPages, }, nil } // Stats 获取调用统计 func (s *CallLogStore) Stats() (*CallLogStats, error) { if s.db == nil { return &CallLogStats{}, nil } stats := &CallLogStats{} // 总体统计 err := s.db.QueryRow( "SELECT COUNT(*), COUNT(*) FILTER (WHERE success=true), COUNT(*) FILTER (WHERE success=false), COALESCE(AVG(duration_ms),0) FROM tool_call_logs", ).Scan(&stats.TotalCalls, &stats.SuccessCount, &stats.FailCount, &stats.AvgDuration) if err != nil { return nil, fmt.Errorf("查询总体统计失败: %w", err) } if stats.TotalCalls > 0 { stats.SuccessRate = float64(stats.SuccessCount) / float64(stats.TotalCalls) * 100 } // 按工具统计 rows, err := s.db.Query( "SELECT tool_name, COUNT(*), COUNT(*) FILTER (WHERE success=true), COUNT(*) FILTER (WHERE success=false), COALESCE(AVG(duration_ms),0) FROM tool_call_logs GROUP BY tool_name ORDER BY COUNT(*) DESC", ) if err != nil { return nil, fmt.Errorf("查询按工具统计失败: %w", err) } defer rows.Close() stats.ByTool = make([]ToolCallCount, 0) for rows.Next() { var tc ToolCallCount if err := rows.Scan(&tc.ToolName, &tc.Count, &tc.SuccessCount, &tc.FailCount, &tc.AvgDuration); err != nil { log.Printf("[call-log-store] 扫描工具统计失败: %v", err) continue } stats.ByTool = append(stats.ByTool, tc) } return stats, nil } // Close 关闭数据库连接 func (s *CallLogStore) Close() { if s.db != nil { s.db.Close() } } // DBUrlFromEnv 从环境变量获取数据库连接 func DBUrlFromEnv() string { // 如果设置了 DB_URL 直接使用 if url := os.Getenv("DB_URL"); url != "" { return url } // 否则从单独的环境变量构建 host := getEnv("DB_HOST", "localhost") port := getEnv("DB_PORT", "5432") user := getEnv("DB_USER", "cyrene") pass := getEnv("DB_PASSWORD", "change_me") dbname := getEnv("DB_NAME", "cyrene_ai") sslmode := getEnv("DB_SSLMODE", "disable") return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s", user, pass, host, port, dbname, sslmode) } func getEnv(key, fallback string) string { if v := os.Getenv(key); v != "" { return v } return fallback }