78e3f450c2
- Fix: Session history flash (race condition + WS guard) - Fix: Chat background overlay + sidebar transparency - Fix: IoT device control (Chinese action names, status field) - Feat: Independent memory-service (port 8091, 13 endpoints) - Feat: Independent tool-engine service (port 8092, 13 tools) - Feat: Tool call logs with paginated DevTools panel - Feat: Thinking log records with DevTools panel - Feat: Future development roadmap document - Chore: Updated .gitignore, go.work, DevTools config - Chore: 5-service health check, project review docs
290 lines
7.9 KiB
Go
290 lines
7.9 KiB
Go
package store
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"time"
|
|
|
|
_ "github.com/lib/pq"
|
|
)
|
|
|
|
// CallLogRecord 工具调用记录
|
|
type CallLogRecord struct {
|
|
ID int `json:"id"`
|
|
CallID string `json:"call_id"`
|
|
ToolName string `json:"tool_name"`
|
|
Arguments json.RawMessage `json:"arguments"`
|
|
Output string `json:"output"`
|
|
Error string `json:"error"`
|
|
Success bool `json:"success"`
|
|
DurationMs int `json:"duration_ms"`
|
|
UserID string `json:"user_id"`
|
|
SessionID string `json:"session_id"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
// CallLogQuery 查询参数
|
|
type CallLogQuery struct {
|
|
ToolName string
|
|
Page int
|
|
Limit int
|
|
}
|
|
|
|
// CallLogPageResult 分页结果
|
|
type CallLogPageResult struct {
|
|
Calls []CallLogRecord `json:"calls"`
|
|
Total int `json:"total"`
|
|
Page int `json:"page"`
|
|
Limit int `json:"limit"`
|
|
TotalPages int `json:"total_pages"`
|
|
}
|
|
|
|
// CallLogStats 调用统计
|
|
type CallLogStats struct {
|
|
TotalCalls int `json:"total_calls"`
|
|
SuccessCount int `json:"success_count"`
|
|
FailCount int `json:"fail_count"`
|
|
SuccessRate float64 `json:"success_rate"`
|
|
AvgDuration float64 `json:"avg_duration_ms"`
|
|
ByTool []ToolCallCount `json:"by_tool"`
|
|
}
|
|
|
|
// ToolCallCount 按工具统计
|
|
type ToolCallCount struct {
|
|
ToolName string `json:"tool_name"`
|
|
Count int `json:"count"`
|
|
SuccessCount int `json:"success_count"`
|
|
FailCount int `json:"fail_count"`
|
|
AvgDuration float64 `json:"avg_duration_ms"`
|
|
}
|
|
|
|
// CallLogStore 工具调用日志存储
|
|
type CallLogStore struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewCallLogStore 创建调用日志存储并自动建表
|
|
func NewCallLogStore(dbURL string) (*CallLogStore, error) {
|
|
if dbURL == "" {
|
|
log.Println("[call-log-store] DB_URL 未设置,工具调用日志将不会持久化")
|
|
return &CallLogStore{}, nil
|
|
}
|
|
|
|
db, err := sql.Open("postgres", dbURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("打开数据库连接失败: %w", err)
|
|
}
|
|
|
|
db.SetMaxOpenConns(5)
|
|
db.SetMaxIdleConns(2)
|
|
db.SetConnMaxLifetime(5 * time.Minute)
|
|
|
|
if err := db.Ping(); err != nil {
|
|
log.Printf("[call-log-store] 数据库连接失败: %v (将尝试继续运行)", err)
|
|
return &CallLogStore{}, nil
|
|
}
|
|
|
|
store := &CallLogStore{db: db}
|
|
if err := store.migrate(); err != nil {
|
|
log.Printf("[call-log-store] 数据库迁移失败: %v", err)
|
|
}
|
|
|
|
log.Println("[call-log-store] 数据库连接成功,表已就绪")
|
|
return store, nil
|
|
}
|
|
|
|
// migrate 创建表结构
|
|
func (s *CallLogStore) migrate() error {
|
|
if s.db == nil {
|
|
return nil
|
|
}
|
|
|
|
query := `
|
|
CREATE TABLE IF NOT EXISTS tool_call_logs (
|
|
id SERIAL PRIMARY KEY,
|
|
call_id TEXT NOT NULL,
|
|
tool_name TEXT NOT NULL,
|
|
arguments JSONB,
|
|
output TEXT,
|
|
error TEXT,
|
|
success BOOLEAN NOT NULL DEFAULT true,
|
|
duration_ms INTEGER,
|
|
user_id TEXT DEFAULT '',
|
|
session_id TEXT DEFAULT '',
|
|
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_tcl_tool_name ON tool_call_logs(tool_name);
|
|
CREATE INDEX IF NOT EXISTS idx_tcl_created_at ON tool_call_logs(created_at);
|
|
CREATE INDEX IF NOT EXISTS idx_tcl_user_id ON tool_call_logs(user_id);
|
|
`
|
|
|
|
_, err := s.db.Exec(query)
|
|
return err
|
|
}
|
|
|
|
// Insert 插入一条调用记录
|
|
func (s *CallLogStore) Insert(record *CallLogRecord) error {
|
|
if s.db == nil {
|
|
return nil
|
|
}
|
|
|
|
argsJSON, _ := json.Marshal(record.Arguments)
|
|
_, err := s.db.Exec(
|
|
`INSERT INTO tool_call_logs (call_id, tool_name, arguments, output, error, success, duration_ms, user_id, session_id, created_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`,
|
|
record.CallID, record.ToolName, argsJSON, record.Output, record.Error,
|
|
record.Success, record.DurationMs, record.UserID, record.SessionID, record.CreatedAt,
|
|
)
|
|
if err != nil {
|
|
log.Printf("[call-log-store] 插入记录失败: %v", err)
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Query 分页查询调用记录
|
|
func (s *CallLogStore) Query(q CallLogQuery) (*CallLogPageResult, error) {
|
|
if s.db == nil {
|
|
return &CallLogPageResult{Calls: []CallLogRecord{}, Total: 0, Page: q.Page, Limit: q.Limit, TotalPages: 0}, nil
|
|
}
|
|
|
|
if q.Page < 1 {
|
|
q.Page = 1
|
|
}
|
|
if q.Limit < 1 || q.Limit > 100 {
|
|
q.Limit = 20
|
|
}
|
|
|
|
// 构建 WHERE 条件
|
|
where := ""
|
|
whereArgs := []interface{}{}
|
|
argIdx := 1
|
|
|
|
if q.ToolName != "" {
|
|
where = fmt.Sprintf(" WHERE tool_name = $%d", argIdx)
|
|
whereArgs = append(whereArgs, q.ToolName)
|
|
argIdx++
|
|
}
|
|
|
|
// 计数
|
|
countQuery := "SELECT COUNT(*) FROM tool_call_logs" + where
|
|
var total int
|
|
if err := s.db.QueryRow(countQuery, whereArgs...).Scan(&total); err != nil {
|
|
return nil, fmt.Errorf("查询总数失败: %w", err)
|
|
}
|
|
|
|
// 分页查询
|
|
offset := (q.Page - 1) * q.Limit
|
|
querySql := fmt.Sprintf(
|
|
"SELECT id, call_id, tool_name, arguments, COALESCE(output,''), COALESCE(error,''), success, COALESCE(duration_ms,0), COALESCE(user_id,''), COALESCE(session_id,''), created_at FROM tool_call_logs%s ORDER BY created_at DESC LIMIT $%d OFFSET $%d",
|
|
where, argIdx, argIdx+1,
|
|
)
|
|
queryArgs := append(whereArgs, q.Limit, offset)
|
|
|
|
rows, err := s.db.Query(querySql, queryArgs...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询记录失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
calls := make([]CallLogRecord, 0)
|
|
for rows.Next() {
|
|
var r CallLogRecord
|
|
var argsJSON []byte
|
|
if err := rows.Scan(&r.ID, &r.CallID, &r.ToolName, &argsJSON, &r.Output, &r.Error, &r.Success, &r.DurationMs, &r.UserID, &r.SessionID, &r.CreatedAt); err != nil {
|
|
log.Printf("[call-log-store] 扫描行失败: %v", err)
|
|
continue
|
|
}
|
|
r.Arguments = argsJSON
|
|
calls = append(calls, r)
|
|
}
|
|
|
|
totalPages := (total + q.Limit - 1) / q.Limit
|
|
|
|
return &CallLogPageResult{
|
|
Calls: calls,
|
|
Total: total,
|
|
Page: q.Page,
|
|
Limit: q.Limit,
|
|
TotalPages: totalPages,
|
|
}, nil
|
|
}
|
|
|
|
// Stats 获取调用统计
|
|
func (s *CallLogStore) Stats() (*CallLogStats, error) {
|
|
if s.db == nil {
|
|
return &CallLogStats{}, nil
|
|
}
|
|
|
|
stats := &CallLogStats{}
|
|
|
|
// 总体统计
|
|
err := s.db.QueryRow(
|
|
"SELECT COUNT(*), COUNT(*) FILTER (WHERE success=true), COUNT(*) FILTER (WHERE success=false), COALESCE(AVG(duration_ms),0) FROM tool_call_logs",
|
|
).Scan(&stats.TotalCalls, &stats.SuccessCount, &stats.FailCount, &stats.AvgDuration)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询总体统计失败: %w", err)
|
|
}
|
|
|
|
if stats.TotalCalls > 0 {
|
|
stats.SuccessRate = float64(stats.SuccessCount) / float64(stats.TotalCalls) * 100
|
|
}
|
|
|
|
// 按工具统计
|
|
rows, err := s.db.Query(
|
|
"SELECT tool_name, COUNT(*), COUNT(*) FILTER (WHERE success=true), COUNT(*) FILTER (WHERE success=false), COALESCE(AVG(duration_ms),0) FROM tool_call_logs GROUP BY tool_name ORDER BY COUNT(*) DESC",
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("查询按工具统计失败: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
stats.ByTool = make([]ToolCallCount, 0)
|
|
for rows.Next() {
|
|
var tc ToolCallCount
|
|
if err := rows.Scan(&tc.ToolName, &tc.Count, &tc.SuccessCount, &tc.FailCount, &tc.AvgDuration); err != nil {
|
|
log.Printf("[call-log-store] 扫描工具统计失败: %v", err)
|
|
continue
|
|
}
|
|
stats.ByTool = append(stats.ByTool, tc)
|
|
}
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
// Close 关闭数据库连接
|
|
func (s *CallLogStore) Close() {
|
|
if s.db != nil {
|
|
s.db.Close()
|
|
}
|
|
}
|
|
|
|
// DBUrlFromEnv 从环境变量获取数据库连接
|
|
func DBUrlFromEnv() string {
|
|
// 如果设置了 DB_URL 直接使用
|
|
if url := os.Getenv("DB_URL"); url != "" {
|
|
return url
|
|
}
|
|
|
|
// 否则从单独的环境变量构建
|
|
host := getEnv("DB_HOST", "localhost")
|
|
port := getEnv("DB_PORT", "5432")
|
|
user := getEnv("DB_USER", "cyrene")
|
|
pass := getEnv("DB_PASSWORD", "change_me")
|
|
dbname := getEnv("DB_NAME", "cyrene_ai")
|
|
sslmode := getEnv("DB_SSLMODE", "disable")
|
|
|
|
return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s", user, pass, host, port, dbname, sslmode)
|
|
}
|
|
|
|
func getEnv(key, fallback string) string {
|
|
if v := os.Getenv(key); v != "" {
|
|
return v
|
|
}
|
|
return fallback
|
|
}
|