Files
Cyrene/backend/tool-engine/internal/store/call_log_store.go
T
AskaEth 78e3f450c2 feat: Round 5 - Memory Service, Tool Engine, Call Records, Thinking Logs
- Fix: Session history flash (race condition + WS guard)
- Fix: Chat background overlay + sidebar transparency
- Fix: IoT device control (Chinese action names, status field)
- Feat: Independent memory-service (port 8091, 13 endpoints)
- Feat: Independent tool-engine service (port 8092, 13 tools)
- Feat: Tool call logs with paginated DevTools panel
- Feat: Thinking log records with DevTools panel
- Feat: Future development roadmap document
- Chore: Updated .gitignore, go.work, DevTools config
- Chore: 5-service health check, project review docs
2026-05-18 20:05:14 +08:00

290 lines
7.9 KiB
Go

package store
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"os"
"time"
_ "github.com/lib/pq"
)
// CallLogRecord 工具调用记录
type CallLogRecord struct {
ID int `json:"id"`
CallID string `json:"call_id"`
ToolName string `json:"tool_name"`
Arguments json.RawMessage `json:"arguments"`
Output string `json:"output"`
Error string `json:"error"`
Success bool `json:"success"`
DurationMs int `json:"duration_ms"`
UserID string `json:"user_id"`
SessionID string `json:"session_id"`
CreatedAt time.Time `json:"created_at"`
}
// CallLogQuery 查询参数
type CallLogQuery struct {
ToolName string
Page int
Limit int
}
// CallLogPageResult 分页结果
type CallLogPageResult struct {
Calls []CallLogRecord `json:"calls"`
Total int `json:"total"`
Page int `json:"page"`
Limit int `json:"limit"`
TotalPages int `json:"total_pages"`
}
// CallLogStats 调用统计
type CallLogStats struct {
TotalCalls int `json:"total_calls"`
SuccessCount int `json:"success_count"`
FailCount int `json:"fail_count"`
SuccessRate float64 `json:"success_rate"`
AvgDuration float64 `json:"avg_duration_ms"`
ByTool []ToolCallCount `json:"by_tool"`
}
// ToolCallCount 按工具统计
type ToolCallCount struct {
ToolName string `json:"tool_name"`
Count int `json:"count"`
SuccessCount int `json:"success_count"`
FailCount int `json:"fail_count"`
AvgDuration float64 `json:"avg_duration_ms"`
}
// CallLogStore 工具调用日志存储
type CallLogStore struct {
db *sql.DB
}
// NewCallLogStore 创建调用日志存储并自动建表
func NewCallLogStore(dbURL string) (*CallLogStore, error) {
if dbURL == "" {
log.Println("[call-log-store] DB_URL 未设置,工具调用日志将不会持久化")
return &CallLogStore{}, nil
}
db, err := sql.Open("postgres", dbURL)
if err != nil {
return nil, fmt.Errorf("打开数据库连接失败: %w", err)
}
db.SetMaxOpenConns(5)
db.SetMaxIdleConns(2)
db.SetConnMaxLifetime(5 * time.Minute)
if err := db.Ping(); err != nil {
log.Printf("[call-log-store] 数据库连接失败: %v (将尝试继续运行)", err)
return &CallLogStore{}, nil
}
store := &CallLogStore{db: db}
if err := store.migrate(); err != nil {
log.Printf("[call-log-store] 数据库迁移失败: %v", err)
}
log.Println("[call-log-store] 数据库连接成功,表已就绪")
return store, nil
}
// migrate 创建表结构
func (s *CallLogStore) migrate() error {
if s.db == nil {
return nil
}
query := `
CREATE TABLE IF NOT EXISTS tool_call_logs (
id SERIAL PRIMARY KEY,
call_id TEXT NOT NULL,
tool_name TEXT NOT NULL,
arguments JSONB,
output TEXT,
error TEXT,
success BOOLEAN NOT NULL DEFAULT true,
duration_ms INTEGER,
user_id TEXT DEFAULT '',
session_id TEXT DEFAULT '',
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_tcl_tool_name ON tool_call_logs(tool_name);
CREATE INDEX IF NOT EXISTS idx_tcl_created_at ON tool_call_logs(created_at);
CREATE INDEX IF NOT EXISTS idx_tcl_user_id ON tool_call_logs(user_id);
`
_, err := s.db.Exec(query)
return err
}
// Insert 插入一条调用记录
func (s *CallLogStore) Insert(record *CallLogRecord) error {
if s.db == nil {
return nil
}
argsJSON, _ := json.Marshal(record.Arguments)
_, err := s.db.Exec(
`INSERT INTO tool_call_logs (call_id, tool_name, arguments, output, error, success, duration_ms, user_id, session_id, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`,
record.CallID, record.ToolName, argsJSON, record.Output, record.Error,
record.Success, record.DurationMs, record.UserID, record.SessionID, record.CreatedAt,
)
if err != nil {
log.Printf("[call-log-store] 插入记录失败: %v", err)
return err
}
return nil
}
// Query 分页查询调用记录
func (s *CallLogStore) Query(q CallLogQuery) (*CallLogPageResult, error) {
if s.db == nil {
return &CallLogPageResult{Calls: []CallLogRecord{}, Total: 0, Page: q.Page, Limit: q.Limit, TotalPages: 0}, nil
}
if q.Page < 1 {
q.Page = 1
}
if q.Limit < 1 || q.Limit > 100 {
q.Limit = 20
}
// 构建 WHERE 条件
where := ""
whereArgs := []interface{}{}
argIdx := 1
if q.ToolName != "" {
where = fmt.Sprintf(" WHERE tool_name = $%d", argIdx)
whereArgs = append(whereArgs, q.ToolName)
argIdx++
}
// 计数
countQuery := "SELECT COUNT(*) FROM tool_call_logs" + where
var total int
if err := s.db.QueryRow(countQuery, whereArgs...).Scan(&total); err != nil {
return nil, fmt.Errorf("查询总数失败: %w", err)
}
// 分页查询
offset := (q.Page - 1) * q.Limit
querySql := fmt.Sprintf(
"SELECT id, call_id, tool_name, arguments, COALESCE(output,''), COALESCE(error,''), success, COALESCE(duration_ms,0), COALESCE(user_id,''), COALESCE(session_id,''), created_at FROM tool_call_logs%s ORDER BY created_at DESC LIMIT $%d OFFSET $%d",
where, argIdx, argIdx+1,
)
queryArgs := append(whereArgs, q.Limit, offset)
rows, err := s.db.Query(querySql, queryArgs...)
if err != nil {
return nil, fmt.Errorf("查询记录失败: %w", err)
}
defer rows.Close()
calls := make([]CallLogRecord, 0)
for rows.Next() {
var r CallLogRecord
var argsJSON []byte
if err := rows.Scan(&r.ID, &r.CallID, &r.ToolName, &argsJSON, &r.Output, &r.Error, &r.Success, &r.DurationMs, &r.UserID, &r.SessionID, &r.CreatedAt); err != nil {
log.Printf("[call-log-store] 扫描行失败: %v", err)
continue
}
r.Arguments = argsJSON
calls = append(calls, r)
}
totalPages := (total + q.Limit - 1) / q.Limit
return &CallLogPageResult{
Calls: calls,
Total: total,
Page: q.Page,
Limit: q.Limit,
TotalPages: totalPages,
}, nil
}
// Stats 获取调用统计
func (s *CallLogStore) Stats() (*CallLogStats, error) {
if s.db == nil {
return &CallLogStats{}, nil
}
stats := &CallLogStats{}
// 总体统计
err := s.db.QueryRow(
"SELECT COUNT(*), COUNT(*) FILTER (WHERE success=true), COUNT(*) FILTER (WHERE success=false), COALESCE(AVG(duration_ms),0) FROM tool_call_logs",
).Scan(&stats.TotalCalls, &stats.SuccessCount, &stats.FailCount, &stats.AvgDuration)
if err != nil {
return nil, fmt.Errorf("查询总体统计失败: %w", err)
}
if stats.TotalCalls > 0 {
stats.SuccessRate = float64(stats.SuccessCount) / float64(stats.TotalCalls) * 100
}
// 按工具统计
rows, err := s.db.Query(
"SELECT tool_name, COUNT(*), COUNT(*) FILTER (WHERE success=true), COUNT(*) FILTER (WHERE success=false), COALESCE(AVG(duration_ms),0) FROM tool_call_logs GROUP BY tool_name ORDER BY COUNT(*) DESC",
)
if err != nil {
return nil, fmt.Errorf("查询按工具统计失败: %w", err)
}
defer rows.Close()
stats.ByTool = make([]ToolCallCount, 0)
for rows.Next() {
var tc ToolCallCount
if err := rows.Scan(&tc.ToolName, &tc.Count, &tc.SuccessCount, &tc.FailCount, &tc.AvgDuration); err != nil {
log.Printf("[call-log-store] 扫描工具统计失败: %v", err)
continue
}
stats.ByTool = append(stats.ByTool, tc)
}
return stats, nil
}
// Close 关闭数据库连接
func (s *CallLogStore) Close() {
if s.db != nil {
s.db.Close()
}
}
// DBUrlFromEnv 从环境变量获取数据库连接
func DBUrlFromEnv() string {
// 如果设置了 DB_URL 直接使用
if url := os.Getenv("DB_URL"); url != "" {
return url
}
// 否则从单独的环境变量构建
host := getEnv("DB_HOST", "localhost")
port := getEnv("DB_PORT", "5432")
user := getEnv("DB_USER", "cyrene")
pass := getEnv("DB_PASSWORD", "change_me")
dbname := getEnv("DB_NAME", "cyrene_ai")
sslmode := getEnv("DB_SSLMODE", "disable")
return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s", user, pass, host, port, dbname, sslmode)
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}