feat: Round 5 - Memory Service, Tool Engine, Call Records, Thinking Logs
- Fix: Session history flash (race condition + WS guard) - Fix: Chat background overlay + sidebar transparency - Fix: IoT device control (Chinese action names, status field) - Feat: Independent memory-service (port 8091, 13 endpoints) - Feat: Independent tool-engine service (port 8092, 13 tools) - Feat: Tool call logs with paginated DevTools panel - Feat: Thinking log records with DevTools panel - Feat: Future development roadmap document - Chore: Updated .gitignore, go.work, DevTools config - Chore: 5-service health check, project review docs
This commit is contained in:
+9
-2
@@ -19,6 +19,13 @@ devtools/package-lock.json
|
||||
|
||||
# Go 编译二进制
|
||||
backend/ai-core/main
|
||||
backend/ai-core/cmd/main
|
||||
backend/gateway/main
|
||||
backend/cmd
|
||||
backend/iot-debug-service/main
|
||||
backend/gateway/cmd/main
|
||||
backend/iot-debug-service/main
|
||||
backend/iot-debug-service/cmd/main
|
||||
backend/memory-service/cmd/main
|
||||
backend/tool-engine/cmd/tool-engine
|
||||
|
||||
# Stale binary (old build artifact)
|
||||
backend/cmd
|
||||
@@ -121,6 +121,12 @@ func main() {
|
||||
thinkerCfg := background.DefaultThinkerConfig()
|
||||
adminUserID := "admin_admin"
|
||||
adminSessionID := "admin-session-main"
|
||||
|
||||
// 创建记忆服务 HTTP 客户端(用于持久化思考日志到 memory-service)
|
||||
memServiceURL := getEnv("MEMORY_SERVICE_URL", "http://localhost:8091")
|
||||
memClient := memory.NewClient(memServiceURL)
|
||||
log.Printf("记忆服务客户端已就绪: %s", memServiceURL)
|
||||
|
||||
thinker := background.NewThinker(
|
||||
thinkerCfg,
|
||||
personaLoader,
|
||||
@@ -133,6 +139,7 @@ func main() {
|
||||
convStore,
|
||||
adminUserID,
|
||||
adminSessionID,
|
||||
memClient,
|
||||
)
|
||||
thinker.Start()
|
||||
defer thinker.Stop()
|
||||
|
||||
@@ -50,6 +50,9 @@ type Thinker struct {
|
||||
adminUserID string // 管理员用户 ID
|
||||
adminSessionID string // 管理员主对话 session ID
|
||||
|
||||
// 记忆服务 HTTP 客户端(用于持久化思考日志)
|
||||
memClient *memory.Client
|
||||
|
||||
pendingThoughts []*PendingThought
|
||||
|
||||
lastUserMessage time.Time
|
||||
@@ -90,6 +93,7 @@ func NewThinker(
|
||||
convStore *ctxbuild.ConversationStore,
|
||||
adminUserID string,
|
||||
adminSessionID string,
|
||||
memClient *memory.Client,
|
||||
) *Thinker {
|
||||
return &Thinker{
|
||||
enabled: cfg.Enabled,
|
||||
@@ -106,6 +110,7 @@ func NewThinker(
|
||||
convStore: convStore,
|
||||
adminUserID: adminUserID,
|
||||
adminSessionID: adminSessionID,
|
||||
memClient: memClient,
|
||||
pendingThoughts: make([]*PendingThought, 0),
|
||||
lastUserMessage: time.Now(),
|
||||
stopCh: make(chan struct{}),
|
||||
@@ -281,6 +286,7 @@ func (t *Thinker) performThink() {
|
||||
maxToolRounds := 3
|
||||
var finalContent string
|
||||
var totalToolCalls int
|
||||
var toolCallRecords []map[string]interface{}
|
||||
|
||||
for round := 0; round <= maxToolRounds; round++ {
|
||||
resp, err := t.llmAdapter.ChatWithTools(ctx, messages, openAITools)
|
||||
@@ -327,6 +333,10 @@ func (t *Thinker) performThink() {
|
||||
})
|
||||
|
||||
totalToolCalls++
|
||||
toolCallRecords = append(toolCallRecords, map[string]interface{}{
|
||||
"name": tc.Name,
|
||||
"args": args,
|
||||
})
|
||||
}
|
||||
|
||||
// 最后一轮:即使有 tool_calls 也强制停止
|
||||
@@ -348,8 +358,16 @@ func (t *Thinker) performThink() {
|
||||
return
|
||||
}
|
||||
|
||||
// 8. 存储思考结果
|
||||
t.storeThought(finalContent)
|
||||
// 序列化工具调用记录
|
||||
toolCallsJSON := "[]"
|
||||
if len(toolCallRecords) > 0 {
|
||||
if data, err := json.Marshal(toolCallRecords); err == nil {
|
||||
toolCallsJSON = string(data)
|
||||
}
|
||||
}
|
||||
|
||||
// 8. 存储思考结果(内存队列 + 持久化到 memory-service)
|
||||
t.storeThought(finalContent, toolCallsJSON, totalToolCalls)
|
||||
|
||||
log.Printf("[后台思考] 完成 (内容长度=%d, 工具调用=%d次)", len(finalContent), totalToolCalls)
|
||||
|
||||
@@ -475,11 +493,9 @@ func (t *Thinker) buildOpenAITools() []llm.OpenAITool {
|
||||
return result
|
||||
}
|
||||
|
||||
// storeThought 存储思考结果到待推送队列
|
||||
func (t *Thinker) storeThought(content string) {
|
||||
// storeThought 存储思考结果到待推送队列,并异步持久化到 memory-service
|
||||
func (t *Thinker) storeThought(content string, toolCallsJSON string, toolCallCount int) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.pendingThoughts = append(t.pendingThoughts, &PendingThought{
|
||||
Content: content,
|
||||
CreatedAt: time.Now(),
|
||||
@@ -490,8 +506,22 @@ func (t *Thinker) storeThought(content string) {
|
||||
if len(t.pendingThoughts) > 10 {
|
||||
t.pendingThoughts = t.pendingThoughts[len(t.pendingThoughts)-10:]
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
log.Printf("[后台思考] 思考已存储 (当前累积 %d 条待推送思考)", len(t.pendingThoughts))
|
||||
|
||||
// 异步持久化到 memory-service (不阻塞思考循环)
|
||||
if t.memClient != nil {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := t.memClient.SaveThinkingLog(ctx, t.adminUserID, content, toolCallsJSON, toolCallCount, len(content)); err != nil {
|
||||
log.Printf("[后台思考] 持久化思考日志失败: %v", err)
|
||||
} else {
|
||||
log.Printf("[后台思考] 思考日志已持久化 (长度=%d, 工具调用=%d)", len(content), toolCallCount)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// extractMemoriesFromThinking 从思考结果中提取记忆(异步执行)
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
)
|
||||
|
||||
// Client 记忆服务 HTTP 客户端
|
||||
// ai-core 通过此客户端调用独立的 memory-service
|
||||
type Client struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient 创建记忆服务客户端
|
||||
func NewClient(baseURL string) *Client {
|
||||
return &Client{
|
||||
baseURL: baseURL,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 15 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Ping 检查记忆服务是否可用
|
||||
func (c *Client) Ping(ctx context.Context) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/api/v1/health", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("记忆服务健康检查失败: %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save 保存记忆
|
||||
func (c *Client) Save(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"user_id": entry.UserID,
|
||||
"content": entry.Content,
|
||||
"summary": entry.Summary,
|
||||
"category": string(entry.Category),
|
||||
"priority": int(entry.Priority),
|
||||
"importance": entry.Importance,
|
||||
"keywords": entry.Keywords,
|
||||
"session_id": entry.SessionID,
|
||||
"source": entry.Source,
|
||||
})
|
||||
|
||||
resp, err := c.doRequest(ctx, http.MethodPost, c.baseURL+"/api/v1/memories", body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存记忆失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("保存记忆失败 (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// 解析返回以获取 ID 和 CreatedAt
|
||||
var result struct {
|
||||
Memory *model.MemoryEntry `json:"memory"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err == nil && result.Memory != nil {
|
||||
entry.ID = result.Memory.ID
|
||||
entry.CreatedAt = result.Memory.CreatedAt
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query 按条件查询记忆
|
||||
func (c *Client) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryEntry, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/memories?user_id=%s", c.baseURL, q.UserID)
|
||||
if q.Category != "" {
|
||||
url += "&category=" + string(q.Category)
|
||||
}
|
||||
if q.MinImportance > 0 {
|
||||
url += fmt.Sprintf("&min_importance=%d", q.MinImportance)
|
||||
}
|
||||
if q.Limit > 0 {
|
||||
url += fmt.Sprintf("&limit=%d", q.Limit)
|
||||
}
|
||||
|
||||
resp, err := c.doRequest(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询记忆失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Memories []model.MemoryEntry `json:"memories"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析查询结果失败: %w", err)
|
||||
}
|
||||
return result.Memories, nil
|
||||
}
|
||||
|
||||
// QueryByText 语义查询(POST /api/v1/memories/query)
|
||||
func (c *Client) QueryByText(ctx context.Context, userID, queryText, category string, minImportance, limit int) ([]model.MemoryEntry, error) {
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"query_text": queryText,
|
||||
"category": category,
|
||||
"min_importance": minImportance,
|
||||
"limit": limit,
|
||||
})
|
||||
|
||||
resp, err := c.doRequest(ctx, http.MethodPost, c.baseURL+"/api/v1/memories/query", body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("语义查询失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("语义查询失败 (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Memories []model.MemoryEntry `json:"memories"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析查询结果失败: %w", err)
|
||||
}
|
||||
return result.Memories, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取记忆
|
||||
func (c *Client) GetByID(ctx context.Context, id string) (*model.MemoryEntry, error) {
|
||||
resp, err := c.doRequest(ctx, http.MethodGet, c.baseURL+"/api/v1/memories/"+id, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取记忆失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("获取记忆失败 (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Memory model.MemoryEntry `json:"memory"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析获取结果失败: %w", err)
|
||||
}
|
||||
return &result.Memory, nil
|
||||
}
|
||||
|
||||
// Delete 删除记忆
|
||||
func (c *Client) Delete(ctx context.Context, id string) error {
|
||||
resp, err := c.doRequest(ctx, http.MethodDelete, c.baseURL+"/api/v1/memories/"+id, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除记忆失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("删除记忆失败 (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMemoriesByCategory 按分类获取记忆
|
||||
func (c *Client) GetMemoriesByCategory(ctx context.Context, userID string, category model.MemoryCategory) ([]model.MemoryEntry, error) {
|
||||
return c.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Category: category,
|
||||
Limit: 50,
|
||||
})
|
||||
}
|
||||
|
||||
// ConsolidateMemories 合并相似记忆
|
||||
func (c *Client) ConsolidateMemories(ctx context.Context, userID string) (int, error) {
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"user_id": userID,
|
||||
})
|
||||
|
||||
resp, err := c.doRequest(ctx, http.MethodPost, c.baseURL+"/api/v1/memories/consolidate", body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("合并记忆失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Merged int `json:"merged"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return 0, fmt.Errorf("解析合并结果失败: %w", err)
|
||||
}
|
||||
return result.Merged, nil
|
||||
}
|
||||
|
||||
// DecayMemories 衰减旧记忆
|
||||
func (c *Client) DecayMemories(ctx context.Context, userID string) (int, int, error) {
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"user_id": userID,
|
||||
})
|
||||
|
||||
resp, err := c.doRequest(ctx, http.MethodPost, c.baseURL+"/api/v1/memories/decay", body)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("衰减记忆失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Decayed int `json:"decayed"`
|
||||
Deleted int `json:"deleted"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return 0, 0, fmt.Errorf("解析衰减结果失败: %w", err)
|
||||
}
|
||||
return result.Decayed, result.Deleted, nil
|
||||
}
|
||||
|
||||
// GetCategories 获取用户类别统计
|
||||
func (c *Client) GetCategories(ctx context.Context, userID string) (map[string]int, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/memories/categories?user_id=%s", c.baseURL, userID)
|
||||
resp, err := c.doRequest(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取类别统计失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Categories map[string]int `json:"categories"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析类别统计失败: %w", err)
|
||||
}
|
||||
return result.Categories, nil
|
||||
}
|
||||
|
||||
// SaveThinkingLog 持久化自主思考日志到 memory-service
|
||||
func (c *Client) SaveThinkingLog(ctx context.Context, userID, content, toolCalls string, toolCallCount, contentLength int) error {
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"content": content,
|
||||
"tool_calls": toolCalls,
|
||||
"tool_call_count": toolCallCount,
|
||||
"content_length": contentLength,
|
||||
})
|
||||
|
||||
resp, err := c.doRequest(ctx, http.MethodPost, c.baseURL+"/api/v1/thinking", body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存思考日志失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("保存思考日志失败 (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsReady 检查记忆服务是否就绪
|
||||
func (c *Client) IsReady() bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
return c.Ping(ctx) == nil
|
||||
}
|
||||
|
||||
// doRequest 内部 HTTP 请求辅助方法
|
||||
func (c *Client) doRequest(ctx context.Context, method, url string, body []byte) (*http.Response, error) {
|
||||
var reqBody io.Reader
|
||||
if body != nil {
|
||||
reqBody = bytes.NewReader(body)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[memory-client] HTTP 请求失败 %s %s: %v", method, url, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
@@ -277,8 +277,8 @@ reflection_guidelines:
|
||||
- question: "开拓者的情绪是否有变化?"
|
||||
action: "如果情绪变好,说明陪伴有效;如果变差,思考如何改进"
|
||||
periodic:
|
||||
- frequency: "每10轮对话一次"
|
||||
actions:
|
||||
- "回顾最近的记忆,检查是否有矛盾之处"
|
||||
- "总结开拓者最近的生活状态和情绪趋势"
|
||||
- "思考如何在下次对话中创造惊喜或温暖"
|
||||
frequency: "每10轮对话一次"
|
||||
actions:
|
||||
- "回顾最近的记忆,检查是否有矛盾之处"
|
||||
- "总结开拓者最近的生活状态和情绪趋势"
|
||||
- "思考如何在下次对话中创造惊喜或温暖"
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IoTControlTool IoT 设备控制工具
|
||||
@@ -53,6 +54,73 @@ func (t *IoTControlTool) Definition() ToolDefinition {
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeAction 标准化 action 参数,支持中文别名、power 参数等
|
||||
func normalizeAction(arguments map[string]interface{}) string {
|
||||
action, _ := arguments["action"].(string)
|
||||
|
||||
// 如果 action 为空,检查 power/status 参数
|
||||
if action == "" {
|
||||
// power 参数: "off"/"关"/"关闭" → turn_off, "on"/"开"/"打开" → turn_on
|
||||
if pv, ok := arguments["power"]; ok {
|
||||
switch v := pv.(type) {
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "off", "false", "关", "关闭":
|
||||
return "turn_off"
|
||||
case "on", "true", "开", "打开", "开启":
|
||||
return "turn_on"
|
||||
}
|
||||
case bool:
|
||||
if !v {
|
||||
return "turn_off"
|
||||
}
|
||||
return "turn_on"
|
||||
}
|
||||
}
|
||||
// status 参数同理
|
||||
if sv, ok := arguments["status"]; ok {
|
||||
switch v := sv.(type) {
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "off", "false", "关", "关闭":
|
||||
return "turn_off"
|
||||
case "on", "true", "开", "打开", "开启":
|
||||
return "turn_on"
|
||||
}
|
||||
case bool:
|
||||
if !v {
|
||||
return "turn_off"
|
||||
}
|
||||
return "turn_on"
|
||||
}
|
||||
}
|
||||
// 默认 toggle
|
||||
return "toggle"
|
||||
}
|
||||
|
||||
// 标准化中文 action 名
|
||||
switch strings.ToLower(strings.TrimSpace(action)) {
|
||||
case "打开", "开启", "开":
|
||||
return "turn_on"
|
||||
case "关闭", "关":
|
||||
return "turn_off"
|
||||
case "切换":
|
||||
return "toggle"
|
||||
case "设置温度", "调温度", "set_temp":
|
||||
return "set_temperature"
|
||||
case "设置亮度", "调亮度", "set_light":
|
||||
return "set_brightness"
|
||||
case "设置位置", "调位置":
|
||||
return "set_position"
|
||||
case "设置模式", "调模式", "切换模式":
|
||||
return "set_mode"
|
||||
case "设置颜色", "调颜色", "换颜色":
|
||||
return "set_color"
|
||||
}
|
||||
|
||||
return action
|
||||
}
|
||||
|
||||
// Execute 执行设备控制
|
||||
func (t *IoTControlTool) Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) {
|
||||
if t.iotClient == nil {
|
||||
@@ -69,7 +137,7 @@ func (t *IoTControlTool) Execute(ctx context.Context, arguments map[string]inter
|
||||
deviceID, _ = arguments["entity_id"].(string)
|
||||
}
|
||||
|
||||
action, _ := arguments["action"].(string)
|
||||
action := normalizeAction(arguments)
|
||||
|
||||
if deviceID == "" {
|
||||
return &ToolResult{
|
||||
@@ -79,8 +147,10 @@ func (t *IoTControlTool) Execute(ctx context.Context, arguments map[string]inter
|
||||
}, nil
|
||||
}
|
||||
|
||||
if action == "" {
|
||||
action = "toggle"
|
||||
// 先获取设备名用于友好的返回消息(失败不影响后续流程)
|
||||
deviceName := deviceID
|
||||
if dev, err := t.iotClient.GetDevice(deviceID); err == nil {
|
||||
deviceName = dev.Name
|
||||
}
|
||||
|
||||
// 处理属性设置类操作
|
||||
@@ -95,51 +165,10 @@ func (t *IoTControlTool) Execute(ctx context.Context, arguments map[string]inter
|
||||
return t.handleSetMode(deviceID, arguments)
|
||||
case "set_color":
|
||||
return t.handleSetColor(deviceID, arguments)
|
||||
}
|
||||
|
||||
// 处理开关类操作:需要获取当前设备状态
|
||||
currentDevice, err := t.iotClient.GetDevice(deviceID)
|
||||
if err != nil {
|
||||
return &ToolResult{
|
||||
ToolName: "iot_control",
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("获取设备状态失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "turn_on":
|
||||
// 如果设备已经开启,不需要操作
|
||||
if currentDevice.Status == "on" || currentDevice.Status == "open" || currentDevice.Status == "unlocked" {
|
||||
return &ToolResult{
|
||||
ToolName: "iot_control",
|
||||
Success: true,
|
||||
Data: fmt.Sprintf("设备 %s (%s) 已经处于开启状态,无需操作。", currentDevice.Name, deviceID),
|
||||
}, nil
|
||||
}
|
||||
if err := t.iotClient.ToggleDevice(deviceID); err != nil {
|
||||
return &ToolResult{
|
||||
ToolName: "iot_control",
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("打开设备失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
return &ToolResult{
|
||||
ToolName: "iot_control",
|
||||
Success: true,
|
||||
Data: fmt.Sprintf("已打开设备: %s", currentDevice.Name),
|
||||
}, nil
|
||||
|
||||
case "turn_off":
|
||||
// 如果设备已经关闭,不需要操作
|
||||
if currentDevice.Status == "off" || currentDevice.Status == "closed" || currentDevice.Status == "locked" {
|
||||
return &ToolResult{
|
||||
ToolName: "iot_control",
|
||||
Success: true,
|
||||
Data: fmt.Sprintf("设备 %s (%s) 已经处于关闭状态,无需操作。", currentDevice.Name, deviceID),
|
||||
}, nil
|
||||
}
|
||||
if err := t.iotClient.ToggleDevice(deviceID); err != nil {
|
||||
// 声明式关闭:使用 SetDeviceProperty status/off 而非 toggle
|
||||
// 即使设备已经关闭,SetProperty 也会幂等处理
|
||||
if err := t.iotClient.SetDeviceProperty(deviceID, "status", "off"); err != nil {
|
||||
return &ToolResult{
|
||||
ToolName: "iot_control",
|
||||
Success: false,
|
||||
@@ -149,9 +178,22 @@ func (t *IoTControlTool) Execute(ctx context.Context, arguments map[string]inter
|
||||
return &ToolResult{
|
||||
ToolName: "iot_control",
|
||||
Success: true,
|
||||
Data: fmt.Sprintf("已关闭设备: %s", currentDevice.Name),
|
||||
Data: fmt.Sprintf("已关闭设备: %s", deviceName),
|
||||
}, nil
|
||||
case "turn_on":
|
||||
// 声明式打开:使用 SetDeviceProperty status/on 而非 toggle
|
||||
if err := t.iotClient.SetDeviceProperty(deviceID, "status", "on"); err != nil {
|
||||
return &ToolResult{
|
||||
ToolName: "iot_control",
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("打开设备失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
return &ToolResult{
|
||||
ToolName: "iot_control",
|
||||
Success: true,
|
||||
Data: fmt.Sprintf("已打开设备: %s", deviceName),
|
||||
}, nil
|
||||
|
||||
default: // "toggle"
|
||||
if err := t.iotClient.ToggleDevice(deviceID); err != nil {
|
||||
return &ToolResult{
|
||||
@@ -167,7 +209,7 @@ func (t *IoTControlTool) Execute(ctx context.Context, arguments map[string]inter
|
||||
return &ToolResult{
|
||||
ToolName: "iot_control",
|
||||
Success: true,
|
||||
Data: fmt.Sprintf("已成功切换设备 %s 的状态。", currentDevice.Name),
|
||||
Data: fmt.Sprintf("已成功切换设备 %s 的状态。", deviceName),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ToolEngineClient 工具引擎 HTTP 客户端
|
||||
// 将工具执行请求转发到独立的 tool-engine 微服务
|
||||
type ToolEngineClient struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// toolEngineToolDef 来自 tool-engine 的工具定义响应
|
||||
type toolEngineToolDef struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
|
||||
// toolEngineResult 来自 tool-engine 的工具执行结果
|
||||
type toolEngineResult struct {
|
||||
ID string `json:"id"`
|
||||
Output string `json:"output"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// NewToolEngineClient 创建工具引擎客户端
|
||||
func NewToolEngineClient(baseURL string) *ToolEngineClient {
|
||||
return &ToolEngineClient{
|
||||
baseURL: baseURL,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetDefinitions 从 tool-engine 获取所有工具定义
|
||||
func (c *ToolEngineClient) GetDefinitions(ctx context.Context) ([]ToolDefinition, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL+"/api/v1/tools", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求工具列表失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("获取工具列表返回状态码 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Tools []toolEngineToolDef `json:"tools"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析工具列表失败: %w", err)
|
||||
}
|
||||
|
||||
defs := make([]ToolDefinition, 0, len(result.Tools))
|
||||
for _, t := range result.Tools {
|
||||
defs = append(defs, ToolDefinition{
|
||||
Name: t.Name,
|
||||
Description: t.Description,
|
||||
Parameters: t.Parameters,
|
||||
})
|
||||
}
|
||||
|
||||
log.Printf("[tool-engine-client] 从 tool-engine 获取了 %d 个工具定义", len(defs))
|
||||
return defs, nil
|
||||
}
|
||||
|
||||
// Execute 通过 tool-engine 执行工具调用
|
||||
func (c *ToolEngineClient) Execute(ctx context.Context, toolName string, arguments map[string]interface{}) (*ToolResult, error) {
|
||||
body, err := json.Marshal(map[string]interface{}{
|
||||
"arguments": arguments,
|
||||
})
|
||||
if err != nil {
|
||||
return &ToolResult{
|
||||
ToolName: toolName,
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("序列化参数失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/tools/%s/execute", c.baseURL, toolName)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return &ToolResult{
|
||||
ToolName: toolName,
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("创建请求失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return &ToolResult{
|
||||
ToolName: toolName,
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("请求 tool-engine 失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return &ToolResult{
|
||||
ToolName: toolName,
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("工具 %s 不存在", toolName),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return &ToolResult{
|
||||
ToolName: toolName,
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("tool-engine 返回状态码 %d: %s", resp.StatusCode, string(respBody)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
var result toolEngineResult
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return &ToolResult{
|
||||
ToolName: toolName,
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("解析 tool-engine 响应失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return &ToolResult{
|
||||
ToolName: toolName,
|
||||
Success: false,
|
||||
Error: result.Error,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &ToolResult{
|
||||
ToolName: toolName,
|
||||
Success: true,
|
||||
Data: result.Output,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// HealthCheck 检查 tool-engine 服务是否可用
|
||||
func (c *ToolEngineClient) HealthCheck(ctx context.Context) error {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL+"/api/v1/health", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建健康检查请求失败: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tool-engine 不可达: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("tool-engine 健康检查返回状态码 %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -39,9 +39,15 @@ type Config struct {
|
||||
// AI-Core 服务
|
||||
AICoreURL string
|
||||
|
||||
// Memory 服务
|
||||
MemoryServiceURL string
|
||||
|
||||
// IoT 调试服务
|
||||
IoTDebugServiceURL string
|
||||
|
||||
// Tool-Engine 工具引擎服务
|
||||
ToolEngineURL string
|
||||
|
||||
// LLM (透传给AI-Core,Gateway可能也需要)
|
||||
LLMAPIURL string
|
||||
LLMAPIKey string
|
||||
@@ -85,8 +91,12 @@ func Load() *Config {
|
||||
|
||||
AICoreURL: getEnv("AI_CORE_URL", "http://localhost:8081"),
|
||||
|
||||
MemoryServiceURL: getEnv("MEMORY_SERVICE_URL", "http://localhost:8091"),
|
||||
|
||||
IoTDebugServiceURL: getEnv("IOT_DEBUG_SERVICE_URL", "http://localhost:8083"),
|
||||
|
||||
ToolEngineURL: getEnv("TOOL_ENGINE_URL", "http://localhost:8092"),
|
||||
|
||||
LLMAPIURL: getEnv("LLM_API_URL", "https://api.openai.com/v1"),
|
||||
LLMAPIKey: getEnv("LLM_API_KEY", ""),
|
||||
LLMModel: getEnv("LLM_MODEL", "gpt-4o"),
|
||||
|
||||
@@ -15,23 +15,23 @@ import (
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
)
|
||||
|
||||
// MemoryHandler 记忆查询处理器 — 代理到 AI-Core
|
||||
// MemoryHandler 记忆查询处理器 — 代理到 Memory-Service
|
||||
type MemoryHandler struct {
|
||||
aiCoreURL string
|
||||
client *http.Client
|
||||
memoryServiceURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewMemoryHandler 创建记忆处理器
|
||||
func NewMemoryHandler(aiCoreURL string) *MemoryHandler {
|
||||
func NewMemoryHandler(memoryServiceURL string) *MemoryHandler {
|
||||
return &MemoryHandler{
|
||||
aiCoreURL: aiCoreURL,
|
||||
memoryServiceURL: memoryServiceURL,
|
||||
client: &http.Client{
|
||||
Timeout: 15 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Query 搜索用户记忆 — 代理 GET /api/v1/memory/search?user_id=...&q=...
|
||||
// Query 搜索用户记忆 — 代理 POST /api/v1/memories/query
|
||||
// 管理员可通过 user_id 查询参数查询任意用户的记忆
|
||||
func (h *MemoryHandler) Query(c *gin.Context) {
|
||||
authUserID := middleware.GetUserID(c)
|
||||
@@ -48,16 +48,28 @@ func (h *MemoryHandler) Query(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/memory/search?user_id=%s&q=%s",
|
||||
h.aiCoreURL, userID, query)
|
||||
// 使用 memory-service 的 POST /api/v1/memories/query 端点
|
||||
reqBody, _ := json.Marshal(map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"query_text": query,
|
||||
"limit": 10,
|
||||
})
|
||||
|
||||
resp, err := h.client.Get(url)
|
||||
url := fmt.Sprintf("%s/api/v1/memories/query", h.memoryServiceURL)
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
log.Printf("[memory] AI-Core 不可达 (Query): %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "构建请求失败"})
|
||||
return
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := h.client.Do(httpReq)
|
||||
if err != nil {
|
||||
log.Printf("[memory] Memory-Service 不可达 (Query): %v", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": fmt.Sprintf("AI-Core 不可达: %v", err),
|
||||
"errorType": "ai_core_unreachable",
|
||||
"hint": "AI-Core 服务未启动或不可达,请先在「服务管理」面板中启动 AI-Core",
|
||||
"error": fmt.Sprintf("Memory-Service 不可达: %v", err),
|
||||
"errorType": "memory_service_unreachable",
|
||||
"hint": "Memory-Service 服务未启动或不可达,请先在「服务管理」面板中启动 Memory-Service",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -70,7 +82,7 @@ func (h *MemoryHandler) Query(c *gin.Context) {
|
||||
c.JSON(resp.StatusCode, result)
|
||||
}
|
||||
|
||||
// List 列出用户所有记忆 — 代理 GET /api/v1/memory?user_id=...
|
||||
// List 列出用户所有记忆 — 代理 GET /api/v1/memories?user_id=...
|
||||
// 管理员可通过 user_id 查询参数查询任意用户的记忆
|
||||
func (h *MemoryHandler) List(c *gin.Context) {
|
||||
authUserID := middleware.GetUserID(c)
|
||||
@@ -81,15 +93,15 @@ func (h *MemoryHandler) List(c *gin.Context) {
|
||||
userID = authUserID
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/memory?user_id=%s", h.aiCoreURL, userID)
|
||||
url := fmt.Sprintf("%s/api/v1/memories?user_id=%s", h.memoryServiceURL, userID)
|
||||
|
||||
resp, err := h.client.Get(url)
|
||||
if err != nil {
|
||||
log.Printf("[memory] AI-Core 不可达 (List): %v", err)
|
||||
log.Printf("[memory] Memory-Service 不可达 (List): %v", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": fmt.Sprintf("AI-Core 不可达: %v", err),
|
||||
"errorType": "ai_core_unreachable",
|
||||
"hint": "AI-Core 服务未启动或不可达,请先在「服务管理」面板中启动 AI-Core",
|
||||
"error": fmt.Sprintf("Memory-Service 不可达: %v", err),
|
||||
"errorType": "memory_service_unreachable",
|
||||
"hint": "Memory-Service 服务未启动或不可达,请先在「服务管理」面板中启动 Memory-Service",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -102,7 +114,7 @@ func (h *MemoryHandler) List(c *gin.Context) {
|
||||
c.JSON(resp.StatusCode, result)
|
||||
}
|
||||
|
||||
// Add 手动添加记忆 — 代理 POST /api/v1/memory
|
||||
// Add 手动添加记忆 — 代理 POST /api/v1/memories
|
||||
// 管理员可通过请求体中的 user_id 字段为任意用户添加记忆
|
||||
func (h *MemoryHandler) Add(c *gin.Context) {
|
||||
authUserID := middleware.GetUserID(c)
|
||||
@@ -132,16 +144,16 @@ func (h *MemoryHandler) Add(c *gin.Context) {
|
||||
userID = req.UserID
|
||||
}
|
||||
|
||||
// 转发到 AI-Core
|
||||
aiReq := map[string]interface{}{
|
||||
// 转发到 Memory-Service
|
||||
memReq := map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"content": req.Content,
|
||||
"category": req.Category,
|
||||
"priority": req.Priority,
|
||||
}
|
||||
reqBody, _ := json.Marshal(aiReq)
|
||||
reqBody, _ := json.Marshal(memReq)
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/memory", h.aiCoreURL)
|
||||
url := fmt.Sprintf("%s/api/v1/memories", h.memoryServiceURL)
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "构建请求失败"})
|
||||
@@ -151,11 +163,11 @@ func (h *MemoryHandler) Add(c *gin.Context) {
|
||||
|
||||
resp, err := h.client.Do(httpReq)
|
||||
if err != nil {
|
||||
log.Printf("[memory] AI-Core 不可达 (Add): %v", err)
|
||||
log.Printf("[memory] Memory-Service 不可达 (Add): %v", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": fmt.Sprintf("AI-Core 不可达: %v", err),
|
||||
"errorType": "ai_core_unreachable",
|
||||
"hint": "AI-Core 服务未启动或不可达,请先在「服务管理」面板中启动 AI-Core",
|
||||
"error": fmt.Sprintf("Memory-Service 不可达: %v", err),
|
||||
"errorType": "memory_service_unreachable",
|
||||
"hint": "Memory-Service 服务未启动或不可达,请先在「服务管理」面板中启动 Memory-Service",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -168,7 +180,7 @@ func (h *MemoryHandler) Add(c *gin.Context) {
|
||||
c.JSON(resp.StatusCode, result)
|
||||
}
|
||||
|
||||
// Delete 删除单条记忆 — 代理 DELETE /api/v1/memory?id=...
|
||||
// Delete 删除单条记忆 — 代理 DELETE /api/v1/memories/:id
|
||||
func (h *MemoryHandler) Delete(c *gin.Context) {
|
||||
memoryID := c.Query("id")
|
||||
if memoryID == "" {
|
||||
@@ -176,7 +188,7 @@ func (h *MemoryHandler) Delete(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/memory?id=%s", h.aiCoreURL, memoryID)
|
||||
url := fmt.Sprintf("%s/api/v1/memories/%s", h.memoryServiceURL, memoryID)
|
||||
|
||||
req, err := http.NewRequest("DELETE", url, nil)
|
||||
if err != nil {
|
||||
@@ -186,11 +198,11 @@ func (h *MemoryHandler) Delete(c *gin.Context) {
|
||||
|
||||
resp, err := h.client.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("[memory] AI-Core 不可达 (Delete): %v", err)
|
||||
log.Printf("[memory] Memory-Service 不可达 (Delete): %v", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": fmt.Sprintf("AI-Core 不可达: %v", err),
|
||||
"errorType": "ai_core_unreachable",
|
||||
"hint": "AI-Core 服务未启动或不可达,请先在「服务管理」面板中启动 AI-Core",
|
||||
"error": fmt.Sprintf("Memory-Service 不可达: %v", err),
|
||||
"errorType": "memory_service_unreachable",
|
||||
"hint": "Memory-Service 服务未启动或不可达,请先在「服务管理」面板中启动 Memory-Service",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.S
|
||||
// 初始化处理器
|
||||
authHandler := handler.NewAuthHandler(cfg)
|
||||
sessionHandler := handler.NewSessionHandler(hub, sessionStore)
|
||||
memoryHandler := handler.NewMemoryHandler(cfg.AICoreURL)
|
||||
memoryHandler := handler.NewMemoryHandler(cfg.MemoryServiceURL)
|
||||
chatHandler := handler.NewChatHandler(cfg, hub)
|
||||
webhookHandler := handler.NewWebhookHandler(cfg, hub)
|
||||
|
||||
|
||||
@@ -4,4 +4,6 @@ use (
|
||||
./ai-core
|
||||
./gateway
|
||||
./iot-debug-service
|
||||
./memory-service
|
||||
./tool-engine
|
||||
)
|
||||
|
||||
@@ -227,7 +227,7 @@ func (ds *DeviceStore) Toggle(id string) (*Device, error) {
|
||||
return &cp, nil
|
||||
}
|
||||
|
||||
// SetProperty 设置设备属性(温度、亮度、位置、模式、颜色等)
|
||||
// SetProperty 设置设备属性(状态、温度、亮度、位置、模式、颜色等)
|
||||
func (ds *DeviceStore) SetProperty(id, field string, value interface{}) (*Device, error) {
|
||||
ds.mu.Lock()
|
||||
defer ds.mu.Unlock()
|
||||
@@ -240,6 +240,73 @@ func (ds *DeviceStore) SetProperty(id, field string, value interface{}) (*Device
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
switch field {
|
||||
case "status", "power":
|
||||
// 声明式电源控制:支持 "on"/"off"/"open"/"closed"/"locked"/"unlocked"
|
||||
// 同时支持中文值: "开"/"关"/"打开"/"关闭"
|
||||
// 支持布尔值: true/false → on/off
|
||||
var newStatus string
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
newStatus = normalizeStatus(v, d.Type)
|
||||
if newStatus == "" {
|
||||
return nil, fmt.Errorf("无效的状态值: %s (设备类型 %s 不支持此状态)", v, d.Type)
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
newStatus = "on"
|
||||
} else {
|
||||
newStatus = "off"
|
||||
}
|
||||
case float64:
|
||||
if v == 0 {
|
||||
newStatus = "off"
|
||||
} else {
|
||||
newStatus = "on"
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("status 需要字符串、布尔值或数字")
|
||||
}
|
||||
// 对于非 on/off 设备类型(curtain/lock),转换为对应状态
|
||||
switch d.Type {
|
||||
case TypeCurtain:
|
||||
if newStatus == "on" {
|
||||
newStatus = "open"
|
||||
} else if newStatus == "off" {
|
||||
newStatus = "closed"
|
||||
}
|
||||
case TypeLock:
|
||||
if newStatus == "on" {
|
||||
newStatus = "unlocked"
|
||||
} else if newStatus == "off" {
|
||||
newStatus = "locked"
|
||||
}
|
||||
}
|
||||
|
||||
if d.Status == newStatus {
|
||||
// 状态未变化,直接返回
|
||||
cp := *d
|
||||
return &cp, nil
|
||||
}
|
||||
oldStatus := d.Status
|
||||
d.Status = newStatus
|
||||
// 根据设备类型设置关联属性
|
||||
switch d.Type {
|
||||
case TypeLight:
|
||||
if newStatus == "off" {
|
||||
d.Brightness = 0
|
||||
} else if d.Brightness == 0 {
|
||||
d.Brightness = 80
|
||||
}
|
||||
case TypeCurtain:
|
||||
if newStatus == "closed" {
|
||||
d.Position = 0
|
||||
} else {
|
||||
d.Position = 100
|
||||
}
|
||||
}
|
||||
d.LastUpdated = now
|
||||
ds.addHistory(id, "status", oldStatus, newStatus)
|
||||
|
||||
case "temperature":
|
||||
v, ok := toFloat64(value)
|
||||
if !ok {
|
||||
@@ -323,6 +390,36 @@ func (ds *DeviceStore) SetProperty(id, field string, value interface{}) (*Device
|
||||
return &cp, nil
|
||||
}
|
||||
|
||||
// normalizeStatus 标准化电源状态值
|
||||
// 支持英文: "on"/"off"/"open"/"closed"/"locked"/"unlocked"
|
||||
// 支持中文: "开"/"关"/"打开"/"关闭"/"开启"/"解锁"/"锁定"/"上锁"
|
||||
// 支持布尔兼容: "true"/"false"
|
||||
// 返回空字符串表示无效值
|
||||
func normalizeStatus(v string, dType DeviceType) string {
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "on", "true", "开", "打开", "开启":
|
||||
return "on"
|
||||
case "off", "false", "关", "关闭":
|
||||
return "off"
|
||||
case "open":
|
||||
if dType == TypeCurtain {
|
||||
return "open"
|
||||
}
|
||||
return "on"
|
||||
case "closed":
|
||||
if dType == TypeCurtain {
|
||||
return "closed"
|
||||
}
|
||||
return "off"
|
||||
case "locked", "锁定", "上锁":
|
||||
return "locked"
|
||||
case "unlocked", "解锁":
|
||||
return "unlocked"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// toFloat64 将 interface{} 转换为 float64
|
||||
func toFloat64(v interface{}) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# Build stage
|
||||
FROM golang:1.24-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /memory-service ./cmd/
|
||||
|
||||
# Runtime stage
|
||||
FROM alpine:3.21
|
||||
|
||||
RUN apk --no-cache add ca-certificates
|
||||
WORKDIR /app
|
||||
COPY --from=builder /memory-service .
|
||||
|
||||
EXPOSE 8091
|
||||
|
||||
ENTRYPOINT ["./memory-service"]
|
||||
@@ -0,0 +1,76 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/config"
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/handler"
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/service"
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/store"
|
||||
)
|
||||
|
||||
func main() {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
log.Println("🧠 Memory-Service 启动中...")
|
||||
|
||||
// 加载配置
|
||||
cfg := config.Load()
|
||||
|
||||
log.Printf("配置: 端口=%s, 数据库=%s...", cfg.Port, maskDBURL(cfg.DatabaseURL))
|
||||
|
||||
// 初始化数据库存储
|
||||
memStore := store.NewStore(cfg.DatabaseURL)
|
||||
defer memStore.Close()
|
||||
|
||||
// 初始化服务层
|
||||
svc := service.NewMemoryService(memStore)
|
||||
|
||||
// 初始化 HTTP 处理器
|
||||
h := handler.NewMemoryHandler(svc)
|
||||
|
||||
// 注册路由
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
// 健康检查端点
|
||||
mux.HandleFunc("/api/v1/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
status := "ok"
|
||||
if !memStore.IsReady() {
|
||||
status = "degraded"
|
||||
}
|
||||
w.Write([]byte(`{"status":"` + status + `","service":"memory-service"}`))
|
||||
})
|
||||
|
||||
// 启动 HTTP 服务
|
||||
srv := &http.Server{
|
||||
Addr: ":" + cfg.Port,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("🚀 Memory-Service 已启动在端口 %s", cfg.Port)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("服务启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 优雅关闭
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
log.Println("正在关闭 Memory-Service...")
|
||||
srv.Close()
|
||||
log.Println("Memory-Service 已关闭")
|
||||
}
|
||||
|
||||
func maskDBURL(url string) string {
|
||||
if len(url) > 30 {
|
||||
return url[:30] + "..."
|
||||
}
|
||||
return url
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
module github.com/yourname/cyrene-ai/memory-service
|
||||
|
||||
go 1.26.2
|
||||
|
||||
require github.com/lib/pq v1.10.9
|
||||
@@ -0,0 +1,2 @@
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
@@ -0,0 +1,45 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Config 记忆服务配置
|
||||
type Config struct {
|
||||
Port string
|
||||
DatabaseURL string
|
||||
}
|
||||
|
||||
// Load 从环境变量加载配置
|
||||
func Load() *Config {
|
||||
return &Config{
|
||||
Port: getEnv("PORT", "8091"),
|
||||
DatabaseURL: buildDatabaseURL(),
|
||||
}
|
||||
}
|
||||
|
||||
// buildDatabaseURL 构建 PostgreSQL 连接字符串
|
||||
func buildDatabaseURL() string {
|
||||
// 优先使用 DB_URL 环境变量(简化模式)
|
||||
if url := os.Getenv("DB_URL"); url != "" {
|
||||
return url
|
||||
}
|
||||
|
||||
host := getEnv("POSTGRES_HOST", "localhost")
|
||||
port := getEnv("POSTGRES_PORT", "5432")
|
||||
user := getEnv("POSTGRES_USER", "cyrene")
|
||||
password := getEnv("POSTGRES_PASSWORD", "change_me")
|
||||
dbname := getEnv("POSTGRES_DB", "cyrene_ai")
|
||||
sslmode := getEnv("POSTGRES_SSLMODE", "disable")
|
||||
|
||||
return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s",
|
||||
user, password, host, port, dbname, sslmode)
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
@@ -0,0 +1,542 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/model"
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/service"
|
||||
)
|
||||
|
||||
// MemoryHandler HTTP API 处理器
|
||||
type MemoryHandler struct {
|
||||
svc *service.MemoryService
|
||||
}
|
||||
|
||||
// NewMemoryHandler 创建记忆处理器
|
||||
func NewMemoryHandler(svc *service.MemoryService) *MemoryHandler {
|
||||
return &MemoryHandler{svc: svc}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册所有路由到 mux
|
||||
func (h *MemoryHandler) RegisterRoutes(mux *http.ServeMux) {
|
||||
// POST /api/v1/memories - 创建/保存记忆
|
||||
mux.HandleFunc("/api/v1/memories", h.handleMemories)
|
||||
// GET/DELETE/PUT /api/v1/memories/... (带 ID)
|
||||
mux.HandleFunc("/api/v1/memories/", h.handleMemoryByID)
|
||||
// POST /api/v1/memories/query - 语义查询
|
||||
mux.HandleFunc("/api/v1/memories/query", h.handleQuery)
|
||||
// POST /api/v1/memories/consolidate - 合并相似记忆
|
||||
mux.HandleFunc("/api/v1/memories/consolidate", h.handleConsolidate)
|
||||
// POST /api/v1/memories/decay - 衰减旧记忆
|
||||
mux.HandleFunc("/api/v1/memories/decay", h.handleDecay)
|
||||
// GET /api/v1/memories/categories - 获取类别统计
|
||||
mux.HandleFunc("/api/v1/memories/categories", h.handleCategories)
|
||||
// 自主思考日志 API
|
||||
mux.HandleFunc("/api/v1/thinking", h.handleThinking)
|
||||
mux.HandleFunc("/api/v1/thinking/", h.handleThinkingByID)
|
||||
mux.HandleFunc("/api/v1/thinking/stats", h.handleThinkingStats)
|
||||
}
|
||||
|
||||
// handleMemories 处理 /api/v1/memories
|
||||
// GET - 列出用户记忆 (?user_id=xxx&category=xxx&min_importance=xxx&limit=xxx)
|
||||
// POST - 创建记忆
|
||||
func (h *MemoryHandler) handleMemories(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.listMemories(w, r)
|
||||
case http.MethodPost:
|
||||
h.createMemory(w, r)
|
||||
default:
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// listMemories GET /api/v1/memories?user_id=xxx
|
||||
func (h *MemoryHandler) listMemories(w http.ResponseWriter, r *http.Request) {
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
if userID == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 user_id 参数")
|
||||
return
|
||||
}
|
||||
|
||||
category := r.URL.Query().Get("category")
|
||||
limit := queryInt(r, "limit", 50)
|
||||
minImportance := queryInt(r, "min_importance", 0)
|
||||
|
||||
memories, err := h.svc.ListMemories(r.Context(), userID, category, minImportance, limit)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 列出记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if memories == nil {
|
||||
memories = []model.MemoryEntry{}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"memories": memories,
|
||||
"total": len(memories),
|
||||
})
|
||||
}
|
||||
|
||||
// createMemory POST /api/v1/memories
|
||||
func (h *MemoryHandler) createMemory(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
UserID string `json:"user_id"`
|
||||
Content string `json:"content"`
|
||||
Summary string `json:"summary"`
|
||||
Category string `json:"category"`
|
||||
Priority int `json:"priority"`
|
||||
Importance int `json:"importance"`
|
||||
Keywords []string `json:"keywords"`
|
||||
SessionID string `json:"session_id"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.UserID == "" || req.Content == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 user_id 或 content")
|
||||
return
|
||||
}
|
||||
|
||||
entry := &model.MemoryEntry{
|
||||
UserID: req.UserID,
|
||||
Content: req.Content,
|
||||
Summary: req.Summary,
|
||||
Category: model.MemoryCategory(req.Category),
|
||||
Priority: model.MemoryPriority(req.Priority),
|
||||
Importance: req.Importance,
|
||||
Keywords: req.Keywords,
|
||||
SessionID: req.SessionID,
|
||||
Source: req.Source,
|
||||
}
|
||||
|
||||
if err := h.svc.CreateMemory(r.Context(), entry); err != nil {
|
||||
log.Printf("[memory-handler] 创建记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"status": "saved",
|
||||
"memory": entry,
|
||||
})
|
||||
}
|
||||
|
||||
// handleMemoryByID 处理 /api/v1/memories/{id}
|
||||
// GET - 获取单个记忆
|
||||
// PUT - 更新记忆
|
||||
// DELETE - 删除记忆
|
||||
func (h *MemoryHandler) handleMemoryByID(w http.ResponseWriter, r *http.Request) {
|
||||
id := strings.TrimPrefix(r.URL.Path, "/api/v1/memories/")
|
||||
// 排除子路径 (query, consolidate, decay, categories)
|
||||
switch id {
|
||||
case "query", "consolidate", "decay", "categories":
|
||||
return // 这些有自己独立的处理器
|
||||
}
|
||||
|
||||
if id == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少记忆 ID")
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.getMemory(w, r, id)
|
||||
case http.MethodPut:
|
||||
h.updateMemory(w, r, id)
|
||||
case http.MethodDelete:
|
||||
h.deleteMemory(w, r, id)
|
||||
default:
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// getMemory GET /api/v1/memories/:id
|
||||
func (h *MemoryHandler) getMemory(w http.ResponseWriter, r *http.Request, id string) {
|
||||
entry, err := h.svc.GetMemory(r.Context(), id)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 获取记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if entry == nil {
|
||||
writeError(w, http.StatusNotFound, "记忆不存在")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"memory": entry,
|
||||
})
|
||||
}
|
||||
|
||||
// updateMemory PUT /api/v1/memories/:id
|
||||
func (h *MemoryHandler) updateMemory(w http.ResponseWriter, r *http.Request, id string) {
|
||||
var req struct {
|
||||
Content string `json:"content"`
|
||||
Summary string `json:"summary"`
|
||||
Category string `json:"category"`
|
||||
Priority int `json:"priority"`
|
||||
Importance int `json:"importance"`
|
||||
Keywords []string `json:"keywords"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
entry := &model.MemoryEntry{
|
||||
ID: id,
|
||||
Content: req.Content,
|
||||
Summary: req.Summary,
|
||||
Category: model.MemoryCategory(req.Category),
|
||||
Priority: model.MemoryPriority(req.Priority),
|
||||
Importance: req.Importance,
|
||||
Keywords: req.Keywords,
|
||||
Source: req.Source,
|
||||
}
|
||||
|
||||
if err := h.svc.UpdateMemory(r.Context(), entry); err != nil {
|
||||
log.Printf("[memory-handler] 更新记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"status": "updated",
|
||||
"memory_id": id,
|
||||
})
|
||||
}
|
||||
|
||||
// deleteMemory DELETE /api/v1/memories/:id
|
||||
func (h *MemoryHandler) deleteMemory(w http.ResponseWriter, r *http.Request, id string) {
|
||||
if err := h.svc.DeleteMemory(r.Context(), id); err != nil {
|
||||
log.Printf("[memory-handler] 删除记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"status": "deleted",
|
||||
"memory_id": id,
|
||||
})
|
||||
}
|
||||
|
||||
// handleQuery POST /api/v1/memories/query
|
||||
func (h *MemoryHandler) handleQuery(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
UserID string `json:"user_id"`
|
||||
QueryText string `json:"query_text"`
|
||||
Category string `json:"category"`
|
||||
MinImportance int `json:"min_importance"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.UserID == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 user_id")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Limit <= 0 {
|
||||
req.Limit = 10
|
||||
}
|
||||
|
||||
memories, err := h.svc.QueryMemories(r.Context(), req.UserID, req.QueryText, req.Category, req.MinImportance, req.Limit)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 查询记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if memories == nil {
|
||||
memories = []model.MemoryEntry{}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"user_id": req.UserID,
|
||||
"query": req.QueryText,
|
||||
"memories": memories,
|
||||
"total": len(memories),
|
||||
})
|
||||
}
|
||||
|
||||
// handleConsolidate POST /api/v1/memories/consolidate
|
||||
func (h *MemoryHandler) handleConsolidate(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.UserID == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 user_id")
|
||||
return
|
||||
}
|
||||
|
||||
merged, err := h.svc.ConsolidateMemories(r.Context(), req.UserID)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 合并记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"status": "consolidated",
|
||||
"user_id": req.UserID,
|
||||
"merged": merged,
|
||||
"message": "记忆整理完成",
|
||||
})
|
||||
}
|
||||
|
||||
// handleDecay POST /api/v1/memories/decay
|
||||
func (h *MemoryHandler) handleDecay(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.UserID == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 user_id")
|
||||
return
|
||||
}
|
||||
|
||||
decayed, deleted, err := h.svc.DecayMemories(r.Context(), req.UserID)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 衰减记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"status": "decayed",
|
||||
"user_id": req.UserID,
|
||||
"decayed": decayed,
|
||||
"deleted": deleted,
|
||||
"message": "记忆衰减完成",
|
||||
})
|
||||
}
|
||||
|
||||
// handleCategories GET /api/v1/memories/categories?user_id=xxx
|
||||
func (h *MemoryHandler) handleCategories(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
if userID == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 user_id 参数")
|
||||
return
|
||||
}
|
||||
|
||||
categories, err := h.svc.GetCategories(r.Context(), userID)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 获取分类统计失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"categories": categories,
|
||||
})
|
||||
}
|
||||
|
||||
// --- 辅助函数 ---
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, message string) {
|
||||
writeJSON(w, status, map[string]interface{}{
|
||||
"error": message,
|
||||
})
|
||||
}
|
||||
|
||||
func queryInt(r *http.Request, key string, fallback int) int {
|
||||
v := r.URL.Query().Get(key)
|
||||
if v == "" {
|
||||
return fallback
|
||||
}
|
||||
var result int
|
||||
for _, c := range v {
|
||||
if c < '0' || c > '9' {
|
||||
return fallback
|
||||
}
|
||||
result = result*10 + int(c-'0')
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// handleThinking 处理 /api/v1/thinking
|
||||
// GET - 分页查询思考日志 (?user_id=xxx&limit=xxx&offset=xxx)
|
||||
// POST - 保存思考日志
|
||||
func (h *MemoryHandler) handleThinking(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.listThinkingLogs(w, r)
|
||||
case http.MethodPost:
|
||||
h.createThinkingLog(w, r)
|
||||
default:
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// createThinkingLog POST /api/v1/thinking
|
||||
func (h *MemoryHandler) createThinkingLog(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
UserID string `json:"user_id"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls string `json:"tool_calls"`
|
||||
ToolCallCount int `json:"tool_call_count"`
|
||||
ContentLength int `json:"content_length"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Content == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少 content")
|
||||
return
|
||||
}
|
||||
|
||||
tl := &model.ThinkingLog{
|
||||
UserID: req.UserID,
|
||||
Content: req.Content,
|
||||
ToolCalls: req.ToolCalls,
|
||||
ToolCallCount: req.ToolCallCount,
|
||||
ContentLength: req.ContentLength,
|
||||
}
|
||||
|
||||
if err := h.svc.SaveThinkingLog(r.Context(), tl); err != nil {
|
||||
log.Printf("[memory-handler] 保存思考日志失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"status": "saved",
|
||||
"thinking": tl,
|
||||
})
|
||||
}
|
||||
|
||||
// listThinkingLogs GET /api/v1/thinking?user_id=xxx&limit=xxx&offset=xxx
|
||||
func (h *MemoryHandler) listThinkingLogs(w http.ResponseWriter, r *http.Request) {
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
limit := queryInt(r, "limit", 20)
|
||||
offset := queryInt(r, "offset", 0)
|
||||
|
||||
logs, err := h.svc.QueryThinkingLogs(r.Context(), model.ThinkingQuery{
|
||||
UserID: userID,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 查询思考日志失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if logs == nil {
|
||||
logs = []model.ThinkingLog{}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"logs": logs,
|
||||
"total": len(logs),
|
||||
})
|
||||
}
|
||||
|
||||
// handleThinkingByID 处理 /api/v1/thinking/{id}
|
||||
// GET - 获取单条思考日志
|
||||
func (h *MemoryHandler) handleThinkingByID(w http.ResponseWriter, r *http.Request) {
|
||||
id := strings.TrimPrefix(r.URL.Path, "/api/v1/thinking/")
|
||||
// 排除子路径
|
||||
if id == "stats" || id == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
thinkingLog, err := h.svc.GetThinkingLogByID(r.Context(), id)
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 获取思考日志失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if thinkingLog == nil {
|
||||
writeError(w, http.StatusNotFound, "思考日志不存在")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"thinking": thinkingLog,
|
||||
})
|
||||
}
|
||||
|
||||
// handleThinkingStats GET /api/v1/thinking/stats
|
||||
func (h *MemoryHandler) handleThinkingStats(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.svc.GetThinkingStats(r.Context())
|
||||
if err != nil {
|
||||
log.Printf("[memory-handler] 获取思考日志统计失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"stats": stats,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryPriority 记忆优先级
|
||||
type MemoryPriority int
|
||||
|
||||
const (
|
||||
MemoryTemp MemoryPriority = 0 // 临时记忆 (会话内)
|
||||
MemoryNormal MemoryPriority = 1 // 普通记忆
|
||||
MemoryImportant MemoryPriority = 2 // 重要记忆
|
||||
MemoryCore MemoryPriority = 3 // 核心记忆 (永远保留)
|
||||
)
|
||||
|
||||
// String 返回优先级的中文描述
|
||||
func (p MemoryPriority) String() string {
|
||||
switch p {
|
||||
case MemoryCore:
|
||||
return "核心"
|
||||
case MemoryImportant:
|
||||
return "重要"
|
||||
case MemoryNormal:
|
||||
return "普通"
|
||||
case MemoryTemp:
|
||||
return "临时"
|
||||
default:
|
||||
return "未知"
|
||||
}
|
||||
}
|
||||
|
||||
// MemoryCategory 记忆分类
|
||||
type MemoryCategory string
|
||||
|
||||
const (
|
||||
CategoryUserPreference MemoryCategory = "user_preference" // 用户偏好 (食物、颜色、习惯)
|
||||
CategoryPersonalInfo MemoryCategory = "personal_info" // 个人信息 (姓名、年龄、职业)
|
||||
CategoryConversation MemoryCategory = "conversation" // 对话摘要
|
||||
CategoryKnowledge MemoryCategory = "knowledge" // 知识性信息
|
||||
CategoryEvent MemoryCategory = "event" // 事件记录
|
||||
CategoryTask MemoryCategory = "task" // 任务/计划
|
||||
CategoryRelationship MemoryCategory = "relationship" // 关系信息
|
||||
|
||||
// 向后兼容的旧分类别名
|
||||
CategoryPreference = CategoryUserPreference
|
||||
CategoryFact = CategoryPersonalInfo
|
||||
CategoryHabit = CategoryUserPreference
|
||||
CategoryOther = CategoryKnowledge
|
||||
)
|
||||
|
||||
// CategoryDisplayName 返回分类的中文显示名
|
||||
func (c MemoryCategory) DisplayName() string {
|
||||
switch c {
|
||||
case CategoryUserPreference:
|
||||
return "用户偏好"
|
||||
case CategoryPersonalInfo:
|
||||
return "个人信息"
|
||||
case CategoryConversation:
|
||||
return "对话摘要"
|
||||
case CategoryKnowledge:
|
||||
return "知识信息"
|
||||
case CategoryEvent:
|
||||
return "事件记录"
|
||||
case CategoryTask:
|
||||
return "任务计划"
|
||||
case CategoryRelationship:
|
||||
return "关系情感"
|
||||
default:
|
||||
return "其他"
|
||||
}
|
||||
}
|
||||
|
||||
// MemoryEntry 记忆条目
|
||||
type MemoryEntry struct {
|
||||
ID string `json:"id" db:"id"`
|
||||
UserID string `json:"user_id" db:"user_id"`
|
||||
Content string `json:"content" db:"content"`
|
||||
Summary string `json:"summary" db:"summary"` // 简短摘要
|
||||
Category MemoryCategory `json:"category" db:"category"`
|
||||
Priority MemoryPriority `json:"priority" db:"priority"`
|
||||
Importance int `json:"importance" db:"importance"` // 重要程度 1-10
|
||||
Keywords []string `json:"keywords" db:"keywords"` // 关键词标签
|
||||
SessionID string `json:"session_id" db:"session_id"` // 来源会话
|
||||
Source string `json:"source" db:"source"` // 来源 (conversation/thinking)
|
||||
Embedding []float32 `json:"-" db:"embedding"` // 向量 (pgvector)
|
||||
AccessCount int `json:"access_count" db:"access_count"`
|
||||
LastAccess time.Time `json:"last_access" db:"last_access"`
|
||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"` // 最后更新时间
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty" db:"expires_at"` // 临时记忆过期时间
|
||||
}
|
||||
|
||||
// KeywordsJSON 将关键词序列化为 JSON 字符串(用于数据库存储)
|
||||
func (e *MemoryEntry) KeywordsJSON() string {
|
||||
if len(e.Keywords) == 0 {
|
||||
return "[]"
|
||||
}
|
||||
data, _ := json.Marshal(e.Keywords)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// ParseKeywords 从 JSON 字符串解析关键词
|
||||
func ParseKeywords(raw string) []string {
|
||||
if raw == "" || raw == "[]" {
|
||||
return nil
|
||||
}
|
||||
var keywords []string
|
||||
if err := json.Unmarshal([]byte(raw), &keywords); err != nil {
|
||||
return nil
|
||||
}
|
||||
return keywords
|
||||
}
|
||||
|
||||
// SimilarityScore 计算两个记忆条目的简单文本相似度(基于词汇重叠)
|
||||
// 返回值 0.0 - 1.0
|
||||
func (e *MemoryEntry) SimilarityScore(other *MemoryEntry) float64 {
|
||||
if e.Content == other.Content {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
// 基于关键词的重叠度
|
||||
if len(e.Keywords) > 0 && len(other.Keywords) > 0 {
|
||||
keywordSet := make(map[string]bool, len(e.Keywords))
|
||||
for _, k := range e.Keywords {
|
||||
keywordSet[k] = true
|
||||
}
|
||||
overlap := 0
|
||||
for _, k := range other.Keywords {
|
||||
if keywordSet[k] {
|
||||
overlap++
|
||||
}
|
||||
}
|
||||
keywordScore := float64(overlap) / float64(max(len(e.Keywords), len(other.Keywords)))
|
||||
if keywordScore > 0.6 {
|
||||
return keywordScore
|
||||
}
|
||||
}
|
||||
|
||||
// 基于内容的字符级 Jaccard 相似度
|
||||
return jaccardSimilarity(e.Content, other.Content)
|
||||
}
|
||||
|
||||
// jaccardSimilarity 计算两个字符串的 Jaccard 相似度
|
||||
func jaccardSimilarity(a, b string) float64 {
|
||||
if a == b {
|
||||
return 1.0
|
||||
}
|
||||
if len(a) == 0 || len(b) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 使用 bigram 分词
|
||||
bigramsA := make(map[string]int)
|
||||
runesA := []rune(a)
|
||||
for i := 0; i < len(runesA)-1; i++ {
|
||||
bigramsA[string(runesA[i:i+2])]++
|
||||
}
|
||||
|
||||
bigramsB := make(map[string]int)
|
||||
runesB := []rune(b)
|
||||
for i := 0; i < len(runesB)-1; i++ {
|
||||
bigramsB[string(runesB[i:i+2])]++
|
||||
}
|
||||
|
||||
intersection := 0
|
||||
for bg, countA := range bigramsA {
|
||||
if countB, ok := bigramsB[bg]; ok {
|
||||
intersection += min(countA, countB)
|
||||
}
|
||||
}
|
||||
|
||||
union := 0
|
||||
allBigrams := make(map[string]bool)
|
||||
for bg := range bigramsA {
|
||||
allBigrams[bg] = true
|
||||
}
|
||||
for bg := range bigramsB {
|
||||
allBigrams[bg] = true
|
||||
}
|
||||
for bg := range allBigrams {
|
||||
union += max(bigramsA[bg], bigramsB[bg])
|
||||
}
|
||||
|
||||
if union == 0 {
|
||||
return 0.0
|
||||
}
|
||||
return float64(intersection) / float64(union)
|
||||
}
|
||||
|
||||
// MemoryQuery 记忆查询参数
|
||||
type MemoryQuery struct {
|
||||
UserID string
|
||||
Query string // 查询文本
|
||||
Category MemoryCategory
|
||||
Priority MemoryPriority
|
||||
MinImportance int // 最低重要程度筛选
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// ThinkingLog 自主思考日志
|
||||
type ThinkingLog struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls string `json:"tool_calls"` // JSON 数组
|
||||
ToolCallCount int `json:"tool_call_count"`
|
||||
ContentLength int `json:"content_length"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// ThinkingQuery 思考日志查询参数
|
||||
type ThinkingQuery struct {
|
||||
UserID string
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// ThinkingStats 思考日志统计
|
||||
type ThinkingStats struct {
|
||||
TotalLogs int `json:"total_logs"`
|
||||
TotalToolCalls int `json:"total_tool_calls"`
|
||||
AvgContentLen float64 `json:"avg_content_length"`
|
||||
LatestAt string `json:"latest_at"`
|
||||
}
|
||||
@@ -0,0 +1,316 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/model"
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/store"
|
||||
)
|
||||
|
||||
// deDupThreshold 去重相似度阈值
|
||||
const deDupThreshold = 0.75
|
||||
|
||||
// MemoryService 记忆业务逻辑
|
||||
type MemoryService struct {
|
||||
store *store.Store
|
||||
}
|
||||
|
||||
// NewMemoryService 创建记忆服务
|
||||
func NewMemoryService(s *store.Store) *MemoryService {
|
||||
return &MemoryService{store: s}
|
||||
}
|
||||
|
||||
// CreateMemory 创建/保存记忆
|
||||
func (svc *MemoryService) CreateMemory(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
if entry.UserID == "" {
|
||||
return fmt.Errorf("user_id 不能为空")
|
||||
}
|
||||
if entry.Content == "" {
|
||||
return fmt.Errorf("content 不能为空")
|
||||
}
|
||||
if entry.Category == "" {
|
||||
entry.Category = model.CategoryKnowledge
|
||||
}
|
||||
if entry.Importance < 1 {
|
||||
entry.Importance = 5
|
||||
}
|
||||
if entry.Priority < 0 || entry.Priority > 3 {
|
||||
entry.Priority = model.MemoryNormal
|
||||
}
|
||||
if entry.Source == "" {
|
||||
entry.Source = "manual"
|
||||
}
|
||||
|
||||
// 去重检查
|
||||
similar, err := svc.findSimilar(ctx, entry.UserID, entry)
|
||||
if err == nil && similar != nil {
|
||||
// 合并到已有记忆
|
||||
return svc.mergeMemory(ctx, similar, entry)
|
||||
}
|
||||
|
||||
return svc.store.Save(ctx, entry)
|
||||
}
|
||||
|
||||
// GetMemory 获取单个记忆
|
||||
func (svc *MemoryService) GetMemory(ctx context.Context, id string) (*model.MemoryEntry, error) {
|
||||
return svc.store.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// ListMemories 列出用户所有记忆
|
||||
func (svc *MemoryService) ListMemories(ctx context.Context, userID string, category string, minImportance int, limit int) ([]model.MemoryEntry, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
q := model.MemoryQuery{
|
||||
UserID: userID,
|
||||
MinImportance: minImportance,
|
||||
Limit: limit,
|
||||
}
|
||||
if category != "" {
|
||||
q.Category = model.MemoryCategory(category)
|
||||
}
|
||||
|
||||
return svc.store.Query(ctx, q)
|
||||
}
|
||||
|
||||
// UpdateMemory 更新记忆
|
||||
func (svc *MemoryService) UpdateMemory(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
if entry.ID == "" {
|
||||
return fmt.Errorf("id 不能为空")
|
||||
}
|
||||
return svc.store.Update(ctx, entry)
|
||||
}
|
||||
|
||||
// DeleteMemory 删除记忆
|
||||
func (svc *MemoryService) DeleteMemory(ctx context.Context, id string) error {
|
||||
return svc.store.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// QueryMemories 语义查询 + 关键词匹配
|
||||
func (svc *MemoryService) QueryMemories(ctx context.Context, userID string, queryText string, category string, minImportance int, limit int) ([]model.MemoryEntry, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
var allEntries []model.MemoryEntry
|
||||
seen := make(map[string]bool)
|
||||
|
||||
// 1. 关键词匹配检索
|
||||
keywordEntries, err := svc.store.SearchByKeyword(ctx, userID, queryText, limit*2)
|
||||
if err == nil {
|
||||
for _, e := range keywordEntries {
|
||||
if !seen[e.ID] {
|
||||
seen[e.ID] = true
|
||||
allEntries = append(allEntries, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 补充按分类/重要性查询
|
||||
q := model.MemoryQuery{
|
||||
UserID: userID,
|
||||
MinImportance: minImportance,
|
||||
Limit: limit,
|
||||
}
|
||||
if category != "" {
|
||||
q.Category = model.MemoryCategory(category)
|
||||
}
|
||||
categoryEntries, err := svc.store.Query(ctx, q)
|
||||
if err == nil {
|
||||
for _, e := range categoryEntries {
|
||||
if !seen[e.ID] {
|
||||
seen[e.ID] = true
|
||||
allEntries = append(allEntries, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 在应用层做内容匹配过滤
|
||||
queryLower := strings.ToLower(queryText)
|
||||
var matched []model.MemoryEntry
|
||||
for _, entry := range allEntries {
|
||||
contentLower := strings.ToLower(entry.Content)
|
||||
summaryLower := strings.ToLower(entry.Summary)
|
||||
|
||||
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
|
||||
matched = append(matched, entry)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, kw := range entry.Keywords {
|
||||
if strings.Contains(queryLower, strings.ToLower(kw)) ||
|
||||
strings.Contains(strings.ToLower(kw), queryLower) {
|
||||
matched = append(matched, entry)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(matched) == 0 {
|
||||
matched = allEntries
|
||||
}
|
||||
|
||||
// 4. 去重合并
|
||||
matched = svc.deduplicate(matched)
|
||||
|
||||
// 5. 按重要性降序
|
||||
sortByImportance(matched)
|
||||
|
||||
if len(matched) > limit {
|
||||
matched = matched[:limit]
|
||||
}
|
||||
|
||||
return matched, nil
|
||||
}
|
||||
|
||||
// ConsolidateMemories 合并相似记忆
|
||||
func (svc *MemoryService) ConsolidateMemories(ctx context.Context, userID string) (int, error) {
|
||||
return svc.store.ConsolidateMemories(ctx, userID)
|
||||
}
|
||||
|
||||
// DecayMemories 衰减旧记忆
|
||||
func (svc *MemoryService) DecayMemories(ctx context.Context, userID string) (int, int, error) {
|
||||
return svc.store.DecayMemories(ctx, userID)
|
||||
}
|
||||
|
||||
// GetCategories 获取用户分类统计
|
||||
func (svc *MemoryService) GetCategories(ctx context.Context, userID string) (map[string]int, error) {
|
||||
return svc.store.GetCategories(ctx, userID)
|
||||
}
|
||||
|
||||
// findSimilar 查找相似记忆
|
||||
func (svc *MemoryService) findSimilar(ctx context.Context, userID string, newMem *model.MemoryEntry) (*model.MemoryEntry, error) {
|
||||
existing, err := svc.store.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Limit: 100,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range existing {
|
||||
score := existing[i].SimilarityScore(newMem)
|
||||
if score >= deDupThreshold {
|
||||
return &existing[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// mergeMemory 合并新记忆到已有记忆
|
||||
func (svc *MemoryService) mergeMemory(ctx context.Context, existing *model.MemoryEntry, newMem *model.MemoryEntry) error {
|
||||
// 更新内容(如果新内容更有价值)
|
||||
if newMem.Importance > existing.Importance || len(newMem.Content) > len(existing.Content) {
|
||||
existing.Content = newMem.Content
|
||||
existing.Summary = newMem.Summary
|
||||
}
|
||||
|
||||
// 合并关键词
|
||||
keywordSet := make(map[string]bool)
|
||||
for _, k := range existing.Keywords {
|
||||
keywordSet[k] = true
|
||||
}
|
||||
for _, k := range newMem.Keywords {
|
||||
keywordSet[k] = true
|
||||
}
|
||||
mergedKeywords := make([]string, 0, len(keywordSet))
|
||||
for k := range keywordSet {
|
||||
mergedKeywords = append(mergedKeywords, k)
|
||||
}
|
||||
existing.Keywords = mergedKeywords
|
||||
|
||||
// 取最高重要性
|
||||
if newMem.Importance > existing.Importance {
|
||||
existing.Importance = newMem.Importance
|
||||
}
|
||||
|
||||
// 取最高优先级
|
||||
if newMem.Priority > existing.Priority {
|
||||
existing.Priority = newMem.Priority
|
||||
}
|
||||
|
||||
existing.AccessCount++
|
||||
existing.Source = "merged"
|
||||
|
||||
log.Printf("[memory-service] 合并记忆 [%s|%d★]: %s (去重)", existing.Category, existing.Importance, existing.Summary)
|
||||
return svc.store.Update(ctx, existing)
|
||||
}
|
||||
|
||||
// deduplicate 去重合并
|
||||
func (svc *MemoryService) deduplicate(entries []model.MemoryEntry) []model.MemoryEntry {
|
||||
if len(entries) < 2 {
|
||||
return entries
|
||||
}
|
||||
|
||||
result := make([]model.MemoryEntry, 0, len(entries))
|
||||
discarded := make(map[int]bool)
|
||||
|
||||
for i := 0; i < len(entries); i++ {
|
||||
if discarded[i] {
|
||||
continue
|
||||
}
|
||||
for j := i + 1; j < len(entries); j++ {
|
||||
if discarded[j] {
|
||||
continue
|
||||
}
|
||||
score := entries[i].SimilarityScore(&entries[j])
|
||||
if score >= deDupThreshold {
|
||||
if entries[j].Importance > entries[i].Importance ||
|
||||
(entries[j].Importance == entries[i].Importance && entries[j].Priority > entries[i].Priority) {
|
||||
discarded[i] = true
|
||||
break
|
||||
} else {
|
||||
discarded[j] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !discarded[i] {
|
||||
result = append(result, entries[i])
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// sortByImportance 按 Importance 降序, Priority 降序排列
|
||||
func sortByImportance(entries []model.MemoryEntry) {
|
||||
for i := 0; i < len(entries); i++ {
|
||||
for j := i + 1; j < len(entries); j++ {
|
||||
if entries[j].Importance > entries[i].Importance ||
|
||||
(entries[j].Importance == entries[i].Importance && entries[j].Priority > entries[i].Priority) {
|
||||
entries[i], entries[j] = entries[j], entries[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SaveThinkingLog 保存自主思考日志
|
||||
func (svc *MemoryService) SaveThinkingLog(ctx context.Context, tl *model.ThinkingLog) error {
|
||||
if tl.Content == "" {
|
||||
return fmt.Errorf("content 不能为空")
|
||||
}
|
||||
if tl.UserID == "" {
|
||||
tl.UserID = "admin_admin"
|
||||
}
|
||||
return svc.store.SaveThinkingLog(ctx, tl)
|
||||
}
|
||||
|
||||
// QueryThinkingLogs 分页查询思考日志
|
||||
func (svc *MemoryService) QueryThinkingLogs(ctx context.Context, q model.ThinkingQuery) ([]model.ThinkingLog, error) {
|
||||
return svc.store.QueryThinkingLogs(ctx, q)
|
||||
}
|
||||
|
||||
// GetThinkingLogByID 获取单条思考日志
|
||||
func (svc *MemoryService) GetThinkingLogByID(ctx context.Context, id string) (*model.ThinkingLog, error) {
|
||||
return svc.store.GetThinkingLogByID(ctx, id)
|
||||
}
|
||||
|
||||
// GetThinkingStats 获取思考日志统计
|
||||
func (svc *MemoryService) GetThinkingStats(ctx context.Context) (*model.ThinkingStats, error) {
|
||||
return svc.store.GetThinkingStats(ctx)
|
||||
}
|
||||
@@ -0,0 +1,765 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/memory-service/internal/model"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// deDupThreshold 去重相似度阈值
|
||||
const deDupThreshold = 0.75
|
||||
|
||||
const reconnectInterval = 30 * time.Second
|
||||
|
||||
// Store 记忆持久化存储(PostgreSQL + pgvector)
|
||||
type Store struct {
|
||||
databaseURL string
|
||||
mu sync.RWMutex
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// errDBNotReady 数据库未就绪时返回的友好错误
|
||||
var errDBNotReady = fmt.Errorf("记忆系统未就绪: 数据库连接不可用,正在后台重试连接")
|
||||
|
||||
// NewStore 创建记忆存储
|
||||
// 连接失败时不返回 error,而是启动后台重连循环
|
||||
func NewStore(connStr string) *Store {
|
||||
s := &Store{
|
||||
databaseURL: connStr,
|
||||
}
|
||||
|
||||
// 尝试初始连接
|
||||
if err := s.Reconnect(); err != nil {
|
||||
log.Printf("[memory-service] ⚠ 记忆存储初始化: 数据库连接失败 (%v),将在后台每30秒重试", err)
|
||||
} else {
|
||||
log.Println("[memory-service] 记忆存储已就绪")
|
||||
}
|
||||
|
||||
// 启动后台重连 goroutine
|
||||
go s.reconnectLoop()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// reconnectLoop 后台重连循环
|
||||
func (s *Store) reconnectLoop() {
|
||||
ticker := time.NewTicker(reconnectInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.mu.RLock()
|
||||
ready := s.db != nil
|
||||
s.mu.RUnlock()
|
||||
|
||||
if ready {
|
||||
// 数据库已连接,检查连接是否仍然有效
|
||||
s.mu.RLock()
|
||||
db := s.db
|
||||
s.mu.RUnlock()
|
||||
if db != nil {
|
||||
if err := db.Ping(); err != nil {
|
||||
log.Printf("[memory-service] ⚠ 数据库连接丢失: %v,开始重连", err)
|
||||
s.mu.Lock()
|
||||
if s.db != nil {
|
||||
s.db.Close()
|
||||
s.db = nil
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !s.IsReady() {
|
||||
if err := s.Reconnect(); err != nil {
|
||||
log.Printf("[memory-service] ⚠ 数据库重连失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reconnect 尝试重连数据库并执行迁移
|
||||
func (s *Store) Reconnect() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 如果已有有效连接,先检查
|
||||
if s.db != nil {
|
||||
if err := s.db.Ping(); err == nil {
|
||||
return nil // 仍然有效
|
||||
}
|
||||
// 连接已失效,关闭旧连接
|
||||
s.db.Close()
|
||||
s.db = nil
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", s.databaseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return fmt.Errorf("数据库ping失败: %w", err)
|
||||
}
|
||||
|
||||
s.db = db
|
||||
|
||||
// 执行建表迁移
|
||||
if err := s.migrate(); err != nil {
|
||||
log.Printf("[memory-service] ⚠ 数据库迁移失败: %v", err)
|
||||
s.db.Close()
|
||||
s.db = nil
|
||||
return fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
log.Println("[memory-service] ✅ 数据库重连成功,记忆系统已就绪")
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsReady 返回数据库是否可用
|
||||
func (s *Store) IsReady() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.db != nil
|
||||
}
|
||||
|
||||
// getDB 获取当前数据库连接(带读锁保护)
|
||||
func (s *Store) getDB() *sql.DB {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.db
|
||||
}
|
||||
|
||||
// migrate 创建表结构
|
||||
func (s *Store) migrate() error {
|
||||
queries := []string{
|
||||
`CREATE EXTENSION IF NOT EXISTS vector`,
|
||||
`CREATE TABLE IF NOT EXISTS memory_entries (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id VARCHAR(64) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
summary TEXT DEFAULT '',
|
||||
category VARCHAR(32) DEFAULT 'knowledge',
|
||||
priority INT DEFAULT 1,
|
||||
importance INT DEFAULT 5,
|
||||
keywords TEXT DEFAULT '[]',
|
||||
session_id VARCHAR(64) DEFAULT '',
|
||||
source TEXT DEFAULT 'conversation',
|
||||
embedding vector(1536),
|
||||
access_count INT DEFAULT 0,
|
||||
last_access TIMESTAMPTZ DEFAULT NOW(),
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_user_id ON memory_entries(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_category ON memory_entries(category)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_priority ON memory_entries(priority)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_importance ON memory_entries(importance)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_user_priority ON memory_entries(user_id, priority DESC)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_user_importance ON memory_entries(user_id, importance DESC)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_source ON memory_entries(source)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_me_category_importance ON memory_entries(category, importance DESC)`,
|
||||
`CREATE TABLE IF NOT EXISTS thinking_logs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id VARCHAR(64) NOT NULL DEFAULT 'admin_admin',
|
||||
content TEXT NOT NULL,
|
||||
tool_calls TEXT DEFAULT '[]',
|
||||
tool_call_count INT DEFAULT 0,
|
||||
content_length INT DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_tl_user_id ON thinking_logs(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_tl_created_at ON thinking_logs(created_at DESC)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
if _, err := s.db.Exec(q); err != nil {
|
||||
return fmt.Errorf("执行迁移 '%s' 失败: %w", q[:min(50, len(q))], err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save 保存记忆
|
||||
func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return errDBNotReady
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if entry.Source == "" {
|
||||
entry.Source = "conversation"
|
||||
}
|
||||
if entry.Importance == 0 {
|
||||
entry.Importance = 5
|
||||
}
|
||||
|
||||
query := `INSERT INTO memory_entries (user_id, content, summary, category, priority, importance, keywords, session_id, source, embedding, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING id, created_at`
|
||||
|
||||
var embedding interface{}
|
||||
if len(entry.Embedding) > 0 {
|
||||
vec := make([]float64, len(entry.Embedding))
|
||||
for i, v := range entry.Embedding {
|
||||
vec[i] = float64(v)
|
||||
}
|
||||
embedding = fmt.Sprintf("[%s]", joinFloats(vec))
|
||||
}
|
||||
|
||||
return db.QueryRowContext(ctx, query,
|
||||
entry.UserID, entry.Content, entry.Summary,
|
||||
string(entry.Category), int(entry.Priority),
|
||||
entry.Importance, entry.KeywordsJSON(),
|
||||
entry.SessionID, entry.Source, embedding, entry.ExpiresAt,
|
||||
).Scan(&entry.ID, &entry.CreatedAt)
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取记忆
|
||||
func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
|
||||
session_id, source, access_count, last_access, created_at, updated_at, expires_at
|
||||
FROM memory_entries WHERE id = $1`
|
||||
|
||||
entry := &model.MemoryEntry{}
|
||||
var category, keywordsRaw string
|
||||
err := db.QueryRowContext(ctx, query, id).Scan(
|
||||
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
|
||||
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
|
||||
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
|
||||
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询记忆失败: %w", err)
|
||||
}
|
||||
entry.Category = model.MemoryCategory(category)
|
||||
entry.Keywords = model.ParseKeywords(keywordsRaw)
|
||||
|
||||
// 更新访问计数
|
||||
go s.incrementAccess(context.Background(), id)
|
||||
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// Query 按条件查询记忆
|
||||
func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryEntry, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
if q.Limit <= 0 {
|
||||
q.Limit = 10
|
||||
}
|
||||
|
||||
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
|
||||
session_id, source, access_count, last_access, created_at, updated_at, expires_at
|
||||
FROM memory_entries WHERE user_id = $1`
|
||||
args := []interface{}{q.UserID}
|
||||
argIdx := 2
|
||||
|
||||
if q.Category != "" {
|
||||
query += fmt.Sprintf(" AND category = $%d", argIdx)
|
||||
args = append(args, string(q.Category))
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if q.Priority >= 0 {
|
||||
query += fmt.Sprintf(" AND priority >= $%d", argIdx)
|
||||
args = append(args, int(q.Priority))
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if q.MinImportance > 0 {
|
||||
query += fmt.Sprintf(" AND importance >= $%d", argIdx)
|
||||
args = append(args, q.MinImportance)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
query += fmt.Sprintf(" ORDER BY priority DESC, importance DESC, created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
|
||||
args = append(args, q.Limit, q.Offset)
|
||||
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询记忆失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanMemoryRows(rows)
|
||||
}
|
||||
|
||||
// Delete 删除记忆
|
||||
func (s *Store) Delete(ctx context.Context, id string) error {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return errDBNotReady
|
||||
}
|
||||
_, err := db.ExecContext(ctx, `DELETE FROM memory_entries WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// PurgeExpired 清理过期记忆
|
||||
func (s *Store) PurgeExpired(ctx context.Context) (int64, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return 0, errDBNotReady
|
||||
}
|
||||
result, err := db.ExecContext(ctx,
|
||||
`DELETE FROM memory_entries WHERE expires_at IS NOT NULL AND expires_at < NOW()`)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// SearchByVector 向量相似度搜索
|
||||
func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []float64, limit int) ([]model.MemoryEntry, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
vecStr := fmt.Sprintf("[%s]", joinFloats(embedding))
|
||||
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
|
||||
session_id, source, access_count, last_access, created_at, updated_at, expires_at,
|
||||
1 - (embedding <=> $1) AS similarity
|
||||
FROM memory_entries
|
||||
WHERE user_id = $2 AND embedding IS NOT NULL
|
||||
ORDER BY embedding <=> $1
|
||||
LIMIT $3`
|
||||
|
||||
rows, err := db.QueryContext(ctx, query, vecStr, userID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("向量搜索失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []model.MemoryEntry
|
||||
for rows.Next() {
|
||||
var entry model.MemoryEntry
|
||||
var category, keywordsRaw string
|
||||
var similarity float64
|
||||
if err := rows.Scan(
|
||||
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
|
||||
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
|
||||
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
|
||||
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
|
||||
&similarity,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("扫描向量搜索结果失败: %w", err)
|
||||
}
|
||||
entry.Category = model.MemoryCategory(category)
|
||||
entry.Keywords = model.ParseKeywords(keywordsRaw)
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
// SearchByKeyword 关键词匹配查询
|
||||
func (s *Store) SearchByKeyword(ctx context.Context, userID, keyword string, limit int) ([]model.MemoryEntry, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
|
||||
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
|
||||
session_id, source, access_count, last_access, created_at, updated_at, expires_at
|
||||
FROM memory_entries
|
||||
WHERE user_id = $1 AND (content ILIKE $2 OR summary ILIKE $2 OR keywords ILIKE $2)
|
||||
ORDER BY priority DESC, importance DESC
|
||||
LIMIT $3`
|
||||
|
||||
likePattern := "%" + keyword + "%"
|
||||
rows, err := db.QueryContext(ctx, query, userID, likePattern, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("关键词搜索失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanMemoryRows(rows)
|
||||
}
|
||||
|
||||
// Update 更新记忆
|
||||
func (s *Store) Update(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return errDBNotReady
|
||||
}
|
||||
|
||||
query := `UPDATE memory_entries SET content = $1, summary = $2, category = $3, priority = $4,
|
||||
importance = $5, keywords = $6, source = $7, updated_at = NOW()
|
||||
WHERE id = $8`
|
||||
|
||||
_, err := db.ExecContext(ctx, query,
|
||||
entry.Content, entry.Summary, string(entry.Category), int(entry.Priority),
|
||||
entry.Importance, entry.KeywordsJSON(), entry.Source, entry.ID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetCategories 获取用户所有分类及计数
|
||||
func (s *Store) GetCategories(ctx context.Context, userID string) (map[string]int, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
rows, err := db.QueryContext(ctx,
|
||||
`SELECT category, COUNT(*) FROM memory_entries WHERE user_id = $1 GROUP BY category ORDER BY category`,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询分类统计失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
categories := make(map[string]int)
|
||||
for rows.Next() {
|
||||
var cat string
|
||||
var count int
|
||||
if err := rows.Scan(&cat, &count); err != nil {
|
||||
return nil, fmt.Errorf("扫描分类统计失败: %w", err)
|
||||
}
|
||||
categories[cat] = count
|
||||
}
|
||||
|
||||
return categories, rows.Err()
|
||||
}
|
||||
|
||||
// Count 获取用户的记忆总数
|
||||
func (s *Store) Count(ctx context.Context, userID string) (int, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return 0, errDBNotReady
|
||||
}
|
||||
|
||||
var count int
|
||||
err := db.QueryRowContext(ctx,
|
||||
`SELECT COUNT(*) FROM memory_entries WHERE user_id = $1`,
|
||||
userID,
|
||||
).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("统计记忆失败: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ConsolidateMemories 记忆整理:合并相似记忆
|
||||
func (s *Store) ConsolidateMemories(ctx context.Context, userID string) (int, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return 0, errDBNotReady
|
||||
}
|
||||
|
||||
// 获取用户所有记忆
|
||||
allMems, err := s.Query(ctx, model.MemoryQuery{
|
||||
UserID: userID,
|
||||
Limit: 500,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("查询记忆失败: %w", err)
|
||||
}
|
||||
|
||||
if len(allMems) < 2 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
merged := 0
|
||||
for i := 0; i < len(allMems); i++ {
|
||||
if allMems[i].ID == "" {
|
||||
continue
|
||||
}
|
||||
for j := i + 1; j < len(allMems); j++ {
|
||||
if allMems[j].ID == "" {
|
||||
continue
|
||||
}
|
||||
score := allMems[i].SimilarityScore(&allMems[j])
|
||||
if score >= deDupThreshold {
|
||||
keep, discard := &allMems[i], &allMems[j]
|
||||
if discard.Importance > keep.Importance || discard.Priority > keep.Priority {
|
||||
keep, discard = discard, keep
|
||||
}
|
||||
|
||||
// 合并关键词
|
||||
keywordSet := make(map[string]bool)
|
||||
for _, k := range keep.Keywords {
|
||||
keywordSet[k] = true
|
||||
}
|
||||
for _, k := range discard.Keywords {
|
||||
keywordSet[k] = true
|
||||
}
|
||||
mergedKeywords := make([]string, 0, len(keywordSet))
|
||||
for k := range keywordSet {
|
||||
mergedKeywords = append(mergedKeywords, k)
|
||||
}
|
||||
keep.Keywords = mergedKeywords
|
||||
|
||||
if keep.Importance < 10 {
|
||||
keep.Importance++
|
||||
}
|
||||
keep.Source = "consolidated"
|
||||
|
||||
if err := s.Update(ctx, keep); err != nil {
|
||||
log.Printf("[memory-service] 合并更新记忆 %s 失败: %v", keep.ID, err)
|
||||
continue
|
||||
}
|
||||
if err := s.Delete(ctx, discard.ID); err != nil {
|
||||
log.Printf("[memory-service] 合并删除记忆 %s 失败: %v", discard.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
discard.ID = ""
|
||||
merged++
|
||||
log.Printf("[memory-service] 合并相似记忆: %s <- %s (相似度 %.0f%%)",
|
||||
keep.ID[:min(8, len(keep.ID))], discard.ID[:min(8, len(discard.ID))], score*100)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if merged > 0 {
|
||||
log.Printf("[memory-service] 记忆整理完成: 用户 %s 合并 %d 条相似记忆", userID, merged)
|
||||
}
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
// DecayMemories 记忆衰减:降低长期未访问的低重要性记忆
|
||||
func (s *Store) DecayMemories(ctx context.Context, userID string) (int, int, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return 0, 0, errDBNotReady
|
||||
}
|
||||
|
||||
result1, err := db.ExecContext(ctx, `
|
||||
UPDATE memory_entries SET priority = GREATEST(priority - 1, 0), updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
AND access_count < 3
|
||||
AND last_access < NOW() - INTERVAL '30 days'
|
||||
AND importance < 3
|
||||
AND priority > 0
|
||||
AND category NOT IN ('personal_info', 'user_preference')
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("衰减低活跃记忆失败: %w", err)
|
||||
}
|
||||
|
||||
decayed1, _ := result1.RowsAffected()
|
||||
|
||||
result2, err := db.ExecContext(ctx, `
|
||||
DELETE FROM memory_entries
|
||||
WHERE user_id = $1
|
||||
AND priority = 0
|
||||
AND access_count = 0
|
||||
AND last_access < NOW() - INTERVAL '14 days'
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("清理临时记忆失败: %w", err)
|
||||
}
|
||||
|
||||
deleted2, _ := result2.RowsAffected()
|
||||
|
||||
total := decayed1 + deleted2
|
||||
if total > 0 {
|
||||
log.Printf("[memory-service] 记忆衰减完成: 用户 %s 降级 %d 条, 删除 %d 条过期临时记忆",
|
||||
userID, decayed1, deleted2)
|
||||
}
|
||||
|
||||
return int(decayed1), int(deleted2), nil
|
||||
}
|
||||
|
||||
func (s *Store) incrementAccess(ctx context.Context, id string) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
db.ExecContext(ctx,
|
||||
`UPDATE memory_entries SET access_count = access_count + 1, last_access = NOW() WHERE id = $1`, id)
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (s *Store) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.db != nil {
|
||||
return s.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// scanMemoryRows 扫描记忆行(通用方法)
|
||||
func scanMemoryRows(rows *sql.Rows) ([]model.MemoryEntry, error) {
|
||||
var entries []model.MemoryEntry
|
||||
for rows.Next() {
|
||||
var entry model.MemoryEntry
|
||||
var category, keywordsRaw string
|
||||
if err := rows.Scan(
|
||||
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
|
||||
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
|
||||
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
|
||||
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("扫描记忆行失败: %w", err)
|
||||
}
|
||||
entry.Category = model.MemoryCategory(category)
|
||||
entry.Keywords = model.ParseKeywords(keywordsRaw)
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
// joinFloats 将 float64 切片转为逗号分隔字符串
|
||||
func joinFloats(vec []float64) string {
|
||||
if len(vec) == 0 {
|
||||
return ""
|
||||
}
|
||||
s := fmt.Sprintf("%f", vec[0])
|
||||
for i := 1; i < len(vec); i++ {
|
||||
s += fmt.Sprintf(",%f", vec[i])
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// SaveThinkingLog 保存自主思考日志
|
||||
func (s *Store) SaveThinkingLog(ctx context.Context, log *model.ThinkingLog) error {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return errDBNotReady
|
||||
}
|
||||
|
||||
if log.UserID == "" {
|
||||
log.UserID = "admin_admin"
|
||||
}
|
||||
if log.ToolCalls == "" {
|
||||
log.ToolCalls = "[]"
|
||||
}
|
||||
|
||||
query := `INSERT INTO thinking_logs (user_id, content, tool_calls, tool_call_count, content_length)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING id, created_at`
|
||||
|
||||
return db.QueryRowContext(ctx, query,
|
||||
log.UserID, log.Content, log.ToolCalls,
|
||||
log.ToolCallCount, log.ContentLength,
|
||||
).Scan(&log.ID, &log.CreatedAt)
|
||||
}
|
||||
|
||||
// QueryThinkingLogs 分页查询思考日志
|
||||
func (s *Store) QueryThinkingLogs(ctx context.Context, q model.ThinkingQuery) ([]model.ThinkingLog, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
if q.Limit <= 0 {
|
||||
q.Limit = 20
|
||||
}
|
||||
|
||||
query := `SELECT id, user_id, content, tool_calls, tool_call_count, content_length, created_at
|
||||
FROM thinking_logs`
|
||||
args := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
if q.UserID != "" {
|
||||
query += fmt.Sprintf(" WHERE user_id = $%d", argIdx)
|
||||
args = append(args, q.UserID)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
query += fmt.Sprintf(" ORDER BY created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
|
||||
args = append(args, q.Limit, q.Offset)
|
||||
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询思考日志失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var logs []model.ThinkingLog
|
||||
for rows.Next() {
|
||||
var tl model.ThinkingLog
|
||||
if err := rows.Scan(&tl.ID, &tl.UserID, &tl.Content, &tl.ToolCalls,
|
||||
&tl.ToolCallCount, &tl.ContentLength, &tl.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描思考日志行失败: %w", err)
|
||||
}
|
||||
logs = append(logs, tl)
|
||||
}
|
||||
return logs, rows.Err()
|
||||
}
|
||||
|
||||
// GetThinkingLogByID 根据ID获取单条思考日志
|
||||
func (s *Store) GetThinkingLogByID(ctx context.Context, id string) (*model.ThinkingLog, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
query := `SELECT id, user_id, content, tool_calls, tool_call_count, content_length, created_at
|
||||
FROM thinking_logs WHERE id = $1`
|
||||
|
||||
tl := &model.ThinkingLog{}
|
||||
err := db.QueryRowContext(ctx, query, id).Scan(
|
||||
&tl.ID, &tl.UserID, &tl.Content, &tl.ToolCalls,
|
||||
&tl.ToolCallCount, &tl.ContentLength, &tl.CreatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询思考日志失败: %w", err)
|
||||
}
|
||||
return tl, nil
|
||||
}
|
||||
|
||||
// GetThinkingStats 获取思考日志统计信息
|
||||
func (s *Store) GetThinkingStats(ctx context.Context) (*model.ThinkingStats, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
query := `SELECT
|
||||
COALESCE(COUNT(*), 0),
|
||||
COALESCE(SUM(tool_call_count), 0),
|
||||
COALESCE(AVG(content_length), 0),
|
||||
COALESCE(MAX(created_at)::TEXT, '')
|
||||
FROM thinking_logs`
|
||||
|
||||
stats := &model.ThinkingStats{}
|
||||
err := db.QueryRowContext(ctx, query).Scan(
|
||||
&stats.TotalLogs, &stats.TotalToolCalls,
|
||||
&stats.AvgContentLen, &stats.LatestAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询思考日志统计失败: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
# Build stage
|
||||
FROM golang:1.24-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o /tool-engine ./cmd/
|
||||
|
||||
# Runtime stage
|
||||
FROM alpine:3.21
|
||||
|
||||
RUN apk --no-cache add ca-certificates
|
||||
WORKDIR /app
|
||||
COPY --from=builder /tool-engine .
|
||||
|
||||
EXPOSE 8092
|
||||
|
||||
ENTRYPOINT ["./tool-engine"]
|
||||
@@ -0,0 +1,82 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/config"
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/handler"
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/service"
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/store"
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/tools"
|
||||
)
|
||||
|
||||
func main() {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
log.Println("🔧 Tool-Engine 启动中...")
|
||||
|
||||
// 加载配置
|
||||
cfg := config.Load()
|
||||
|
||||
log.Printf("配置: 端口=%s, IoT服务=%s, 数据目录=%s, DB=%s", cfg.Port, cfg.IoTServiceURL, cfg.DataDir, cfg.DBUrl)
|
||||
|
||||
// 初始化调用日志存储
|
||||
callLogStore, err := store.NewCallLogStore(cfg.DBUrl)
|
||||
if err != nil {
|
||||
log.Printf("[main] 初始化调用日志存储失败: %v", err)
|
||||
callLogStore = nil
|
||||
}
|
||||
|
||||
// 初始化 IoT 客户端
|
||||
var iotClient tools.IoTClientInterface
|
||||
if cfg.IoTServiceURL != "" {
|
||||
iotClient = tools.NewIoTClient(cfg.IoTServiceURL)
|
||||
log.Printf("[main] IoT 客户端已初始化: %s", cfg.IoTServiceURL)
|
||||
} else {
|
||||
log.Println("[main] IoT 服务 URL 未配置,IoT 工具将不可用")
|
||||
}
|
||||
|
||||
// 初始化服务层
|
||||
svc := service.NewToolService(iotClient, cfg.DataDir)
|
||||
|
||||
// 初始化 HTTP 处理器
|
||||
h := handler.NewToolHandler(svc, callLogStore)
|
||||
|
||||
// 注册路由
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
// 健康检查端点
|
||||
mux.HandleFunc("/api/v1/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"status":"ok","service":"tool-engine"}`))
|
||||
})
|
||||
|
||||
// 启动 HTTP 服务
|
||||
srv := &http.Server{
|
||||
Addr: ":" + cfg.Port,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("🚀 Tool-Engine 已启动在端口 %s", cfg.Port)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("服务启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 优雅关闭
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
log.Println("正在关闭 Tool-Engine...")
|
||||
|
||||
if callLogStore != nil {
|
||||
callLogStore.Close()
|
||||
}
|
||||
srv.Close()
|
||||
log.Println("Tool-Engine 已关闭")
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
module github.com/yourname/cyrene-ai/tool-engine
|
||||
|
||||
go 1.26.2
|
||||
|
||||
require github.com/lib/pq v1.10.9
|
||||
@@ -0,0 +1,2 @@
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
@@ -0,0 +1,30 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// Config 工具引擎服务配置
|
||||
type Config struct {
|
||||
Port string
|
||||
IoTServiceURL string
|
||||
DataDir string
|
||||
DBUrl string
|
||||
}
|
||||
|
||||
// Load 从环境变量加载配置
|
||||
func Load() *Config {
|
||||
return &Config{
|
||||
Port: getEnv("PORT", "8092"),
|
||||
IoTServiceURL: getEnv("IOT_SERVICE_URL", "http://localhost:8083"),
|
||||
DataDir: getEnv("DATA_DIR", "/tmp/cyrene_data"),
|
||||
DBUrl: getEnv("DB_URL", ""),
|
||||
}
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/service"
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/store"
|
||||
)
|
||||
|
||||
// ToolHandler HTTP API 处理器
|
||||
type ToolHandler struct {
|
||||
svc *service.ToolService
|
||||
callLogStore *store.CallLogStore
|
||||
}
|
||||
|
||||
// NewToolHandler 创建工具处理器
|
||||
func NewToolHandler(svc *service.ToolService, callLogStore *store.CallLogStore) *ToolHandler {
|
||||
return &ToolHandler{svc: svc, callLogStore: callLogStore}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册所有路由到 mux
|
||||
func (h *ToolHandler) RegisterRoutes(mux *http.ServeMux) {
|
||||
// GET /api/v1/tools - 列出所有工具
|
||||
mux.HandleFunc("/api/v1/tools", h.handleTools)
|
||||
// GET /api/v1/tools/ - 工具详情和单个执行 (带名称)
|
||||
mux.HandleFunc("/api/v1/tools/", h.handleToolByName)
|
||||
// POST /api/v1/tools/execute - 批量执行
|
||||
mux.HandleFunc("/api/v1/tools/execute", h.handleBatchExecute)
|
||||
}
|
||||
|
||||
// handleTools GET /api/v1/tools - 列出所有工具
|
||||
func (h *ToolHandler) handleTools(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
tools := h.svc.ListTools()
|
||||
if tools == nil {
|
||||
tools = []model.ToolDefinition{}
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"tools": tools,
|
||||
"total": len(tools),
|
||||
})
|
||||
}
|
||||
|
||||
// handleToolByName 处理 /api/v1/tools/{name} 和 /api/v1/tools/{name}/execute 和 /api/v1/tools/calls 和 /api/v1/tools/calls/stats
|
||||
func (h *ToolHandler) handleToolByName(w http.ResponseWriter, r *http.Request) {
|
||||
// 解析路径: /api/v1/tools/{name} 或 /api/v1/tools/{name}/execute
|
||||
path := strings.TrimPrefix(r.URL.Path, "/api/v1/tools/")
|
||||
parts := strings.SplitN(path, "/", 2)
|
||||
|
||||
toolName := parts[0]
|
||||
if toolName == "" {
|
||||
writeError(w, http.StatusBadRequest, "缺少工具名称")
|
||||
return
|
||||
}
|
||||
|
||||
// 处理 /api/v1/tools/calls/stats
|
||||
if toolName == "calls" && len(parts) == 2 && parts[1] == "stats" {
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
h.handleCallStats(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理 /api/v1/tools/calls
|
||||
if toolName == "calls" {
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
h.handleCallLogs(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 判断是否为执行请求
|
||||
if len(parts) == 2 && parts[1] == "execute" {
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed, use POST")
|
||||
return
|
||||
}
|
||||
h.executeTool(w, r, toolName)
|
||||
return
|
||||
}
|
||||
|
||||
// GET /api/v1/tools/{name} - 获取工具定义
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
def, ok := h.svc.GetTool(toolName)
|
||||
if !ok {
|
||||
writeError(w, http.StatusNotFound, "工具 "+toolName+" 不存在")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, def)
|
||||
}
|
||||
|
||||
// executeTool POST /api/v1/tools/{name}/execute - 执行单个工具
|
||||
func (h *ToolHandler) executeTool(w http.ResponseWriter, r *http.Request, toolName string) {
|
||||
var req model.ExecuteRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求体格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Arguments == nil {
|
||||
req.Arguments = make(map[string]interface{})
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
result, err := h.svc.Execute(r.Context(), toolName, req.Arguments)
|
||||
durationMs := int(time.Since(startTime).Milliseconds())
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[tool-handler] 执行工具 %s 失败: %v", toolName, err)
|
||||
h.logCall(toolName, req.Arguments, "", err.Error(), false, durationMs, r)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 异步记录调用日志
|
||||
h.logCall(toolName, req.Arguments, result.Output, result.Error, result.Error == "" && err == nil, durationMs, r)
|
||||
|
||||
writeJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
// handleBatchExecute POST /api/v1/tools/execute - 批量执行
|
||||
func (h *ToolHandler) handleBatchExecute(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed, use POST")
|
||||
return
|
||||
}
|
||||
|
||||
var req model.BatchExecuteRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "请求体格式错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Calls) == 0 {
|
||||
writeError(w, http.StatusBadRequest, "calls 不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
response := h.svc.ExecuteBatch(r.Context(), req.Calls)
|
||||
batchDuration := int(time.Since(startTime).Milliseconds())
|
||||
|
||||
// 异步记录每个调用
|
||||
for i, call := range req.Calls {
|
||||
var output, errStr string
|
||||
var success bool
|
||||
if i < len(response.Results) {
|
||||
output = response.Results[i].Output
|
||||
errStr = response.Results[i].Error
|
||||
success = errStr == ""
|
||||
}
|
||||
h.logCall(call.Name, call.Arguments, output, errStr, success, batchDuration, r)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, response)
|
||||
}
|
||||
|
||||
// newUUID generates a UUID v4 string using crypto/rand
|
||||
func newUUID() string {
|
||||
b := make([]byte, 16)
|
||||
_, _ = rand.Read(b)
|
||||
b[6] = (b[6] & 0x0f) | 0x40 // Version 4
|
||||
b[8] = (b[8] & 0x3f) | 0x80 // Variant 10
|
||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
||||
}
|
||||
|
||||
// logCall 异步记录工具调用日志
|
||||
func (h *ToolHandler) logCall(toolName string, args map[string]interface{}, output, errStr string, success bool, durationMs int, r *http.Request) {
|
||||
if h.callLogStore == nil {
|
||||
return
|
||||
}
|
||||
|
||||
callID := newUUID()
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
sessionID := r.URL.Query().Get("session_id")
|
||||
|
||||
go func() {
|
||||
argsJSON, _ := json.Marshal(args)
|
||||
record := &store.CallLogRecord{
|
||||
CallID: callID,
|
||||
ToolName: toolName,
|
||||
Arguments: argsJSON,
|
||||
Output: output,
|
||||
Error: errStr,
|
||||
Success: success,
|
||||
DurationMs: durationMs,
|
||||
UserID: userID,
|
||||
SessionID: sessionID,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := h.callLogStore.Insert(record); err != nil {
|
||||
log.Printf("[tool-handler] 记录调用日志失败: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// handleCallLogs GET /api/v1/tools/calls - 查询调用记录
|
||||
func (h *ToolHandler) handleCallLogs(w http.ResponseWriter, r *http.Request) {
|
||||
if h.callLogStore == nil {
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"calls": []interface{}{},
|
||||
"total": 0,
|
||||
"page": 1,
|
||||
"limit": 20,
|
||||
"total_pages": 0,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
|
||||
page, _ := strconv.Atoi(q.Get("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
limit, _ := strconv.Atoi(q.Get("limit"))
|
||||
if limit < 1 || limit > 100 {
|
||||
limit = 20
|
||||
}
|
||||
|
||||
query := store.CallLogQuery{
|
||||
ToolName: q.Get("tool_name"),
|
||||
Page: page,
|
||||
Limit: limit,
|
||||
}
|
||||
|
||||
result, err := h.callLogStore.Query(query)
|
||||
if err != nil {
|
||||
log.Printf("[tool-handler] 查询调用记录失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, "查询调用记录失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
// handleCallStats GET /api/v1/tools/calls/stats - 调用统计
|
||||
func (h *ToolHandler) handleCallStats(w http.ResponseWriter, r *http.Request) {
|
||||
if h.callLogStore == nil {
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"total_calls": 0,
|
||||
"success_count": 0,
|
||||
"fail_count": 0,
|
||||
"success_rate": 0,
|
||||
"avg_duration_ms": 0,
|
||||
"by_tool": []interface{}{},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.callLogStore.Stats()
|
||||
if err != nil {
|
||||
log.Printf("[tool-handler] 查询调用统计失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, "查询调用统计失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, stats)
|
||||
}
|
||||
|
||||
// writeJSON 写入 JSON 响应
|
||||
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
if err := json.NewEncoder(w).Encode(data); err != nil {
|
||||
log.Printf("[tool-handler] JSON 编码失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// writeError 写入错误响应
|
||||
func writeError(w http.ResponseWriter, status int, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": message,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package model
|
||||
|
||||
// ToolDefinition 工具定义(用于 LLM function calling)
|
||||
type ToolDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
|
||||
// ToolCall 工具调用请求
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
// ToolResult 工具执行结果
|
||||
type ToolResult struct {
|
||||
ID string `json:"id"`
|
||||
Output string `json:"output"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ExecuteRequest 单个工具执行请求
|
||||
type ExecuteRequest struct {
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
// BatchExecuteRequest 批量执行请求
|
||||
type BatchExecuteRequest struct {
|
||||
Calls []ToolCall `json:"calls"`
|
||||
}
|
||||
|
||||
// BatchExecuteResponse 批量执行响应
|
||||
type BatchExecuteResponse struct {
|
||||
Results []ToolResult `json:"results"`
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/tools"
|
||||
)
|
||||
|
||||
// ToolService 工具执行引擎
|
||||
type ToolService struct {
|
||||
registry map[string]tools.Tool
|
||||
}
|
||||
|
||||
// NewToolService 创建工具服务,注册所有工具
|
||||
func NewToolService(iotClient tools.IoTClientInterface, dataDir string) *ToolService {
|
||||
svc := &ToolService{
|
||||
registry: make(map[string]tools.Tool),
|
||||
}
|
||||
|
||||
// 注册所有 13 个工具
|
||||
svc.Register(tools.NewCalculatorTool())
|
||||
svc.Register(tools.NewDateTimeTool())
|
||||
svc.Register(tools.NewTextTool())
|
||||
svc.Register(tools.NewCryptoTool())
|
||||
svc.Register(tools.NewRandomTool())
|
||||
svc.Register(tools.NewMarkdownTool())
|
||||
svc.Register(tools.NewJSONTool())
|
||||
svc.Register(tools.NewFileTool(dataDir))
|
||||
svc.Register(tools.NewHTTPTool())
|
||||
svc.Register(tools.NewWebSearchTool())
|
||||
svc.Register(tools.NewWebFetchTool())
|
||||
|
||||
// IoT 工具(需要 IoT 客户端)
|
||||
if iotClient != nil {
|
||||
svc.Register(tools.NewIoTQueryTool(iotClient))
|
||||
svc.Register(tools.NewIoTControlTool(iotClient))
|
||||
} else {
|
||||
log.Println("[tool-service] IoT 客户端未配置,跳过 IoT 工具注册")
|
||||
}
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
// Register 注册工具
|
||||
func (s *ToolService) Register(tool tools.Tool) {
|
||||
def := tool.Definition()
|
||||
s.registry[def.Name] = tool
|
||||
log.Printf("[tool-service] 已注册工具: %s", def.Name)
|
||||
}
|
||||
|
||||
// ListTools 获取所有工具定义
|
||||
func (s *ToolService) ListTools() []model.ToolDefinition {
|
||||
defs := make([]model.ToolDefinition, 0, len(s.registry))
|
||||
for _, tool := range s.registry {
|
||||
defs = append(defs, tool.Definition())
|
||||
}
|
||||
return defs
|
||||
}
|
||||
|
||||
// GetTool 获取单个工具定义
|
||||
func (s *ToolService) GetTool(name string) (*model.ToolDefinition, bool) {
|
||||
tool, ok := s.registry[name]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
def := tool.Definition()
|
||||
return &def, true
|
||||
}
|
||||
|
||||
// Execute 执行单个工具
|
||||
func (s *ToolService) Execute(ctx context.Context, name string, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
tool, ok := s.registry[name]
|
||||
if !ok {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("工具 %s 不存在", name),
|
||||
}, nil
|
||||
}
|
||||
|
||||
log.Printf("[tool-service] 执行工具: %s", name)
|
||||
result, err := tool.Execute(ctx, arguments)
|
||||
if err != nil {
|
||||
log.Printf("[tool-service] 工具 %s 执行错误: %v", name, err)
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("执行工具 %s 失败: %v", name, err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
log.Printf("[tool-service] 工具 %s 返回错误: %s", name, result.Error)
|
||||
} else {
|
||||
log.Printf("[tool-service] 工具 %s 执行成功", name)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecuteBatch 批量执行工具调用
|
||||
func (s *ToolService) ExecuteBatch(ctx context.Context, calls []model.ToolCall) *model.BatchExecuteResponse {
|
||||
results := make([]model.ToolResult, 0, len(calls))
|
||||
|
||||
for _, call := range calls {
|
||||
result, err := s.Execute(ctx, call.Name, call.Arguments)
|
||||
if err != nil {
|
||||
results = append(results, model.ToolResult{
|
||||
ID: call.ID,
|
||||
Output: "",
|
||||
Error: err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
result.ID = call.ID
|
||||
results = append(results, *result)
|
||||
}
|
||||
|
||||
return &model.BatchExecuteResponse{
|
||||
Results: results,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,289 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// CallLogRecord 工具调用记录
|
||||
type CallLogRecord struct {
|
||||
ID int `json:"id"`
|
||||
CallID string `json:"call_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
Output string `json:"output"`
|
||||
Error string `json:"error"`
|
||||
Success bool `json:"success"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
UserID string `json:"user_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// CallLogQuery 查询参数
|
||||
type CallLogQuery struct {
|
||||
ToolName string
|
||||
Page int
|
||||
Limit int
|
||||
}
|
||||
|
||||
// CallLogPageResult 分页结果
|
||||
type CallLogPageResult struct {
|
||||
Calls []CallLogRecord `json:"calls"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
Limit int `json:"limit"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
// CallLogStats 调用统计
|
||||
type CallLogStats struct {
|
||||
TotalCalls int `json:"total_calls"`
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailCount int `json:"fail_count"`
|
||||
SuccessRate float64 `json:"success_rate"`
|
||||
AvgDuration float64 `json:"avg_duration_ms"`
|
||||
ByTool []ToolCallCount `json:"by_tool"`
|
||||
}
|
||||
|
||||
// ToolCallCount 按工具统计
|
||||
type ToolCallCount struct {
|
||||
ToolName string `json:"tool_name"`
|
||||
Count int `json:"count"`
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailCount int `json:"fail_count"`
|
||||
AvgDuration float64 `json:"avg_duration_ms"`
|
||||
}
|
||||
|
||||
// CallLogStore 工具调用日志存储
|
||||
type CallLogStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewCallLogStore 创建调用日志存储并自动建表
|
||||
func NewCallLogStore(dbURL string) (*CallLogStore, error) {
|
||||
if dbURL == "" {
|
||||
log.Println("[call-log-store] DB_URL 未设置,工具调用日志将不会持久化")
|
||||
return &CallLogStore{}, nil
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", dbURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开数据库连接失败: %w", err)
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(5)
|
||||
db.SetMaxIdleConns(2)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
log.Printf("[call-log-store] 数据库连接失败: %v (将尝试继续运行)", err)
|
||||
return &CallLogStore{}, nil
|
||||
}
|
||||
|
||||
store := &CallLogStore{db: db}
|
||||
if err := store.migrate(); err != nil {
|
||||
log.Printf("[call-log-store] 数据库迁移失败: %v", err)
|
||||
}
|
||||
|
||||
log.Println("[call-log-store] 数据库连接成功,表已就绪")
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// migrate 创建表结构
|
||||
func (s *CallLogStore) migrate() error {
|
||||
if s.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
query := `
|
||||
CREATE TABLE IF NOT EXISTS tool_call_logs (
|
||||
id SERIAL PRIMARY KEY,
|
||||
call_id TEXT NOT NULL,
|
||||
tool_name TEXT NOT NULL,
|
||||
arguments JSONB,
|
||||
output TEXT,
|
||||
error TEXT,
|
||||
success BOOLEAN NOT NULL DEFAULT true,
|
||||
duration_ms INTEGER,
|
||||
user_id TEXT DEFAULT '',
|
||||
session_id TEXT DEFAULT '',
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_tcl_tool_name ON tool_call_logs(tool_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_tcl_created_at ON tool_call_logs(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_tcl_user_id ON tool_call_logs(user_id);
|
||||
`
|
||||
|
||||
_, err := s.db.Exec(query)
|
||||
return err
|
||||
}
|
||||
|
||||
// Insert 插入一条调用记录
|
||||
func (s *CallLogStore) Insert(record *CallLogRecord) error {
|
||||
if s.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
argsJSON, _ := json.Marshal(record.Arguments)
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO tool_call_logs (call_id, tool_name, arguments, output, error, success, duration_ms, user_id, session_id, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`,
|
||||
record.CallID, record.ToolName, argsJSON, record.Output, record.Error,
|
||||
record.Success, record.DurationMs, record.UserID, record.SessionID, record.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("[call-log-store] 插入记录失败: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query 分页查询调用记录
|
||||
func (s *CallLogStore) Query(q CallLogQuery) (*CallLogPageResult, error) {
|
||||
if s.db == nil {
|
||||
return &CallLogPageResult{Calls: []CallLogRecord{}, Total: 0, Page: q.Page, Limit: q.Limit, TotalPages: 0}, nil
|
||||
}
|
||||
|
||||
if q.Page < 1 {
|
||||
q.Page = 1
|
||||
}
|
||||
if q.Limit < 1 || q.Limit > 100 {
|
||||
q.Limit = 20
|
||||
}
|
||||
|
||||
// 构建 WHERE 条件
|
||||
where := ""
|
||||
whereArgs := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
if q.ToolName != "" {
|
||||
where = fmt.Sprintf(" WHERE tool_name = $%d", argIdx)
|
||||
whereArgs = append(whereArgs, q.ToolName)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
// 计数
|
||||
countQuery := "SELECT COUNT(*) FROM tool_call_logs" + where
|
||||
var total int
|
||||
if err := s.db.QueryRow(countQuery, whereArgs...).Scan(&total); err != nil {
|
||||
return nil, fmt.Errorf("查询总数失败: %w", err)
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (q.Page - 1) * q.Limit
|
||||
querySql := fmt.Sprintf(
|
||||
"SELECT id, call_id, tool_name, arguments, COALESCE(output,''), COALESCE(error,''), success, COALESCE(duration_ms,0), COALESCE(user_id,''), COALESCE(session_id,''), created_at FROM tool_call_logs%s ORDER BY created_at DESC LIMIT $%d OFFSET $%d",
|
||||
where, argIdx, argIdx+1,
|
||||
)
|
||||
queryArgs := append(whereArgs, q.Limit, offset)
|
||||
|
||||
rows, err := s.db.Query(querySql, queryArgs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询记录失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
calls := make([]CallLogRecord, 0)
|
||||
for rows.Next() {
|
||||
var r CallLogRecord
|
||||
var argsJSON []byte
|
||||
if err := rows.Scan(&r.ID, &r.CallID, &r.ToolName, &argsJSON, &r.Output, &r.Error, &r.Success, &r.DurationMs, &r.UserID, &r.SessionID, &r.CreatedAt); err != nil {
|
||||
log.Printf("[call-log-store] 扫描行失败: %v", err)
|
||||
continue
|
||||
}
|
||||
r.Arguments = argsJSON
|
||||
calls = append(calls, r)
|
||||
}
|
||||
|
||||
totalPages := (total + q.Limit - 1) / q.Limit
|
||||
|
||||
return &CallLogPageResult{
|
||||
Calls: calls,
|
||||
Total: total,
|
||||
Page: q.Page,
|
||||
Limit: q.Limit,
|
||||
TotalPages: totalPages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Stats 获取调用统计
|
||||
func (s *CallLogStore) Stats() (*CallLogStats, error) {
|
||||
if s.db == nil {
|
||||
return &CallLogStats{}, nil
|
||||
}
|
||||
|
||||
stats := &CallLogStats{}
|
||||
|
||||
// 总体统计
|
||||
err := s.db.QueryRow(
|
||||
"SELECT COUNT(*), COUNT(*) FILTER (WHERE success=true), COUNT(*) FILTER (WHERE success=false), COALESCE(AVG(duration_ms),0) FROM tool_call_logs",
|
||||
).Scan(&stats.TotalCalls, &stats.SuccessCount, &stats.FailCount, &stats.AvgDuration)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询总体统计失败: %w", err)
|
||||
}
|
||||
|
||||
if stats.TotalCalls > 0 {
|
||||
stats.SuccessRate = float64(stats.SuccessCount) / float64(stats.TotalCalls) * 100
|
||||
}
|
||||
|
||||
// 按工具统计
|
||||
rows, err := s.db.Query(
|
||||
"SELECT tool_name, COUNT(*), COUNT(*) FILTER (WHERE success=true), COUNT(*) FILTER (WHERE success=false), COALESCE(AVG(duration_ms),0) FROM tool_call_logs GROUP BY tool_name ORDER BY COUNT(*) DESC",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询按工具统计失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
stats.ByTool = make([]ToolCallCount, 0)
|
||||
for rows.Next() {
|
||||
var tc ToolCallCount
|
||||
if err := rows.Scan(&tc.ToolName, &tc.Count, &tc.SuccessCount, &tc.FailCount, &tc.AvgDuration); err != nil {
|
||||
log.Printf("[call-log-store] 扫描工具统计失败: %v", err)
|
||||
continue
|
||||
}
|
||||
stats.ByTool = append(stats.ByTool, tc)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (s *CallLogStore) Close() {
|
||||
if s.db != nil {
|
||||
s.db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// DBUrlFromEnv 从环境变量获取数据库连接
|
||||
func DBUrlFromEnv() string {
|
||||
// 如果设置了 DB_URL 直接使用
|
||||
if url := os.Getenv("DB_URL"); url != "" {
|
||||
return url
|
||||
}
|
||||
|
||||
// 否则从单独的环境变量构建
|
||||
host := getEnv("DB_HOST", "localhost")
|
||||
port := getEnv("DB_PORT", "5432")
|
||||
user := getEnv("DB_USER", "cyrene")
|
||||
pass := getEnv("DB_PASSWORD", "change_me")
|
||||
dbname := getEnv("DB_NAME", "cyrene_ai")
|
||||
sslmode := getEnv("DB_SSLMODE", "disable")
|
||||
|
||||
return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s", user, pass, host, port, dbname, sslmode)
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
@@ -0,0 +1,342 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// CalculatorTool performs safe mathematical expression evaluation.
|
||||
type CalculatorTool struct{}
|
||||
|
||||
// NewCalculatorTool creates a calculator tool.
|
||||
func NewCalculatorTool() *CalculatorTool {
|
||||
return &CalculatorTool{}
|
||||
}
|
||||
|
||||
// Definition returns the tool definition for LLM function calling.
|
||||
func (t *CalculatorTool) Definition() model.ToolDefinition {
|
||||
return model.ToolDefinition{
|
||||
Name: "calculator",
|
||||
Description: "执行数学计算。用于精确计算数学表达式,支持四则运算、三角函数、对数、幂运算等。适用于LLM不擅长的复杂计算场景。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"expression": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "数学表达式,如 \"2 + 3 * 4\"、\"sqrt(16) + sin(pi/2)\"。支持运算符: + - * / % ^。支持函数: sqrt, sin, cos, tan, abs, floor, ceil, round, log, ln, pow。支持常量: pi, e。",
|
||||
},
|
||||
},
|
||||
"required": []string{"expression"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute evaluates a mathematical expression.
|
||||
func (t *CalculatorTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
expression, ok := arguments["expression"].(string)
|
||||
if !ok || strings.TrimSpace(expression) == "" {
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Error: "缺少 expression 参数",
|
||||
}, nil
|
||||
}
|
||||
|
||||
result, err := evaluate(expression)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Error: fmt.Sprintf("计算错误: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("表达式: %s\n结果: %s", expression, formatResult(result)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func formatResult(v float64) string {
|
||||
if v == math.Trunc(v) && math.Abs(v) < 1e15 {
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
}
|
||||
return strconv.FormatFloat(v, 'g', -1, 64)
|
||||
}
|
||||
|
||||
type tokenKind int
|
||||
|
||||
const (
|
||||
tokNumber tokenKind = iota
|
||||
tokIdent
|
||||
tokOp
|
||||
tokLParen
|
||||
tokRParen
|
||||
tokComma
|
||||
tokEOF
|
||||
)
|
||||
|
||||
type token struct {
|
||||
kind tokenKind
|
||||
value string
|
||||
}
|
||||
|
||||
type lexer struct {
|
||||
input []rune
|
||||
pos int
|
||||
}
|
||||
|
||||
func newLexer(s string) *lexer {
|
||||
return &lexer{input: []rune(s), pos: 0}
|
||||
}
|
||||
|
||||
func (l *lexer) next() token {
|
||||
l.skipWhitespace()
|
||||
if l.pos >= len(l.input) {
|
||||
return token{kind: tokEOF}
|
||||
}
|
||||
|
||||
ch := l.input[l.pos]
|
||||
|
||||
if unicode.IsDigit(ch) || ch == '.' {
|
||||
start := l.pos
|
||||
hasDot := ch == '.'
|
||||
l.pos++
|
||||
for l.pos < len(l.input) && (unicode.IsDigit(l.input[l.pos]) || l.input[l.pos] == '.') {
|
||||
if l.input[l.pos] == '.' {
|
||||
if hasDot {
|
||||
break
|
||||
}
|
||||
hasDot = true
|
||||
}
|
||||
l.pos++
|
||||
}
|
||||
return token{kind: tokNumber, value: string(l.input[start:l.pos])}
|
||||
}
|
||||
|
||||
if unicode.IsLetter(ch) || ch == '_' {
|
||||
start := l.pos
|
||||
l.pos++
|
||||
for l.pos < len(l.input) && (unicode.IsLetter(l.input[l.pos]) || unicode.IsDigit(l.input[l.pos]) || l.input[l.pos] == '_') {
|
||||
l.pos++
|
||||
}
|
||||
return token{kind: tokIdent, value: string(l.input[start:l.pos])}
|
||||
}
|
||||
|
||||
switch ch {
|
||||
case '+', '-', '*', '/', '%', '^':
|
||||
l.pos++
|
||||
return token{kind: tokOp, value: string(ch)}
|
||||
case '(':
|
||||
l.pos++
|
||||
return token{kind: tokLParen}
|
||||
case ')':
|
||||
l.pos++
|
||||
return token{kind: tokRParen}
|
||||
case ',':
|
||||
l.pos++
|
||||
return token{kind: tokComma}
|
||||
}
|
||||
|
||||
return token{kind: tokEOF}
|
||||
}
|
||||
|
||||
func (l *lexer) skipWhitespace() {
|
||||
for l.pos < len(l.input) && unicode.IsSpace(l.input[l.pos]) {
|
||||
l.pos++
|
||||
}
|
||||
}
|
||||
|
||||
type parser struct {
|
||||
lex *lexer
|
||||
cur token
|
||||
peek token
|
||||
}
|
||||
|
||||
func newParser(lex *lexer) *parser {
|
||||
p := &parser{lex: lex}
|
||||
p.cur = lex.next()
|
||||
p.peek = lex.next()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *parser) advance() {
|
||||
p.cur = p.peek
|
||||
p.peek = p.lex.next()
|
||||
}
|
||||
|
||||
func evaluate(expr string) (float64, error) {
|
||||
lex := newLexer(expr)
|
||||
par := newParser(lex)
|
||||
result, err := par.parseExpression()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if par.cur.kind != tokEOF {
|
||||
return 0, fmt.Errorf("表达式末尾存在意外字符")
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (p *parser) parseExpression() (float64, error) {
|
||||
left, err := p.parseTerm()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for p.cur.kind == tokOp && (p.cur.value == "+" || p.cur.value == "-") {
|
||||
op := p.cur.value
|
||||
p.advance()
|
||||
right, err := p.parseTerm()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if op == "+" {
|
||||
left += right
|
||||
} else {
|
||||
left -= right
|
||||
}
|
||||
}
|
||||
return left, nil
|
||||
}
|
||||
|
||||
func (p *parser) parseTerm() (float64, error) {
|
||||
left, err := p.parseUnary()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for p.cur.kind == tokOp && (p.cur.value == "*" || p.cur.value == "/" || p.cur.value == "%" || p.cur.value == "^") {
|
||||
op := p.cur.value
|
||||
p.advance()
|
||||
right, err := p.parseUnary()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch op {
|
||||
case "*":
|
||||
left *= right
|
||||
case "/":
|
||||
if right == 0 {
|
||||
return 0, fmt.Errorf("除数不能为零")
|
||||
}
|
||||
left /= right
|
||||
case "%":
|
||||
left = math.Mod(left, right)
|
||||
case "^":
|
||||
left = math.Pow(left, right)
|
||||
}
|
||||
}
|
||||
return left, nil
|
||||
}
|
||||
|
||||
func (p *parser) parseUnary() (float64, error) {
|
||||
if p.cur.kind == tokOp && p.cur.value == "-" {
|
||||
p.advance()
|
||||
val, err := p.parseUnary()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return -val, nil
|
||||
}
|
||||
if p.cur.kind == tokOp && p.cur.value == "+" {
|
||||
p.advance()
|
||||
return p.parseUnary()
|
||||
}
|
||||
return p.parseAtom()
|
||||
}
|
||||
|
||||
func (p *parser) parseAtom() (float64, error) {
|
||||
switch p.cur.kind {
|
||||
case tokNumber:
|
||||
val, err := strconv.ParseFloat(p.cur.value, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("无效数字: %s", p.cur.value)
|
||||
}
|
||||
p.advance()
|
||||
return val, nil
|
||||
|
||||
case tokIdent:
|
||||
name := strings.ToLower(p.cur.value)
|
||||
p.advance()
|
||||
|
||||
switch name {
|
||||
case "pi":
|
||||
return math.Pi, nil
|
||||
case "e":
|
||||
return math.E, nil
|
||||
}
|
||||
|
||||
if p.cur.kind != tokLParen {
|
||||
return 0, fmt.Errorf("未知标识符: %s (如果是函数需要加括号)", name)
|
||||
}
|
||||
p.advance()
|
||||
|
||||
arg, err := p.parseExpression()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if p.cur.kind != tokRParen {
|
||||
return 0, fmt.Errorf("函数 %s 缺少右括号", name)
|
||||
}
|
||||
p.advance()
|
||||
|
||||
return applyFunc(name, arg)
|
||||
|
||||
case tokLParen:
|
||||
p.advance()
|
||||
val, err := p.parseExpression()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if p.cur.kind != tokRParen {
|
||||
return 0, fmt.Errorf("缺少右括号")
|
||||
}
|
||||
p.advance()
|
||||
return val, nil
|
||||
|
||||
default:
|
||||
return 0, fmt.Errorf("意外的 token: %v", p.cur.value)
|
||||
}
|
||||
}
|
||||
|
||||
func applyFunc(name string, arg float64) (float64, error) {
|
||||
switch name {
|
||||
case "sqrt":
|
||||
if arg < 0 {
|
||||
return 0, fmt.Errorf("sqrt 参数不能为负数")
|
||||
}
|
||||
return math.Sqrt(arg), nil
|
||||
case "sin":
|
||||
return math.Sin(arg), nil
|
||||
case "cos":
|
||||
return math.Cos(arg), nil
|
||||
case "tan":
|
||||
return math.Tan(arg), nil
|
||||
case "abs":
|
||||
return math.Abs(arg), nil
|
||||
case "floor":
|
||||
return math.Floor(arg), nil
|
||||
case "ceil":
|
||||
return math.Ceil(arg), nil
|
||||
case "round":
|
||||
return math.Round(arg), nil
|
||||
case "log":
|
||||
if arg <= 0 {
|
||||
return 0, fmt.Errorf("log 参数必须大于0")
|
||||
}
|
||||
return math.Log10(arg), nil
|
||||
case "ln":
|
||||
if arg <= 0 {
|
||||
return 0, fmt.Errorf("ln 参数必须大于0")
|
||||
}
|
||||
return math.Log(arg), nil
|
||||
case "pow":
|
||||
return 0, fmt.Errorf("pow 需要两个参数,请使用 ^ 运算符代替")
|
||||
default:
|
||||
return 0, fmt.Errorf("未知函数: %s", name)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash"
|
||||
"net/url"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// CryptoTool provides cryptographic and encoding utilities for the LLM.
|
||||
type CryptoTool struct{}
|
||||
|
||||
// NewCryptoTool creates a crypto/encoding tool.
|
||||
func NewCryptoTool() *CryptoTool {
|
||||
return &CryptoTool{}
|
||||
}
|
||||
|
||||
// Definition returns the tool definition for LLM function calling.
|
||||
func (t *CryptoTool) Definition() model.ToolDefinition {
|
||||
return model.ToolDefinition{
|
||||
Name: "crypto",
|
||||
Description: "加密哈希与编码工具。计算MD5/SHA哈希值,执行Base64编码/解码,URL编码/解码。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"hash", "base64_encode", "base64_decode", "url_encode", "url_decode"},
|
||||
"description": "操作类型。hash: 计算哈希值;base64_encode: Base64编码;base64_decode: Base64解码;url_encode: URL编码;url_decode: URL解码",
|
||||
},
|
||||
"input": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "输入数据,需要处理的字符串",
|
||||
},
|
||||
"algorithm": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"md5", "sha1", "sha256", "sha512"},
|
||||
"description": "哈希算法(用于 hash 操作),默认 sha256",
|
||||
},
|
||||
},
|
||||
"required": []string{"action", "input"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute performs crypto/encoding operations.
|
||||
func (t *CryptoTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
action, ok := arguments["action"].(string)
|
||||
if !ok || action == "" {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 action 参数"}, nil
|
||||
}
|
||||
|
||||
input, ok := arguments["input"].(string)
|
||||
if !ok {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 input 参数"}, nil
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "hash":
|
||||
return t.handleHash(arguments)
|
||||
case "base64_encode":
|
||||
return t.handleBase64Encode(input)
|
||||
case "base64_decode":
|
||||
return t.handleBase64Decode(input)
|
||||
case "url_encode":
|
||||
return t.handleURLEncode(input)
|
||||
case "url_decode":
|
||||
return t.handleURLDecode(input)
|
||||
default:
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Error: fmt.Sprintf("未知操作: %s,支持: hash, base64_encode, base64_decode, url_encode, url_decode", action),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *CryptoTool) handleHash(arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
input, _ := arguments["input"].(string)
|
||||
algorithm, _ := arguments["algorithm"].(string)
|
||||
if algorithm == "" {
|
||||
algorithm = "sha256"
|
||||
}
|
||||
|
||||
var h hash.Hash
|
||||
switch algorithm {
|
||||
case "md5":
|
||||
h = md5.New()
|
||||
case "sha1":
|
||||
h = sha1.New()
|
||||
case "sha256":
|
||||
h = sha256.New()
|
||||
case "sha512":
|
||||
h = sha512.New()
|
||||
default:
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Error: fmt.Sprintf("不支持的哈希算法: %s,支持: md5, sha1, sha256, sha512", algorithm),
|
||||
}, nil
|
||||
}
|
||||
|
||||
h.Write([]byte(input))
|
||||
hashBytes := h.Sum(nil)
|
||||
hashHex := fmt.Sprintf("%x", hashBytes)
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("哈希算法: %s\n输入长度: %d 字节\n哈希值 (hex): %s\n哈希长度: %d 位",
|
||||
algorithm, len(input), hashHex, len(hashBytes)*8),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *CryptoTool) handleBase64Encode(input string) (*model.ToolResult, error) {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(input))
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("Base64 编码结果:\n原始 (%d 字节): %s\n编码 (%d 字符): %s",
|
||||
len(input), truncate(input, 100), len(encoded), encoded),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *CryptoTool) handleBase64Decode(input string) (*model.ToolResult, error) {
|
||||
decoded, err := base64.StdEncoding.DecodeString(input)
|
||||
if err != nil {
|
||||
decoded, err = base64.RawStdEncoding.DecodeString(input)
|
||||
if err != nil {
|
||||
decoded, err = base64.URLEncoding.DecodeString(input)
|
||||
if err != nil {
|
||||
decoded, err = base64.RawURLEncoding.DecodeString(input)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Error: "Base64 解码失败: 输入不是有效的 Base64 字符串",
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("Base64 解码结果:\n原始 (%d 字符): %s\n解码 (%d 字节): %s",
|
||||
len(input), truncate(input, 100), len(decoded), truncate(string(decoded), 200)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *CryptoTool) handleURLEncode(input string) (*model.ToolResult, error) {
|
||||
encoded := url.QueryEscape(input)
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("URL 编码结果:\n原始 (%d 字节): %s\n编码 (%d 字节): %s",
|
||||
len(input), truncate(input, 100), len(encoded), encoded),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *CryptoTool) handleURLDecode(input string) (*model.ToolResult, error) {
|
||||
decoded, err := url.QueryUnescape(input)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Error: fmt.Sprintf("URL 解码失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("URL 解码结果:\n原始 (%d 字节): %s\n解码 (%d 字节): %s",
|
||||
len(input), truncate(input, 100), len(decoded), truncate(decoded, 200)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func truncate(s string, maxLen int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return string(runes[:maxLen]) + "..."
|
||||
}
|
||||
@@ -0,0 +1,360 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// DateTimeTool provides date/time operations for the LLM.
|
||||
type DateTimeTool struct{}
|
||||
|
||||
// NewDateTimeTool creates a date/time tool.
|
||||
func NewDateTimeTool() *DateTimeTool {
|
||||
return &DateTimeTool{}
|
||||
}
|
||||
|
||||
// Definition returns the tool definition for LLM function calling.
|
||||
func (t *DateTimeTool) Definition() model.ToolDefinition {
|
||||
return model.ToolDefinition{
|
||||
Name: "datetime",
|
||||
Description: "日期时间工具。获取当前时间、格式化日期、日期加减、计算日期差、查看可用时区。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"now", "format", "add", "diff", "timezone_list"},
|
||||
"description": "操作类型。now: 获取当前时间;format: 格式化日期;add: 日期加减;diff: 计算两个日期的差值;timezone_list: 列出常用时区",
|
||||
},
|
||||
"format": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "日期格式串(Go风格)。默认 \"2006-01-02 15:04:05\"。常用: \"2006-01-02\"(仅日期)、\"15:04:05\"(仅时间)",
|
||||
},
|
||||
"timezone": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "时区标识,如 \"Asia/Shanghai\"、\"America/New_York\"、\"UTC\"。默认使用服务器本地时区",
|
||||
},
|
||||
"date": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "基准日期,格式为 \"2006-01-02 15:04:05\" 或 \"2006-01-02\"",
|
||||
},
|
||||
"duration": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "时长字符串,如 \"24h\"、\"7d\"、\"30m\"、\"1h30m\"。支持单位: s(秒), m(分钟), h(小时), d(天), w(周), M(月), y(年)",
|
||||
},
|
||||
"date2": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "第二个日期(用于 diff 操作),格式同 date",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute performs date/time operations.
|
||||
func (t *DateTimeTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
action, ok := arguments["action"].(string)
|
||||
if !ok || action == "" {
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Error: "缺少 action 参数",
|
||||
}, nil
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "now":
|
||||
return t.handleNow(arguments)
|
||||
case "format":
|
||||
return t.handleFormat(arguments)
|
||||
case "add":
|
||||
return t.handleAdd(arguments)
|
||||
case "diff":
|
||||
return t.handleDiff(arguments)
|
||||
case "timezone_list":
|
||||
return t.handleTimezoneList()
|
||||
default:
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Error: fmt.Sprintf("未知操作: %s,支持: now, format, add, diff, timezone_list", action),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DateTimeTool) handleNow(arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
tz, err := t.getTimezone(arguments)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
format := t.getFormat(arguments)
|
||||
now := time.Now().In(tz)
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("当前时间: %s\n时区: %s\nUnix时间戳: %d",
|
||||
now.Format(format), tz.String(), now.Unix()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *DateTimeTool) handleFormat(arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
dateStr, _ := arguments["date"].(string)
|
||||
if dateStr == "" {
|
||||
return &model.ToolResult{ID: "", Error: "format 操作需要 date 参数"}, nil
|
||||
}
|
||||
|
||||
parsed, err := t.parseDate(dateStr)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("日期解析失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
tz, err := t.getTimezone(arguments)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
format := t.getFormat(arguments)
|
||||
formatted := parsed.In(tz).Format(format)
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("原始: %s\n格式化: %s\n时区: %s", dateStr, formatted, tz.String()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *DateTimeTool) handleAdd(arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
durationStr, _ := arguments["duration"].(string)
|
||||
if durationStr == "" {
|
||||
return &model.ToolResult{ID: "", Error: "add 操作需要 duration 参数"}, nil
|
||||
}
|
||||
|
||||
dateStr, _ := arguments["date"].(string)
|
||||
var base time.Time
|
||||
if dateStr != "" {
|
||||
var err error
|
||||
base, err = t.parseDate(dateStr)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("日期解析失败: %v", err)}, nil
|
||||
}
|
||||
} else {
|
||||
tz, _ := t.getTimezone(arguments)
|
||||
base = time.Now().In(tz)
|
||||
}
|
||||
|
||||
dur, err := t.parseDuration(durationStr)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("时长解析失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
tz, _ := t.getTimezone(arguments)
|
||||
|
||||
result := base.In(tz)
|
||||
|
||||
months := extractDurationUnit(durationStr, 'M')
|
||||
years := extractDurationUnit(durationStr, 'y')
|
||||
|
||||
if months != 0 || years != 0 {
|
||||
result = result.AddDate(years, months, 0)
|
||||
}
|
||||
|
||||
if dur != 0 {
|
||||
result = result.Add(dur)
|
||||
}
|
||||
|
||||
format := t.getFormat(arguments)
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("基准日期: %s\n操作: %s\n结果: %s",
|
||||
base.In(tz).Format(format), durationStr, result.Format(format)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *DateTimeTool) handleDiff(arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
dateStr, _ := arguments["date"].(string)
|
||||
date2Str, _ := arguments["date2"].(string)
|
||||
|
||||
if dateStr == "" || date2Str == "" {
|
||||
return &model.ToolResult{ID: "", Error: "diff 操作需要 date 和 date2 参数"}, nil
|
||||
}
|
||||
|
||||
d1, err := t.parseDate(dateStr)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("date 解析失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
d2, err := t.parseDate(date2Str)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("date2 解析失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
diff := d2.Sub(d1)
|
||||
absDiff := diff
|
||||
if absDiff < 0 {
|
||||
absDiff = -absDiff
|
||||
}
|
||||
|
||||
days := int(absDiff.Hours() / 24)
|
||||
hours := int(absDiff.Hours()) % 24
|
||||
minutes := int(absDiff.Minutes()) % 60
|
||||
seconds := int(absDiff.Seconds()) % 60
|
||||
|
||||
sign := ""
|
||||
if diff < 0 {
|
||||
sign = "-"
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("日期1: %s\n日期2: %s\n差值: %s%d天 %d小时 %d分钟 %d秒 (总计 %s%.0f秒)",
|
||||
dateStr, date2Str, sign, days, hours, minutes, seconds, sign, absDiff.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *DateTimeTool) handleTimezoneList() (*model.ToolResult, error) {
|
||||
zones := []string{
|
||||
"UTC",
|
||||
"Asia/Shanghai (北京时间)",
|
||||
"Asia/Tokyo (东京时间)",
|
||||
"Asia/Seoul (首尔时间)",
|
||||
"Asia/Singapore (新加坡时间)",
|
||||
"Asia/Kolkata (印度时间)",
|
||||
"Asia/Dubai (迪拜时间)",
|
||||
"Europe/London (伦敦时间)",
|
||||
"Europe/Paris (巴黎时间)",
|
||||
"Europe/Berlin (柏林时间)",
|
||||
"Europe/Moscow (莫斯科时间)",
|
||||
"America/New_York (纽约时间)",
|
||||
"America/Chicago (芝加哥时间)",
|
||||
"America/Denver (丹佛时间)",
|
||||
"America/Los_Angeles (洛杉矶时间)",
|
||||
"America/Sao_Paulo (圣保罗时间)",
|
||||
"Australia/Sydney (悉尼时间)",
|
||||
"Pacific/Auckland (奥克兰时间)",
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.WriteString("常用时区列表:\n\n")
|
||||
for i, z := range zones {
|
||||
result.WriteString(fmt.Sprintf(" %2d. %s\n", i+1, z))
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: result.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *DateTimeTool) getTimezone(arguments map[string]interface{}) (*time.Location, error) {
|
||||
tzName, _ := arguments["timezone"].(string)
|
||||
if tzName == "" {
|
||||
return time.Local, nil
|
||||
}
|
||||
loc, err := time.LoadLocation(tzName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无效时区: %s", tzName)
|
||||
}
|
||||
return loc, nil
|
||||
}
|
||||
|
||||
func (t *DateTimeTool) getFormat(arguments map[string]interface{}) string {
|
||||
format, _ := arguments["format"].(string)
|
||||
if format == "" {
|
||||
return "2006-01-02 15:04:05"
|
||||
}
|
||||
return format
|
||||
}
|
||||
|
||||
func (t *DateTimeTool) parseDate(s string) (time.Time, error) {
|
||||
formats := []string{
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02T15:04:05Z",
|
||||
"2006-01-02T15:04:05",
|
||||
"2006-01-02",
|
||||
"2006/01/02 15:04:05",
|
||||
"2006/01/02",
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
}
|
||||
for _, f := range formats {
|
||||
if t, err := time.Parse(f, s); err == nil {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return time.Time{}, fmt.Errorf("无法解析日期: %s", s)
|
||||
}
|
||||
|
||||
func (t *DateTimeTool) parseDuration(s string) (time.Duration, error) {
|
||||
if d, err := time.ParseDuration(s); err == nil {
|
||||
return d, nil
|
||||
}
|
||||
|
||||
var total time.Duration
|
||||
remaining := s
|
||||
|
||||
for len(remaining) > 0 {
|
||||
numStart := 0
|
||||
for numStart < len(remaining) && !unicode.IsDigit(rune(remaining[numStart])) && remaining[numStart] != '-' {
|
||||
numStart++
|
||||
}
|
||||
if numStart >= len(remaining) {
|
||||
break
|
||||
}
|
||||
|
||||
numEnd := numStart
|
||||
for numEnd < len(remaining) && (unicode.IsDigit(rune(remaining[numEnd])) || remaining[numEnd] == '.') {
|
||||
numEnd++
|
||||
}
|
||||
|
||||
val, err := strconv.ParseFloat(remaining[numStart:numEnd], 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("无效时长数字: %s", remaining[numStart:numEnd])
|
||||
}
|
||||
|
||||
unitEnd := numEnd
|
||||
for unitEnd < len(remaining) && unicode.IsLetter(rune(remaining[unitEnd])) {
|
||||
unitEnd++
|
||||
}
|
||||
unit := remaining[numEnd:unitEnd]
|
||||
|
||||
switch unit {
|
||||
case "s":
|
||||
total += time.Duration(val * float64(time.Second))
|
||||
case "m":
|
||||
total += time.Duration(val * float64(time.Minute))
|
||||
case "h":
|
||||
total += time.Duration(val * float64(time.Hour))
|
||||
case "d":
|
||||
total += time.Duration(val * 24 * float64(time.Hour))
|
||||
case "w":
|
||||
total += time.Duration(val * 7 * 24 * float64(time.Hour))
|
||||
}
|
||||
|
||||
remaining = remaining[unitEnd:]
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func extractDurationUnit(s string, unit byte) int {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == unit {
|
||||
j := i - 1
|
||||
for j >= 0 && (unicode.IsDigit(rune(s[j])) || s[j] == '.') {
|
||||
j--
|
||||
}
|
||||
numStr := s[j+1 : i]
|
||||
val, err := strconv.Atoi(numStr)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return val
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -0,0 +1,234 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// FileTool provides sandboxed file system operations for the LLM.
|
||||
type FileTool struct {
|
||||
dataDir string
|
||||
}
|
||||
|
||||
// NewFileTool creates a file operation tool with the given data directory.
|
||||
func NewFileTool(dataDir string) *FileTool {
|
||||
if dataDir == "" {
|
||||
dataDir = "/tmp/cyrene_data"
|
||||
}
|
||||
return &FileTool{dataDir: dataDir}
|
||||
}
|
||||
|
||||
// Definition returns the tool definition for LLM function calling.
|
||||
func (t *FileTool) Definition() model.ToolDefinition {
|
||||
return model.ToolDefinition{
|
||||
Name: "file_ops",
|
||||
Description: "文件操作工具。在服务端安全沙盒内读写文件、列出目录、检查文件是否存在、删除文件。所有操作限制在数据目录内,无法访问系统文件。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"read", "write", "list", "exists", "delete"},
|
||||
"description": "操作类型。read: 读取文件;write: 写入文件(覆盖或创建);list: 列出目录内容;exists: 检查路径是否存在;delete: 删除文件",
|
||||
},
|
||||
"path": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "文件或目录路径(相对于数据目录),如 \"notes/todo.txt\"",
|
||||
},
|
||||
"content": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "写入内容(write 操作时必需)",
|
||||
},
|
||||
},
|
||||
"required": []string{"action", "path"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute performs file operations.
|
||||
func (t *FileTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
action, ok := arguments["action"].(string)
|
||||
if !ok || action == "" {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 action 参数"}, nil
|
||||
}
|
||||
|
||||
relPath, ok := arguments["path"].(string)
|
||||
if !ok || relPath == "" {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 path 参数"}, nil
|
||||
}
|
||||
|
||||
safePath, err := t.resolveSafePath(relPath)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "read":
|
||||
return t.handleRead(safePath, relPath)
|
||||
case "write":
|
||||
content, _ := arguments["content"].(string)
|
||||
return t.handleWrite(safePath, relPath, content)
|
||||
case "list":
|
||||
return t.handleList(safePath, relPath)
|
||||
case "exists":
|
||||
return t.handleExists(safePath, relPath)
|
||||
case "delete":
|
||||
return t.handleDelete(safePath, relPath)
|
||||
default:
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Error: fmt.Sprintf("未知操作: %s,支持: read, write, list, exists, delete", action),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *FileTool) resolveSafePath(relPath string) (string, error) {
|
||||
clean := filepath.Clean(relPath)
|
||||
|
||||
if err := os.MkdirAll(t.dataDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("创建数据目录失败: %v", err)
|
||||
}
|
||||
|
||||
abs := filepath.Join(t.dataDir, clean)
|
||||
|
||||
realPath, err := filepath.EvalSymlinks(abs)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if !strings.HasPrefix(filepath.Clean(abs), filepath.Clean(t.dataDir)+string(filepath.Separator)) &&
|
||||
filepath.Clean(abs) != filepath.Clean(t.dataDir) {
|
||||
return "", fmt.Errorf("路径穿越检测: %s 不在允许的数据目录内", relPath)
|
||||
}
|
||||
return abs, nil
|
||||
}
|
||||
return "", fmt.Errorf("路径解析失败: %v", err)
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(realPath, filepath.Clean(t.dataDir)+string(filepath.Separator)) &&
|
||||
realPath != filepath.Clean(t.dataDir) {
|
||||
return "", fmt.Errorf("路径穿越检测: %s 不在允许的数据目录内", relPath)
|
||||
}
|
||||
|
||||
return realPath, nil
|
||||
}
|
||||
|
||||
func (t *FileTool) handleRead(absPath, relPath string) (*model.ToolResult, error) {
|
||||
const maxSize = 100 * 1024
|
||||
|
||||
info, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("文件不存在: %s", relPath)}, nil
|
||||
}
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("读取文件失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("路径是目录,不能用 read 操作: %s", relPath)}, nil
|
||||
}
|
||||
|
||||
if info.Size() > maxSize {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("文件过大 (%d bytes),超过限制 (%d bytes)", info.Size(), maxSize)}, nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("读取文件失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("文件: %s\n大小: %d bytes\n---\n%s", relPath, len(data), string(data)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *FileTool) handleWrite(absPath, relPath, content string) (*model.ToolResult, error) {
|
||||
dir := filepath.Dir(absPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("创建目录失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
if err := os.WriteFile(absPath, []byte(content), 0644); err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("写入文件失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("已写入文件: %s (%d bytes)", relPath, len(content)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *FileTool) handleList(absPath, relPath string) (*model.ToolResult, error) {
|
||||
entries, err := os.ReadDir(absPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("目录不存在: %s", relPath)}, nil
|
||||
}
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("读取目录失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
if len(entries) == 0 {
|
||||
return &model.ToolResult{ID: "", Output: fmt.Sprintf("目录: %s\n(空目录)", relPath)}, nil
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.WriteString(fmt.Sprintf("目录: %s\n共 %d 项:\n", relPath, len(entries)))
|
||||
for _, entry := range entries {
|
||||
icon := "📄"
|
||||
if entry.IsDir() {
|
||||
icon = "📁"
|
||||
}
|
||||
info, _ := entry.Info()
|
||||
size := ""
|
||||
if info != nil && !entry.IsDir() {
|
||||
size = fmt.Sprintf(" (%d bytes)", info.Size())
|
||||
}
|
||||
result.WriteString(fmt.Sprintf(" %s %s%s\n", icon, entry.Name(), size))
|
||||
}
|
||||
|
||||
return &model.ToolResult{ID: "", Output: result.String()}, nil
|
||||
}
|
||||
|
||||
func (t *FileTool) handleExists(absPath, relPath string) (*model.ToolResult, error) {
|
||||
info, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &model.ToolResult{ID: "", Output: fmt.Sprintf("路径不存在: %s", relPath)}, nil
|
||||
}
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("检查路径失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
kind := "文件"
|
||||
if info.IsDir() {
|
||||
kind = "目录"
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("路径存在: %s (%s, %d bytes)", relPath, kind, info.Size()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *FileTool) handleDelete(absPath, relPath string) (*model.ToolResult, error) {
|
||||
info, err := os.Stat(absPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("文件不存在: %s", relPath)}, nil
|
||||
}
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("删除文件失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("不能删除目录(安全限制): %s", relPath)}, nil
|
||||
}
|
||||
|
||||
if err := os.Remove(absPath); err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("删除文件失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{ID: "", Output: fmt.Sprintf("已删除文件: %s", relPath)}, nil
|
||||
}
|
||||
@@ -0,0 +1,157 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// HTTPTool sends arbitrary HTTP requests, more flexible than web_fetch.
|
||||
type HTTPTool struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewHTTPTool creates an HTTP request tool.
|
||||
func NewHTTPTool() *HTTPTool {
|
||||
return &HTTPTool{
|
||||
client: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Definition returns the tool definition for LLM function calling.
|
||||
func (t *HTTPTool) Definition() model.ToolDefinition {
|
||||
return model.ToolDefinition{
|
||||
Name: "http_request",
|
||||
Description: "发送任意HTTP请求。比web_fetch更灵活,支持自定义请求方法、请求头和请求体。返回状态码、响应头和响应体。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"url": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "请求URL,必须是完整的 http:// 或 https:// 链接",
|
||||
},
|
||||
"method": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"},
|
||||
"description": "HTTP方法,默认GET",
|
||||
},
|
||||
"headers": map[string]interface{}{
|
||||
"type": "object",
|
||||
"description": "请求头,键值对格式,如 {\"Content-Type\": \"application/json\", \"Authorization\": \"Bearer token123\"}",
|
||||
},
|
||||
"body": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "请求体内容",
|
||||
},
|
||||
"timeout": map[string]interface{}{
|
||||
"type": "number",
|
||||
"description": "超时秒数,默认10秒",
|
||||
},
|
||||
},
|
||||
"required": []string{"url"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute sends an HTTP request.
|
||||
func (t *HTTPTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
url, ok := arguments["url"].(string)
|
||||
if !ok || url == "" {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 url 参数"}, nil
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
|
||||
return &model.ToolResult{ID: "", Error: "仅支持 http:// 或 https:// 链接"}, nil
|
||||
}
|
||||
|
||||
method, _ := arguments["method"].(string)
|
||||
if method == "" {
|
||||
method = "GET"
|
||||
}
|
||||
method = strings.ToUpper(method)
|
||||
|
||||
validMethods := map[string]bool{
|
||||
"GET": true, "POST": true, "PUT": true, "DELETE": true,
|
||||
"PATCH": true, "HEAD": true, "OPTIONS": true,
|
||||
}
|
||||
if !validMethods[method] {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("不支持的HTTP方法: %s", method)}, nil
|
||||
}
|
||||
|
||||
timeoutSec := 10.0
|
||||
if timeoutVal, ok := arguments["timeout"].(float64); ok && timeoutVal > 0 {
|
||||
timeoutSec = timeoutVal
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(timeoutSec * float64(time.Second)),
|
||||
}
|
||||
|
||||
var bodyReader io.Reader
|
||||
bodyStr, _ := arguments["body"].(string)
|
||||
if bodyStr != "" {
|
||||
bodyReader = strings.NewReader(bodyStr)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("创建请求失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyreneBot/1.0)")
|
||||
|
||||
if headersRaw, ok := arguments["headers"].(map[string]interface{}); ok {
|
||||
for k, v := range headersRaw {
|
||||
val, ok := v.(string)
|
||||
if !ok {
|
||||
val = fmt.Sprintf("%v", v)
|
||||
}
|
||||
req.Header.Set(k, val)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("请求失败: %v", err)}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
const maxBodySize = 50 * 1024
|
||||
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxBodySize)))
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("读取响应失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
var headerLines []string
|
||||
for k, vals := range resp.Header {
|
||||
for _, v := range vals {
|
||||
headerLines = append(headerLines, fmt.Sprintf("%s: %s", k, v))
|
||||
}
|
||||
}
|
||||
headersStr := strings.Join(headerLines, "\n")
|
||||
|
||||
bodyTruncated := ""
|
||||
if len(bodyBytes) > maxBodySize {
|
||||
bodyTruncated = fmt.Sprintf("\n... [响应体已截断,原大小约 %d bytes]", len(bodyBytes))
|
||||
}
|
||||
|
||||
result := fmt.Sprintf(
|
||||
"请求: %s %s\n状态: %d %s\n响应头:\n%s\n\n响应体 (%d bytes):\n%s%s",
|
||||
method, url,
|
||||
resp.StatusCode, resp.Status,
|
||||
headersStr,
|
||||
len(bodyBytes), string(bodyBytes), bodyTruncated,
|
||||
)
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: result,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IoTClient IoT 调试服务 HTTP 客户端
|
||||
type IoTClient struct {
|
||||
baseURL string
|
||||
client *http.Client
|
||||
|
||||
// 缓存控制
|
||||
mu sync.RWMutex
|
||||
cache []IoTDevice
|
||||
cacheTime time.Time
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
// NewIoTClient 创建 IoT 客户端
|
||||
func NewIoTClient(baseURL string) *IoTClient {
|
||||
return &IoTClient{
|
||||
baseURL: baseURL,
|
||||
client: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
cacheTTL: 60 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllDevices 获取所有设备列表(带缓存)
|
||||
func (c *IoTClient) GetAllDevices() ([]IoTDevice, error) {
|
||||
// 检查缓存
|
||||
c.mu.RLock()
|
||||
if c.cache != nil && time.Since(c.cacheTime) < c.cacheTTL {
|
||||
devices := make([]IoTDevice, len(c.cache))
|
||||
copy(devices, c.cache)
|
||||
c.mu.RUnlock()
|
||||
return devices, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
// 请求 API
|
||||
resp, err := c.client.Get(c.baseURL + "/api/v1/devices")
|
||||
if err != nil {
|
||||
log.Printf("[IoT客户端] 请求失败: %v", err)
|
||||
return nil, fmt.Errorf("获取设备列表失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("获取设备列表返回状态码 %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Devices []IoTDevice `json:"devices"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析设备列表失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
c.mu.Lock()
|
||||
c.cache = result.Devices
|
||||
c.cacheTime = time.Now()
|
||||
c.mu.Unlock()
|
||||
|
||||
return result.Devices, nil
|
||||
}
|
||||
|
||||
// GetDevice 获取单个设备详情
|
||||
func (c *IoTClient) GetDevice(id string) (*IoTDevice, error) {
|
||||
resp, err := c.client.Get(c.baseURL + "/api/v1/devices/" + id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取设备 %s 失败: %w", id, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, fmt.Errorf("设备 %s 不存在", id)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("获取设备 %s 返回状态码 %d", id, resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Device IoTDevice `json:"device"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析设备信息失败: %w", err)
|
||||
}
|
||||
|
||||
return &result.Device, nil
|
||||
}
|
||||
|
||||
// ToggleDevice 切换设备开关状态
|
||||
func (c *IoTClient) ToggleDevice(id string) error {
|
||||
req, err := http.NewRequest(http.MethodPost, c.baseURL+"/api/v1/devices/"+id+"/toggle", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建切换请求失败: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("切换设备 %s 失败: %w", id, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return fmt.Errorf("设备 %s 不存在", id)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("切换设备 %s 返回状态码 %d", id, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 切换后清除缓存,确保下次查询获取最新状态
|
||||
c.mu.Lock()
|
||||
c.cache = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDeviceProperty 设置设备属性(温度、亮度、位置、模式、颜色等)
|
||||
func (c *IoTClient) SetDeviceProperty(id string, field string, value interface{}) error {
|
||||
body, err := json.Marshal(map[string]interface{}{
|
||||
"field": field,
|
||||
"value": value,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, c.baseURL+"/api/v1/devices/"+id+"/set", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建设置请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("设置设备 %s 属性失败: %w", id, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return fmt.Errorf("设备 %s 不存在", id)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
json.NewDecoder(resp.Body).Decode(&errResp)
|
||||
if errResp.Error != "" {
|
||||
return fmt.Errorf("设置设备 %s 属性失败: %s", id, errResp.Error)
|
||||
}
|
||||
return fmt.Errorf("设置设备 %s 属性返回状态码 %d", id, resp.StatusCode)
|
||||
}
|
||||
|
||||
// 修改后清除缓存
|
||||
c.mu.Lock()
|
||||
c.cache = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDevicesForContext 获取设备状态摘要(供上下文注入使用,失败不报错)
|
||||
func (c *IoTClient) GetDevicesForContext() []IoTDevice {
|
||||
devices, err := c.GetAllDevices()
|
||||
if err != nil {
|
||||
log.Printf("[IoT客户端] 获取设备状态摘要失败: %v", err)
|
||||
return nil
|
||||
}
|
||||
return devices
|
||||
}
|
||||
|
||||
// InvalidateCache 使缓存失效
|
||||
func (c *IoTClient) InvalidateCache() {
|
||||
c.mu.Lock()
|
||||
c.cache = nil
|
||||
c.mu.Unlock()
|
||||
}
|
||||
@@ -0,0 +1,439 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// IoTControlTool IoT 设备控制工具
|
||||
type IoTControlTool struct {
|
||||
iotClient IoTClientInterface
|
||||
}
|
||||
|
||||
// NewIoTControlTool 创建 IoT 控制工具
|
||||
func NewIoTControlTool(iotClient IoTClientInterface) *IoTControlTool {
|
||||
return &IoTControlTool{iotClient: iotClient}
|
||||
}
|
||||
|
||||
// Definition 返回工具定义
|
||||
func (t *IoTControlTool) Definition() model.ToolDefinition {
|
||||
return model.ToolDefinition{
|
||||
Name: "iot_control",
|
||||
Description: "【仅当开拓者明确要求控制设备时才使用此工具】控制家中智能设备。可以开关灯光、空调、窗帘、门锁等设备,也可以调节温度、亮度、位置、模式、颜色等属性。" +
|
||||
"\n⚠️ 重要约束:" +
|
||||
"\n - 不要在开拓者只是询问设备状态时调用此工具(查询设备请用 iot_query)" +
|
||||
"\n - 不要自行决定执行操作,必须等开拓者明确说出「打开」「关闭」「调到」「设置」等控制指令" +
|
||||
"\n - 不要因为之前对话中提到过某个设备就主动控制它" +
|
||||
"\n支持的操作:toggle(切换开关状态)、turn_on(打开设备)、turn_off(关闭设备)、" +
|
||||
"set_temperature(设置空调温度,需要 value 参数,单位°C)、" +
|
||||
"set_brightness(设置灯光亮度,需要 value 参数,0-100)、" +
|
||||
"set_position(设置窗帘位置,需要 value 参数,0-100,0=关闭 100=全开)、" +
|
||||
"set_mode(设置空调模式,需要 value 参数,可选值: cool/heat/auto)、" +
|
||||
"set_color(设置灯光颜色,需要 value 参数,可选值: warm_white/cool_white/colorful)",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"device_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "要控制的设备ID。可选值: light-livingroom, light-bedroom, ac-livingroom, ac-bedroom, curtain-livingroom, lock-door",
|
||||
},
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"toggle", "turn_on", "turn_off", "set_temperature", "set_brightness", "set_position", "set_mode", "set_color"},
|
||||
"description": "要执行的操作。toggle:切换开关状态;turn_on:打开设备;turn_off:关闭设备;set_temperature:设置空调温度(需配合value参数);set_brightness:设置灯光亮度(需配合value参数);set_position:设置窗帘位置(需配合value参数);set_mode:设置空调模式(需配合value参数);set_color:设置灯光颜色(需配合value参数)",
|
||||
},
|
||||
"value": map[string]interface{}{
|
||||
"type": "number",
|
||||
"description": "操作的值。set_temperature 时表示目标温度(°C),set_brightness 时表示亮度百分比(0-100),set_position 时表示窗帘开合程度(0-100)。action 为 set_temperature/set_brightness/set_position 时必须提供。set_mode 时为字符串(cool/heat/auto),set_color 时为字符串(warm_white/cool_white/colorful)",
|
||||
},
|
||||
},
|
||||
"required": []string{"device_id", "action"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeAction 标准化 action 参数,支持中文别名、power 参数等
|
||||
func normalizeAction(arguments map[string]interface{}) string {
|
||||
action, _ := arguments["action"].(string)
|
||||
|
||||
// 如果 action 为空,检查 power/status 参数
|
||||
if action == "" {
|
||||
// power 参数: "off"/"关"/"关闭" → turn_off, "on"/"开"/"打开" → turn_on
|
||||
if pv, ok := arguments["power"]; ok {
|
||||
switch v := pv.(type) {
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "off", "false", "关", "关闭":
|
||||
return "turn_off"
|
||||
case "on", "true", "开", "打开", "开启":
|
||||
return "turn_on"
|
||||
}
|
||||
case bool:
|
||||
if !v {
|
||||
return "turn_off"
|
||||
}
|
||||
return "turn_on"
|
||||
}
|
||||
}
|
||||
// status 参数同理
|
||||
if sv, ok := arguments["status"]; ok {
|
||||
switch v := sv.(type) {
|
||||
case string:
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "off", "false", "关", "关闭":
|
||||
return "turn_off"
|
||||
case "on", "true", "开", "打开", "开启":
|
||||
return "turn_on"
|
||||
}
|
||||
case bool:
|
||||
if !v {
|
||||
return "turn_off"
|
||||
}
|
||||
return "turn_on"
|
||||
}
|
||||
}
|
||||
// 默认 toggle
|
||||
return "toggle"
|
||||
}
|
||||
|
||||
// 标准化中文 action 名
|
||||
switch strings.ToLower(strings.TrimSpace(action)) {
|
||||
case "打开", "开启", "开":
|
||||
return "turn_on"
|
||||
case "关闭", "关":
|
||||
return "turn_off"
|
||||
case "切换":
|
||||
return "toggle"
|
||||
case "设置温度", "调温度", "set_temp":
|
||||
return "set_temperature"
|
||||
case "设置亮度", "调亮度", "set_light":
|
||||
return "set_brightness"
|
||||
case "设置位置", "调位置":
|
||||
return "set_position"
|
||||
case "设置模式", "调模式", "切换模式":
|
||||
return "set_mode"
|
||||
case "设置颜色", "调颜色", "换颜色":
|
||||
return "set_color"
|
||||
}
|
||||
|
||||
return action
|
||||
}
|
||||
|
||||
// Execute 执行设备控制
|
||||
func (t *IoTControlTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
if t.iotClient == nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: "IoT 客户端未初始化",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 参数别名:entity_id → device_id
|
||||
deviceID, _ := arguments["device_id"].(string)
|
||||
if deviceID == "" {
|
||||
deviceID, _ = arguments["entity_id"].(string)
|
||||
}
|
||||
|
||||
action := normalizeAction(arguments)
|
||||
|
||||
if deviceID == "" {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: "缺少设备ID(请使用 device_id 参数)",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 先获取设备名用于友好的返回消息(失败不影响后续流程)
|
||||
deviceName := deviceID
|
||||
if dev, err := t.iotClient.GetDevice(deviceID); err == nil {
|
||||
deviceName = dev.Name
|
||||
}
|
||||
|
||||
// 处理属性设置类操作
|
||||
switch action {
|
||||
case "set_temperature":
|
||||
return t.handleSetTemperature(deviceID, arguments)
|
||||
case "set_brightness":
|
||||
return t.handleSetBrightness(deviceID, arguments)
|
||||
case "set_position":
|
||||
return t.handleSetPosition(deviceID, arguments)
|
||||
case "set_mode":
|
||||
return t.handleSetMode(deviceID, arguments)
|
||||
case "set_color":
|
||||
return t.handleSetColor(deviceID, arguments)
|
||||
case "turn_off":
|
||||
// 声明式关闭:使用 SetDeviceProperty status/off 而非 toggle
|
||||
// 即使设备已经关闭,SetProperty 也会幂等处理
|
||||
if err := t.iotClient.SetDeviceProperty(deviceID, "status", "off"); err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("关闭设备失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
return &model.ToolResult{
|
||||
Output: fmt.Sprintf("已关闭设备: %s", deviceName),
|
||||
Error: "",
|
||||
}, nil
|
||||
case "turn_on":
|
||||
// 声明式打开:使用 SetDeviceProperty status/on 而非 toggle
|
||||
if err := t.iotClient.SetDeviceProperty(deviceID, "status", "on"); err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("打开设备失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
return &model.ToolResult{
|
||||
Output: fmt.Sprintf("已打开设备: %s", deviceName),
|
||||
Error: "",
|
||||
}, nil
|
||||
default: // "toggle"
|
||||
if err := t.iotClient.ToggleDevice(deviceID); err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("操作设备失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 获取切换后的状态
|
||||
updatedDevice, err := t.iotClient.GetDevice(deviceID)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: fmt.Sprintf("已成功切换设备 %s 的状态。", deviceName),
|
||||
Error: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
Output: fmt.Sprintf("已成功操作设备: %s\n当前状态: %s", updatedDevice.Name, formatDeviceLine(*updatedDevice)),
|
||||
Error: "",
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// extractValue 从 arguments 中提取 value 参数(支持 value/Value 及数字/字符串类型)
|
||||
func extractValue(arguments map[string]interface{}) interface{} {
|
||||
if v, ok := arguments["value"]; ok {
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSetTemperature 处理设置温度
|
||||
func (t *IoTControlTool) handleSetTemperature(deviceID string, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
val := extractValue(arguments)
|
||||
if val == nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: "缺少 value 参数,请指定目标温度(如 24)",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 先获取当前设备信息
|
||||
currentDevice, err := t.iotClient.GetDevice(deviceID)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("获取设备状态失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
temperature, ok := toFloat64(val)
|
||||
if !ok {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("温度值无效: %v", val),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := t.iotClient.SetDeviceProperty(deviceID, "temperature", temperature); err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("设置温度失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
Output: fmt.Sprintf("已将 %s 温度从 %.1f°C 调整为 %.1f°C", currentDevice.Name, currentDevice.Temperature, temperature),
|
||||
Error: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleSetBrightness 处理设置亮度
|
||||
func (t *IoTControlTool) handleSetBrightness(deviceID string, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
val := extractValue(arguments)
|
||||
if val == nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: "缺少 value 参数,请指定亮度值(0-100)",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 先获取当前设备信息
|
||||
currentDevice, err := t.iotClient.GetDevice(deviceID)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("获取设备状态失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
brightness, ok := toFloat64(val)
|
||||
if !ok {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("亮度值无效: %v", val),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := t.iotClient.SetDeviceProperty(deviceID, "brightness", brightness); err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("设置亮度失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
Output: fmt.Sprintf("已将 %s 亮度调整为 %d%%", currentDevice.Name, int(brightness)),
|
||||
Error: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleSetPosition 处理设置窗帘位置
|
||||
func (t *IoTControlTool) handleSetPosition(deviceID string, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
val := extractValue(arguments)
|
||||
if val == nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: "缺少 value 参数,请指定位置值(0=关闭, 100=全开)",
|
||||
}, nil
|
||||
}
|
||||
|
||||
currentDevice, err := t.iotClient.GetDevice(deviceID)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("获取设备状态失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
position, ok := toFloat64(val)
|
||||
if !ok {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("位置值无效: %v", val),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := t.iotClient.SetDeviceProperty(deviceID, "position", position); err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("设置窗帘位置失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
Output: fmt.Sprintf("已将 %s 窗帘调整为 %d%%", currentDevice.Name, int(position)),
|
||||
Error: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleSetMode 处理设置空调模式
|
||||
func (t *IoTControlTool) handleSetMode(deviceID string, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
val := extractValue(arguments)
|
||||
if val == nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: "缺少 value 参数,请指定模式(cool/heat/auto)",
|
||||
}, nil
|
||||
}
|
||||
|
||||
mode, ok := val.(string)
|
||||
if !ok {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("模式值无效: %v", val),
|
||||
}, nil
|
||||
}
|
||||
|
||||
currentDevice, err := t.iotClient.GetDevice(deviceID)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("获取设备状态失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := t.iotClient.SetDeviceProperty(deviceID, "mode", mode); err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("设置模式失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
Output: fmt.Sprintf("已将 %s 模式切换为 %s", currentDevice.Name, mode),
|
||||
Error: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleSetColor 处理设置灯光颜色
|
||||
func (t *IoTControlTool) handleSetColor(deviceID string, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
val := extractValue(arguments)
|
||||
if val == nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: "缺少 value 参数,请指定颜色(warm_white/cool_white/colorful)",
|
||||
}, nil
|
||||
}
|
||||
|
||||
color, ok := val.(string)
|
||||
if !ok {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("颜色值无效: %v", val),
|
||||
}, nil
|
||||
}
|
||||
|
||||
currentDevice, err := t.iotClient.GetDevice(deviceID)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("获取设备状态失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := t.iotClient.SetDeviceProperty(deviceID, "color", color); err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("设置颜色失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
Output: fmt.Sprintf("已将 %s 灯光颜色切换为 %s", currentDevice.Name, color),
|
||||
Error: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// toFloat64 将 interface{} 转换为 float64
|
||||
func toFloat64(v interface{}) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case float32:
|
||||
return float64(val), true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
case json.Number:
|
||||
f, err := val.Float64()
|
||||
return f, err == nil
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// IoTQueryTool IoT 设备查询工具
|
||||
type IoTQueryTool struct {
|
||||
iotClient IoTClientInterface
|
||||
}
|
||||
|
||||
// NewIoTQueryTool 创建 IoT 查询工具
|
||||
func NewIoTQueryTool(iotClient IoTClientInterface) *IoTQueryTool {
|
||||
return &IoTQueryTool{iotClient: iotClient}
|
||||
}
|
||||
|
||||
// Definition 返回工具定义
|
||||
func (t *IoTQueryTool) Definition() model.ToolDefinition {
|
||||
return model.ToolDefinition{
|
||||
Name: "iot_query",
|
||||
Description: "查询家中智能设备状态。注意:当前设备状态通常已自动注入到系统提示词中,你通常不需要调用此工具即可回答设备状态问题。只有在设备状态信息陈旧或明显不完整时才调用此工具刷新。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"device_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "要查询的设备ID(可选,不填则返回所有设备)。可选值: light-livingroom, light-bedroom, ac-livingroom, ac-bedroom, curtain-livingroom, sensor-temperature, sensor-humidity, lock-door",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute 执行查询
|
||||
func (t *IoTQueryTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
if t.iotClient == nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: "IoT 客户端未初始化",
|
||||
}, nil
|
||||
}
|
||||
|
||||
deviceID, _ := arguments["device_id"].(string)
|
||||
|
||||
if deviceID != "" {
|
||||
// 查询单个设备
|
||||
device, err := t.iotClient.GetDevice(deviceID)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
return &model.ToolResult{
|
||||
Output: formatSingleDevice(device),
|
||||
Error: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 查询所有设备
|
||||
devices, err := t.iotClient.GetAllDevices()
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.WriteString(fmt.Sprintf("当前共有 %d 台智能设备:\n\n", len(devices)))
|
||||
for _, d := range devices {
|
||||
result.WriteString(formatDeviceLine(d) + "\n")
|
||||
}
|
||||
return &model.ToolResult{
|
||||
Output: result.String(),
|
||||
Error: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func formatSingleDevice(d *IoTDevice) string {
|
||||
return fmt.Sprintf("设备: %s (%s)\n状态: %s", d.Name, d.Type, formatDeviceLine(*d))
|
||||
}
|
||||
|
||||
func formatDeviceLine(d IoTDevice) string {
|
||||
switch d.Type {
|
||||
case "light":
|
||||
if d.Status == "on" {
|
||||
return fmt.Sprintf("💡 %s: 开启 (亮度%d%%, %s)", d.Name, d.Brightness, d.Color)
|
||||
}
|
||||
return fmt.Sprintf("💡 %s: 关闭", d.Name)
|
||||
case "ac":
|
||||
if d.Status == "on" {
|
||||
mode := d.Mode
|
||||
switch mode {
|
||||
case "cool":
|
||||
mode = "制冷"
|
||||
case "heat":
|
||||
mode = "制热"
|
||||
case "auto":
|
||||
mode = "自动"
|
||||
}
|
||||
return fmt.Sprintf("❄️ %s: 运行中 (%s %.0f°C)", d.Name, mode, d.Temperature)
|
||||
}
|
||||
return fmt.Sprintf("❄️ %s: 关闭", d.Name)
|
||||
case "curtain":
|
||||
if d.Status == "open" {
|
||||
return fmt.Sprintf("🪟 %s: 已打开", d.Name)
|
||||
}
|
||||
return fmt.Sprintf("🪟 %s: 已关闭", d.Name)
|
||||
case "sensor":
|
||||
unit := d.Unit
|
||||
if unit == "celsius" {
|
||||
unit = "°C"
|
||||
} else if unit == "percent" {
|
||||
unit = "%"
|
||||
}
|
||||
return fmt.Sprintf("🌡️ %s: %.1f%s", d.Name, d.Value, unit)
|
||||
case "lock":
|
||||
status := "已锁定"
|
||||
if d.Status == "unlocked" {
|
||||
status = "已解锁"
|
||||
}
|
||||
return fmt.Sprintf("🔒 %s: %s (电量%d%%)", d.Name, status, d.Battery)
|
||||
default:
|
||||
return fmt.Sprintf("%s: %s", d.Name, d.Status)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,187 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// JSONTool provides JSON parsing, querying, and validation for the LLM.
|
||||
type JSONTool struct{}
|
||||
|
||||
// NewJSONTool creates a JSON processing tool.
|
||||
func NewJSONTool() *JSONTool {
|
||||
return &JSONTool{}
|
||||
}
|
||||
|
||||
// Definition returns the tool definition for LLM function calling.
|
||||
func (t *JSONTool) Definition() model.ToolDefinition {
|
||||
return model.ToolDefinition{
|
||||
Name: "json_ops",
|
||||
Description: "JSON处理工具。解析JSON字符串并格式化输出、用简单路径查询JSON字段、验证JSON是否合法。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"parse", "query", "validate"},
|
||||
"description": "操作类型。parse: 解析JSON并格式化输出;query: 用路径查询JSON中的值(如\"users.0.name\"表示取users数组第0个元素的name字段);validate: 验证JSON字符串是否合法",
|
||||
},
|
||||
"json_string": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "JSON字符串",
|
||||
},
|
||||
"path": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "查询路径(query操作时使用)。支持点分隔和数组索引,如 \"users.0.name\"、\"data.list.2.title\"",
|
||||
},
|
||||
},
|
||||
"required": []string{"action", "json_string"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute performs JSON operations.
|
||||
func (t *JSONTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
action, ok := arguments["action"].(string)
|
||||
if !ok || action == "" {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 action 参数"}, nil
|
||||
}
|
||||
|
||||
jsonStr, ok := arguments["json_string"].(string)
|
||||
if !ok || jsonStr == "" {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 json_string 参数"}, nil
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "parse":
|
||||
return t.handleParse(jsonStr)
|
||||
case "query":
|
||||
path, _ := arguments["path"].(string)
|
||||
return t.handleQuery(jsonStr, path)
|
||||
case "validate":
|
||||
return t.handleValidate(jsonStr)
|
||||
default:
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Error: fmt.Sprintf("未知操作: %s,支持: parse, query, validate", action),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *JSONTool) handleParse(jsonStr string) (*model.ToolResult, error) {
|
||||
var data interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("JSON解析失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
pretty, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("JSON格式化失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("解析成功\n格式化输出:\n%s", string(pretty)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *JSONTool) handleQuery(jsonStr, path string) (*model.ToolResult, error) {
|
||||
if path == "" {
|
||||
return &model.ToolResult{ID: "", Error: "query 操作需要 path 参数"}, nil
|
||||
}
|
||||
|
||||
var data interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("JSON解析失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
value, err := queryPath(data, path)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
pretty, err := json.MarshalIndent(value, "", " ")
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("路径: %s\n值: %v", path, value),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("路径: %s\n值:\n%s", path, string(pretty)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *JSONTool) handleValidate(jsonStr string) (*model.ToolResult, error) {
|
||||
var data interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
errStr := err.Error()
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("❌ JSON不合法\n错误: %s", errStr),
|
||||
}, nil
|
||||
}
|
||||
|
||||
typeName := "object"
|
||||
switch data.(type) {
|
||||
case []interface{}:
|
||||
typeName = "array"
|
||||
case string:
|
||||
typeName = "string"
|
||||
case float64:
|
||||
typeName = "number"
|
||||
case bool:
|
||||
typeName = "boolean"
|
||||
case nil:
|
||||
typeName = "null"
|
||||
}
|
||||
|
||||
size := len(jsonStr)
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("✅ JSON合法\n类型: %s\n大小: %d bytes", typeName, size),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func queryPath(data interface{}, path string) (interface{}, error) {
|
||||
path = strings.TrimPrefix(path, "$.")
|
||||
if path == "" || path == "$" {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
parts := strings.Split(path, ".")
|
||||
current := data
|
||||
|
||||
for _, part := range parts {
|
||||
switch v := current.(type) {
|
||||
case map[string]interface{}:
|
||||
var ok bool
|
||||
current, ok = v[part]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("路径 '%s' 中字段 '%s' 不存在", path, part)
|
||||
}
|
||||
|
||||
case []interface{}:
|
||||
idx, err := strconv.Atoi(part)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("路径 '%s' 中 '%s' 不是有效的数组索引", path, part)
|
||||
}
|
||||
if idx < 0 || idx >= len(v) {
|
||||
return nil, fmt.Errorf("路径 '%s' 中索引 %d 越界(数组长度 %d)", path, idx, len(v))
|
||||
}
|
||||
current = v[idx]
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("路径 '%s' 中无法继续导航:'%s' 不是对象或数组", path, part)
|
||||
}
|
||||
}
|
||||
|
||||
return current, nil
|
||||
}
|
||||
@@ -0,0 +1,348 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// MarkdownTool provides Markdown processing utilities for the LLM.
|
||||
type MarkdownTool struct{}
|
||||
|
||||
// NewMarkdownTool creates a Markdown processing tool.
|
||||
func NewMarkdownTool() *MarkdownTool {
|
||||
return &MarkdownTool{}
|
||||
}
|
||||
|
||||
// Definition returns the tool definition for LLM function calling.
|
||||
func (t *MarkdownTool) Definition() model.ToolDefinition {
|
||||
return model.ToolDefinition{
|
||||
Name: "markdown",
|
||||
Description: "Markdown处理工具。将Markdown转为HTML、提取纯文本、提取链接/代码块、生成目录。用于处理Markdown格式的文档内容。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"to_html", "to_text", "extract_links", "extract_code", "table_of_contents"},
|
||||
"description": "操作类型。to_html: 转换为HTML;to_text: 提取纯文本;extract_links: 提取所有链接;extract_code: 提取所有代码块;table_of_contents: 生成目录",
|
||||
},
|
||||
"markdown": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Markdown格式文本,需要处理的Markdown内容",
|
||||
},
|
||||
},
|
||||
"required": []string{"action", "markdown"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute performs Markdown processing operations.
|
||||
func (t *MarkdownTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
action, ok := arguments["action"].(string)
|
||||
if !ok || action == "" {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 action 参数"}, nil
|
||||
}
|
||||
|
||||
md, ok := arguments["markdown"].(string)
|
||||
if !ok || strings.TrimSpace(md) == "" {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 markdown 参数或内容为空"}, nil
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "to_html":
|
||||
return t.handleToHTML(md)
|
||||
case "to_text":
|
||||
return t.handleToText(md)
|
||||
case "extract_links":
|
||||
return t.handleExtractLinks(md)
|
||||
case "extract_code":
|
||||
return t.handleExtractCode(md)
|
||||
case "table_of_contents":
|
||||
return t.handleTableOfContents(md)
|
||||
default:
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Error: fmt.Sprintf("未知操作: %s,支持: to_html, to_text, extract_links, extract_code, table_of_contents", action),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MarkdownTool) handleToHTML(md string) (*model.ToolResult, error) {
|
||||
html := md
|
||||
|
||||
codeBlocks := make([]string, 0)
|
||||
reFence := regexp.MustCompile("(?s)```[^`]*```")
|
||||
html = reFence.ReplaceAllStringFunc(html, func(match string) string {
|
||||
codeBlocks = append(codeBlocks, match)
|
||||
return fmt.Sprintf("\x00CODEBLOCK%d\x00", len(codeBlocks)-1)
|
||||
})
|
||||
|
||||
inlineCodes := make([]string, 0)
|
||||
reInlineCode := regexp.MustCompile("`[^`]+`")
|
||||
html = reInlineCode.ReplaceAllStringFunc(html, func(match string) string {
|
||||
inlineCodes = append(inlineCodes, match)
|
||||
return fmt.Sprintf("\x00INLINECODE%d\x00", len(inlineCodes)-1)
|
||||
})
|
||||
|
||||
reImage := regexp.MustCompile(`!\[([^\]]*)\]\(([^)]+)\)`)
|
||||
html = reImage.ReplaceAllString(html, `<img src="$2" alt="$1">`)
|
||||
|
||||
reLink := regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`)
|
||||
html = reLink.ReplaceAllString(html, `<a href="$2">$1</a>`)
|
||||
|
||||
reBold := regexp.MustCompile(`\*\*([^*]+)\*\*`)
|
||||
html = reBold.ReplaceAllString(html, `<strong>$1</strong>`)
|
||||
reBold2 := regexp.MustCompile(`__([^_]+)__`)
|
||||
html = reBold2.ReplaceAllString(html, `<strong>$1</strong>`)
|
||||
|
||||
reItalic := regexp.MustCompile(`\*([^*]+)\*`)
|
||||
html = reItalic.ReplaceAllString(html, `<em>$1</em>`)
|
||||
reItalic2 := regexp.MustCompile(`_([^_]+)_`)
|
||||
html = reItalic2.ReplaceAllString(html, `<em>$1</em>`)
|
||||
|
||||
reStrike := regexp.MustCompile(`~~([^~]+)~~`)
|
||||
html = reStrike.ReplaceAllString(html, `<del>$1</del>`)
|
||||
|
||||
reH6 := regexp.MustCompile(`(?m)^######\s+(.+)$`)
|
||||
html = reH6.ReplaceAllString(html, `<h6>$1</h6>`)
|
||||
reH5 := regexp.MustCompile(`(?m)^#####\s+(.+)$`)
|
||||
html = reH5.ReplaceAllString(html, `<h5>$1</h5>`)
|
||||
reH4 := regexp.MustCompile(`(?m)^####\s+(.+)$`)
|
||||
html = reH4.ReplaceAllString(html, `<h4>$1</h4>`)
|
||||
reH3 := regexp.MustCompile(`(?m)^###\s+(.+)$`)
|
||||
html = reH3.ReplaceAllString(html, `<h3>$1</h3>`)
|
||||
reH2 := regexp.MustCompile(`(?m)^##\s+(.+)$`)
|
||||
html = reH2.ReplaceAllString(html, `<h2>$1</h2>`)
|
||||
reH1 := regexp.MustCompile(`(?m)^#\s+(.+)$`)
|
||||
html = reH1.ReplaceAllString(html, `<h1>$1</h1>`)
|
||||
|
||||
reHR := regexp.MustCompile(`(?m)^(---|\*\*\*|___)\s*$`)
|
||||
html = reHR.ReplaceAllString(html, `<hr>`)
|
||||
|
||||
html = t.processLists(html, `(?m)^[\-*]\s+`, "ul")
|
||||
html = t.processLists(html, `(?m)^\d+\.\s+`, "ol")
|
||||
|
||||
reBlockquote := regexp.MustCompile(`(?m)^>\s?(.+)$`)
|
||||
html = reBlockquote.ReplaceAllString(html, `<blockquote>$1</blockquote>`)
|
||||
|
||||
html = t.wrapParagraphs(html)
|
||||
|
||||
for i, cb := range codeBlocks {
|
||||
content := strings.TrimPrefix(cb, "```")
|
||||
content = strings.TrimSuffix(content, "```")
|
||||
lang := ""
|
||||
content = strings.TrimSpace(content)
|
||||
if idx := strings.Index(content, "\n"); idx > 0 {
|
||||
lang = strings.TrimSpace(content[:idx])
|
||||
content = strings.TrimSpace(content[idx+1:])
|
||||
}
|
||||
if lang != "" {
|
||||
html = strings.ReplaceAll(html, fmt.Sprintf("\x00CODEBLOCK%d\x00", i),
|
||||
fmt.Sprintf(`<pre><code class="language-%s">%s</code></pre>`, lang, escapeHTML(content)))
|
||||
} else {
|
||||
html = strings.ReplaceAll(html, fmt.Sprintf("\x00CODEBLOCK%d\x00", i),
|
||||
fmt.Sprintf("<pre><code>%s</code></pre>", escapeHTML(content)))
|
||||
}
|
||||
}
|
||||
|
||||
for i, ic := range inlineCodes {
|
||||
content := strings.Trim(ic, "`")
|
||||
html = strings.ReplaceAll(html, fmt.Sprintf("\x00INLINECODE%d\x00", i),
|
||||
fmt.Sprintf("<code>%s</code>", escapeHTML(content)))
|
||||
}
|
||||
|
||||
return &model.ToolResult{ID: "", Output: html}, nil
|
||||
}
|
||||
|
||||
func (t *MarkdownTool) handleToText(md string) (*model.ToolResult, error) {
|
||||
text := md
|
||||
|
||||
reFence := regexp.MustCompile("(?s)```[^`]*```")
|
||||
text = reFence.ReplaceAllString(text, "[代码块]")
|
||||
|
||||
reInlineCode := regexp.MustCompile("`[^`]+`")
|
||||
text = reInlineCode.ReplaceAllString(text, "[代码]")
|
||||
|
||||
reImage := regexp.MustCompile(`!\[([^\]]*)\]\([^)]+\)`)
|
||||
text = reImage.ReplaceAllString(text, "$1")
|
||||
|
||||
reLink := regexp.MustCompile(`\[([^\]]+)\]\([^)]+\)`)
|
||||
text = reLink.ReplaceAllString(text, "$1")
|
||||
|
||||
text = regexp.MustCompile(`\*\*([^*]+)\*\*`).ReplaceAllString(text, "$1")
|
||||
text = regexp.MustCompile(`__([^_]+)__`).ReplaceAllString(text, "$1")
|
||||
text = regexp.MustCompile(`\*([^*]+)\*`).ReplaceAllString(text, "$1")
|
||||
text = regexp.MustCompile(`_([^_]+)_`).ReplaceAllString(text, "$1")
|
||||
|
||||
text = regexp.MustCompile(`~~([^~]+)~~`).ReplaceAllString(text, "$1")
|
||||
|
||||
text = regexp.MustCompile(`(?m)^#{1,6}\s+`).ReplaceAllString(text, "")
|
||||
|
||||
text = regexp.MustCompile(`(?m)^(---|\*\*\*|___)\s*$`).ReplaceAllString(text, "")
|
||||
|
||||
text = regexp.MustCompile(`(?m)^[\-*]\s+`).ReplaceAllString(text, "")
|
||||
text = regexp.MustCompile(`(?m)^\d+\.\s+`).ReplaceAllString(text, "")
|
||||
|
||||
text = regexp.MustCompile(`(?m)^>\s?`).ReplaceAllString(text, "")
|
||||
|
||||
text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n")
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("纯文本提取结果 (%d 字符):\n\n%s",
|
||||
len([]rune(text)), strings.TrimSpace(text)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *MarkdownTool) handleExtractLinks(md string) (*model.ToolResult, error) {
|
||||
reLink := regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`)
|
||||
matches := reLink.FindAllStringSubmatch(md, -1)
|
||||
|
||||
if len(matches) == 0 {
|
||||
return &model.ToolResult{ID: "", Output: "未找到任何链接"}, nil
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.WriteString(fmt.Sprintf("提取链接 (共 %d 个):\n\n", len(matches)))
|
||||
for i, m := range matches {
|
||||
result.WriteString(fmt.Sprintf("%d. [%s](%s)\n - 文本: %s\n - URL: %s\n\n",
|
||||
i+1, m[1], m[2], m[1], m[2]))
|
||||
}
|
||||
|
||||
return &model.ToolResult{ID: "", Output: strings.TrimSpace(result.String())}, nil
|
||||
}
|
||||
|
||||
func (t *MarkdownTool) handleExtractCode(md string) (*model.ToolResult, error) {
|
||||
reFence := regexp.MustCompile("(?s)```([^`]*)```")
|
||||
matches := reFence.FindAllStringSubmatch(md, -1)
|
||||
|
||||
if len(matches) == 0 {
|
||||
return &model.ToolResult{ID: "", Output: "未找到任何代码块"}, nil
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.WriteString(fmt.Sprintf("提取代码块 (共 %d 个):\n\n", len(matches)))
|
||||
for i, m := range matches {
|
||||
content := strings.TrimSpace(m[1])
|
||||
lang := ""
|
||||
if idx := strings.Index(content, "\n"); idx > 0 {
|
||||
lang = strings.TrimSpace(content[:idx])
|
||||
content = strings.TrimSpace(content[idx+1:])
|
||||
}
|
||||
|
||||
result.WriteString(fmt.Sprintf("--- 代码块 %d", i+1))
|
||||
if lang != "" {
|
||||
result.WriteString(fmt.Sprintf(" (语言: %s)", lang))
|
||||
}
|
||||
result.WriteString(fmt.Sprintf(" ---\n%s\n\n", truncateText(content, 500)))
|
||||
}
|
||||
|
||||
return &model.ToolResult{ID: "", Output: strings.TrimSpace(result.String())}, nil
|
||||
}
|
||||
|
||||
func (t *MarkdownTool) handleTableOfContents(md string) (*model.ToolResult, error) {
|
||||
reHeading := regexp.MustCompile(`(?m)^(#{1,6})\s+(.+)$`)
|
||||
matches := reHeading.FindAllStringSubmatch(md, -1)
|
||||
|
||||
if len(matches) == 0 {
|
||||
return &model.ToolResult{ID: "", Output: "未找到任何标题,无法生成目录"}, nil
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.WriteString(fmt.Sprintf("文档目录 (共 %d 个标题):\n\n", len(matches)))
|
||||
for _, m := range matches {
|
||||
level := len(m[1])
|
||||
title := strings.TrimSpace(m[2])
|
||||
indent := strings.Repeat(" ", level-1)
|
||||
result.WriteString(fmt.Sprintf("%s%s %s\n", indent, strings.Repeat("#", level), title))
|
||||
}
|
||||
|
||||
return &model.ToolResult{ID: "", Output: result.String()}, nil
|
||||
}
|
||||
|
||||
func (t *MarkdownTool) processLists(html, itemPattern, listTag string) string {
|
||||
reItem := regexp.MustCompile(itemPattern + `(.+)$`)
|
||||
lines := strings.Split(html, "\n")
|
||||
result := make([]string, 0, len(lines))
|
||||
|
||||
inList := false
|
||||
for _, line := range lines {
|
||||
if reItem.MatchString(line) {
|
||||
content := reItem.ReplaceAllString(line, "$1")
|
||||
if !inList {
|
||||
result = append(result, fmt.Sprintf("<%s>", listTag))
|
||||
inList = true
|
||||
}
|
||||
result = append(result, fmt.Sprintf("<li>%s</li>", content))
|
||||
} else {
|
||||
if inList {
|
||||
result = append(result, fmt.Sprintf("</%s>", listTag))
|
||||
inList = false
|
||||
}
|
||||
result = append(result, line)
|
||||
}
|
||||
}
|
||||
if inList {
|
||||
result = append(result, fmt.Sprintf("</%s>", listTag))
|
||||
}
|
||||
|
||||
return strings.Join(result, "\n")
|
||||
}
|
||||
|
||||
func (t *MarkdownTool) wrapParagraphs(html string) string {
|
||||
lines := strings.Split(html, "\n")
|
||||
result := make([]string, 0, len(lines))
|
||||
|
||||
skipTags := map[string]bool{
|
||||
"<h1>": true, "<h2>": true, "<h3>": true, "<h4>": true, "<h5>": true, "<h6>": true,
|
||||
"<hr>": true, "<ul>": true, "</ul>": true, "<ol>": true, "</ol>": true,
|
||||
"<li>": true, "</li>": true, "<blockquote>": true, "</blockquote>": true,
|
||||
"<pre>": true, "</pre>": true, "<img": true,
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
result = append(result, line)
|
||||
continue
|
||||
}
|
||||
|
||||
isTag := false
|
||||
for tag := range skipTags {
|
||||
if strings.HasPrefix(trimmed, tag) {
|
||||
isTag = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isTag {
|
||||
result = append(result, fmt.Sprintf("<p>%s</p>", trimmed))
|
||||
} else {
|
||||
result = append(result, line)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(result, "\n")
|
||||
}
|
||||
|
||||
func escapeHTML(s string) string {
|
||||
replacer := strings.NewReplacer(
|
||||
"&", "&"+"amp;",
|
||||
"<", "&"+"lt;",
|
||||
">", "&"+"gt;",
|
||||
"\"", "&"+"quot;",
|
||||
)
|
||||
return replacer.Replace(s)
|
||||
}
|
||||
|
||||
func truncateText(s string, maxLen int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return string(runes[:maxLen]) + "..."
|
||||
}
|
||||
@@ -0,0 +1,318 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
mathrand "math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// RandomTool provides random generation utilities for the LLM.
|
||||
type RandomTool struct{}
|
||||
|
||||
// NewRandomTool creates a random generation tool.
|
||||
func NewRandomTool() *RandomTool {
|
||||
return &RandomTool{}
|
||||
}
|
||||
|
||||
// Definition returns the tool definition for LLM function calling.
|
||||
func (t *RandomTool) Definition() model.ToolDefinition {
|
||||
return model.ToolDefinition{
|
||||
Name: "random",
|
||||
Description: "随机生成工具。生成随机数、UUID、安全密码,或从列表中随机选取/打乱元素。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"number", "uuid", "password", "pick", "shuffle"},
|
||||
"description": "操作类型。number: 生成随机整数;uuid: 生成UUID v4;password: 生成安全密码;pick: 从列表随机选取;shuffle: 随机打乱列表",
|
||||
},
|
||||
"min": map[string]interface{}{
|
||||
"type": "number",
|
||||
"description": "随机数最小值(用于 number 操作),默认 0",
|
||||
},
|
||||
"max": map[string]interface{}{
|
||||
"type": "number",
|
||||
"description": "随机数最大值(用于 number 操作),默认 100",
|
||||
},
|
||||
"length": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "密码长度(用于 password 操作),默认 16",
|
||||
},
|
||||
"items": map[string]interface{}{
|
||||
"type": "array",
|
||||
"description": "列表项(用于 pick/shuffle 操作),字符串数组",
|
||||
"items": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"count": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "选取数量(用于 pick 操作),默认 1",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute performs random generation operations.
|
||||
func (t *RandomTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
action, ok := arguments["action"].(string)
|
||||
if !ok || action == "" {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 action 参数"}, nil
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "number":
|
||||
return t.handleNumber(arguments)
|
||||
case "uuid":
|
||||
return t.handleUUID()
|
||||
case "password":
|
||||
return t.handlePassword(arguments)
|
||||
case "pick":
|
||||
return t.handlePick(arguments)
|
||||
case "shuffle":
|
||||
return t.handleShuffle(arguments)
|
||||
default:
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Error: fmt.Sprintf("未知操作: %s,支持: number, uuid, password, pick, shuffle", action),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *RandomTool) handleNumber(arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
minVal := getFloatArg(arguments, "min", 0)
|
||||
maxVal := getFloatArg(arguments, "max", 100)
|
||||
|
||||
if minVal > maxVal {
|
||||
minVal, maxVal = maxVal, minVal
|
||||
}
|
||||
|
||||
minI := int64(minVal)
|
||||
maxI := int64(maxVal)
|
||||
|
||||
rangeVal := maxI - minI + 1
|
||||
if rangeVal <= 0 {
|
||||
return &model.ToolResult{ID: "", Error: "无效的数值范围"}, nil
|
||||
}
|
||||
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(rangeVal))
|
||||
if err != nil {
|
||||
result := minI + mathrand.Int63n(rangeVal)
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("随机整数 [%d, %d]: %d", minI, maxI, result),
|
||||
}, nil
|
||||
}
|
||||
|
||||
result := minI + n.Int64()
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("随机整数 [%d, %d]: %d", minI, maxI, result),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *RandomTool) handleUUID() (*model.ToolResult, error) {
|
||||
uuid := make([]byte, 16)
|
||||
_, err := rand.Read(uuid)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("生成UUID失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
uuid[6] = (uuid[6] & 0x0f) | 0x40
|
||||
uuid[8] = (uuid[8] & 0x3f) | 0x80
|
||||
|
||||
uuidStr := fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
|
||||
uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16])
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("UUID v4: %s", uuidStr),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *RandomTool) handlePassword(arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
length := getIntArg(arguments, "length", 16)
|
||||
if length < 4 {
|
||||
length = 16
|
||||
}
|
||||
if length > 128 {
|
||||
length = 128
|
||||
}
|
||||
|
||||
uppercase := "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
lowercase := "abcdefghijklmnopqrstuvwxyz"
|
||||
digits := "0123456789"
|
||||
symbols := "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
|
||||
allChars := uppercase + lowercase + digits + symbols
|
||||
|
||||
password := make([]byte, length)
|
||||
|
||||
password[0] = uppercase[secureIndex(len(uppercase))]
|
||||
password[1] = lowercase[secureIndex(len(lowercase))]
|
||||
password[2] = digits[secureIndex(len(digits))]
|
||||
password[3] = symbols[secureIndex(len(symbols))]
|
||||
|
||||
for i := 4; i < length; i++ {
|
||||
password[i] = allChars[secureIndex(len(allChars))]
|
||||
}
|
||||
|
||||
shuffleBytes(password)
|
||||
|
||||
passwordStr := string(password)
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("安全密码 (长度: %d):\n%s\n\n字符集: 大写字母 + 小写字母 + 数字 + 特殊符号",
|
||||
length, passwordStr),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *RandomTool) handlePick(arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
items := getStringSliceArg(arguments, "items")
|
||||
if len(items) == 0 {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 items 参数或列表为空"}, nil
|
||||
}
|
||||
|
||||
count := getIntArg(arguments, "count", 1)
|
||||
if count < 1 {
|
||||
count = 1
|
||||
}
|
||||
if count > len(items) {
|
||||
count = len(items)
|
||||
}
|
||||
|
||||
indices := make([]int, len(items))
|
||||
for i := range indices {
|
||||
indices[i] = i
|
||||
}
|
||||
shuffleInts(indices)
|
||||
|
||||
picked := make([]string, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
picked = append(picked, items[indices[i]])
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.WriteString(fmt.Sprintf("从 %d 个选项中随机选取 %d 个:\n", len(items), count))
|
||||
for i, p := range picked {
|
||||
result.WriteString(fmt.Sprintf(" %d. %s\n", i+1, p))
|
||||
}
|
||||
|
||||
return &model.ToolResult{ID: "", Output: result.String()}, nil
|
||||
}
|
||||
|
||||
func (t *RandomTool) handleShuffle(arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
items := getStringSliceArg(arguments, "items")
|
||||
if len(items) == 0 {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 items 参数或列表为空"}, nil
|
||||
}
|
||||
|
||||
shuffled := make([]string, len(items))
|
||||
copy(shuffled, items)
|
||||
shuffleStrings(shuffled)
|
||||
|
||||
var result strings.Builder
|
||||
result.WriteString(fmt.Sprintf("随机打乱结果 (共 %d 项):\n", len(shuffled)))
|
||||
for i, s := range shuffled {
|
||||
result.WriteString(fmt.Sprintf(" %d. %s\n", i+1, s))
|
||||
}
|
||||
|
||||
return &model.ToolResult{ID: "", Output: result.String()}, nil
|
||||
}
|
||||
|
||||
// --- Helper functions ---
|
||||
|
||||
func getFloatArg(arguments map[string]interface{}, key string, fallback float64) float64 {
|
||||
if v, ok := arguments[key]; ok {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val
|
||||
case int:
|
||||
return float64(val)
|
||||
case int64:
|
||||
return float64(val)
|
||||
case json.Number:
|
||||
f, err := val.Float64()
|
||||
if err == nil {
|
||||
return f
|
||||
}
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getIntArg(arguments map[string]interface{}, key string, fallback int) int {
|
||||
if v, ok := arguments[key]; ok {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return int(val)
|
||||
case int:
|
||||
return val
|
||||
case int64:
|
||||
return int(val)
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getStringSliceArg(arguments map[string]interface{}, key string) []string {
|
||||
if v, ok := arguments[key]; ok {
|
||||
switch val := v.(type) {
|
||||
case []interface{}:
|
||||
result := make([]string, 0, len(val))
|
||||
for _, item := range val {
|
||||
if s, ok := item.(string); ok {
|
||||
result = append(result, s)
|
||||
} else {
|
||||
result = append(result, fmt.Sprintf("%v", item))
|
||||
}
|
||||
}
|
||||
return result
|
||||
case []string:
|
||||
return val
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func secureIndex(max int) int {
|
||||
if max <= 1 {
|
||||
return 0
|
||||
}
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
|
||||
if err != nil {
|
||||
return mathrand.Intn(max)
|
||||
}
|
||||
return int(n.Int64())
|
||||
}
|
||||
|
||||
func shuffleBytes(data []byte) {
|
||||
for i := len(data) - 1; i > 0; i-- {
|
||||
j := secureIndex(i + 1)
|
||||
data[i], data[j] = data[j], data[i]
|
||||
}
|
||||
}
|
||||
|
||||
func shuffleInts(data []int) {
|
||||
for i := len(data) - 1; i > 0; i-- {
|
||||
j := secureIndex(i + 1)
|
||||
data[i], data[j] = data[j], data[i]
|
||||
}
|
||||
}
|
||||
|
||||
func shuffleStrings(data []string) {
|
||||
for i := len(data) - 1; i > 0; i-- {
|
||||
j := secureIndex(i + 1)
|
||||
data[i], data[j] = data[j], data[i]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// Tool 工具接口
|
||||
type Tool interface {
|
||||
Definition() model.ToolDefinition
|
||||
Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error)
|
||||
}
|
||||
|
||||
// IoTClientFactory 用于创建 IoT 客户端的工厂函数类型
|
||||
type IoTClientFactory func() IoTClientInterface
|
||||
|
||||
// IoTClientInterface IoT 客户端接口(解耦对 ai-core 的依赖)
|
||||
type IoTClientInterface interface {
|
||||
GetAllDevices() ([]IoTDevice, error)
|
||||
GetDevice(id string) (*IoTDevice, error)
|
||||
ToggleDevice(id string) error
|
||||
SetDeviceProperty(id string, field string, value interface{}) error
|
||||
}
|
||||
|
||||
// IoTDevice IoT 设备结构体
|
||||
type IoTDevice struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Brightness int `json:"brightness,omitempty"`
|
||||
Color string `json:"color,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Position int `json:"position,omitempty"`
|
||||
Value float64 `json:"value,omitempty"`
|
||||
Unit string `json:"unit,omitempty"`
|
||||
Battery int `json:"battery,omitempty"`
|
||||
LastUpdated string `json:"last_updated"`
|
||||
}
|
||||
@@ -0,0 +1,295 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// TextTool provides text processing operations for the LLM.
|
||||
type TextTool struct{}
|
||||
|
||||
// NewTextTool creates a text processing tool.
|
||||
func NewTextTool() *TextTool {
|
||||
return &TextTool{}
|
||||
}
|
||||
|
||||
// Definition returns the tool definition for LLM function calling.
|
||||
func (t *TextTool) Definition() model.ToolDefinition {
|
||||
return model.ToolDefinition{
|
||||
Name: "text",
|
||||
Description: "文本处理工具。统计文本、生成摘要、翻译文本、正则提取信息。用于处理用户提供的文本内容。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"count", "summarize", "translate", "extract"},
|
||||
"description": "操作类型。count: 统计字符/单词/行/段落数;summarize: 提取首段+关键句生成简单摘要;translate: 翻译文本(需指定target_lang);extract: 正则提取邮箱/电话/URL等",
|
||||
},
|
||||
"text": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "输入文本,需要处理的文本内容",
|
||||
},
|
||||
"target_lang": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"en", "zh", "ja", "ko", "fr", "de"},
|
||||
"description": "翻译目标语言代码。en: 英语, zh: 中文, ja: 日语, ko: 韩语, fr: 法语, de: 德语",
|
||||
},
|
||||
"pattern": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "正则表达式模式,用于 extract 操作。常用预设: email(邮箱), phone(电话), url(网址)",
|
||||
},
|
||||
},
|
||||
"required": []string{"action", "text"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute performs text processing operations.
|
||||
func (t *TextTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
action, ok := arguments["action"].(string)
|
||||
if !ok || action == "" {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 action 参数"}, nil
|
||||
}
|
||||
|
||||
text, ok := arguments["text"].(string)
|
||||
if !ok || strings.TrimSpace(text) == "" {
|
||||
return &model.ToolResult{ID: "", Error: "缺少 text 参数或文本为空"}, nil
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "count":
|
||||
return t.handleCount(text)
|
||||
case "summarize":
|
||||
return t.handleSummarize(text)
|
||||
case "translate":
|
||||
return t.handleTranslate(arguments)
|
||||
case "extract":
|
||||
return t.handleExtract(arguments)
|
||||
default:
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Error: fmt.Sprintf("未知操作: %s,支持: count, summarize, translate, extract", action),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TextTool) handleCount(text string) (*model.ToolResult, error) {
|
||||
charCount := len([]rune(text))
|
||||
byteCount := len(text)
|
||||
|
||||
words := strings.Fields(text)
|
||||
wordCount := len(words)
|
||||
|
||||
lines := strings.Split(text, "\n")
|
||||
lineCount := len(lines)
|
||||
|
||||
paragraphs := regexp.MustCompile(`\n\s*\n`).Split(text, -1)
|
||||
paraCount := 0
|
||||
for _, p := range paragraphs {
|
||||
if strings.TrimSpace(p) != "" {
|
||||
paraCount++
|
||||
}
|
||||
}
|
||||
|
||||
chineseCount := 0
|
||||
for _, r := range text {
|
||||
if unicode.Is(unicode.Han, r) {
|
||||
chineseCount++
|
||||
}
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("文本统计结果:\n- 字符数 (含空格): %d\n- 字符数 (不含空格): %d\n- 字节数: %d\n- 单词数: %d\n- 行数: %d\n- 段落数: %d\n- 中文字符数: %d",
|
||||
charCount, len([]rune(strings.ReplaceAll(text, " ", ""))),
|
||||
byteCount, wordCount, lineCount, paraCount, chineseCount),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *TextTool) handleSummarize(text string) (*model.ToolResult, error) {
|
||||
var result strings.Builder
|
||||
result.WriteString("文本摘要:\n\n")
|
||||
|
||||
paragraphs := regexp.MustCompile(`\n\s*\n`).Split(text, -1)
|
||||
var firstPara string
|
||||
for _, p := range paragraphs {
|
||||
if trimmed := strings.TrimSpace(p); trimmed != "" {
|
||||
firstPara = trimmed
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if firstPara != "" {
|
||||
result.WriteString("【首段】\n")
|
||||
runes := []rune(firstPara)
|
||||
if len(runes) > 300 {
|
||||
firstPara = string(runes[:300]) + "..."
|
||||
}
|
||||
result.WriteString(firstPara)
|
||||
result.WriteString("\n\n")
|
||||
}
|
||||
|
||||
sentences := t.splitSentences(text)
|
||||
keySentences := t.extractKeySentences(sentences, 5)
|
||||
|
||||
if len(keySentences) > 0 {
|
||||
result.WriteString("【关键句】\n")
|
||||
for i, s := range keySentences {
|
||||
result.WriteString(fmt.Sprintf("%d. %s\n", i+1, s))
|
||||
}
|
||||
}
|
||||
|
||||
lines := strings.Split(text, "\n")
|
||||
words := strings.Fields(text)
|
||||
result.WriteString(fmt.Sprintf("\n【概况】共 %d 段、%d 句、%d 词、%d 行",
|
||||
len(paragraphs), len(sentences), len(words), len(lines)))
|
||||
|
||||
return &model.ToolResult{ID: "", Output: result.String()}, nil
|
||||
}
|
||||
|
||||
func (t *TextTool) splitSentences(text string) []string {
|
||||
re := regexp.MustCompile(`[^。!?.!?\n]+[。!?.!?\n]?`)
|
||||
return re.FindAllString(text, -1)
|
||||
}
|
||||
|
||||
func (t *TextTool) extractKeySentences(sentences []string, maxCount int) []string {
|
||||
type scored struct {
|
||||
text string
|
||||
score int
|
||||
}
|
||||
|
||||
var scoredList []scored
|
||||
keywords := []string{"重要", "关键", "核心", "主要", "首先", "最后", "因此", "所以", "总结",
|
||||
"important", "key", "critical", "significant", "therefore", "conclusion", "summary"}
|
||||
|
||||
for _, s := range sentences {
|
||||
trimmed := strings.TrimSpace(s)
|
||||
if len([]rune(trimmed)) < 10 {
|
||||
continue
|
||||
}
|
||||
|
||||
score := len([]rune(trimmed))
|
||||
lower := strings.ToLower(trimmed)
|
||||
for _, kw := range keywords {
|
||||
if strings.Contains(lower, kw) {
|
||||
score += 50
|
||||
}
|
||||
}
|
||||
scoredList = append(scoredList, scored{text: trimmed, score: score})
|
||||
}
|
||||
|
||||
for i := 0; i < len(scoredList); i++ {
|
||||
for j := i + 1; j < len(scoredList); j++ {
|
||||
if scoredList[j].score > scoredList[i].score {
|
||||
scoredList[i], scoredList[j] = scoredList[j], scoredList[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]string, 0, maxCount)
|
||||
for i := 0; i < len(scoredList) && i < maxCount; i++ {
|
||||
result = append(result, scoredList[i].text)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (t *TextTool) handleTranslate(arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
text, _ := arguments["text"].(string)
|
||||
targetLang, _ := arguments["target_lang"].(string)
|
||||
if targetLang == "" {
|
||||
targetLang = "zh"
|
||||
}
|
||||
|
||||
langNames := map[string]string{
|
||||
"en": "英语",
|
||||
"zh": "中文",
|
||||
"ja": "日语",
|
||||
"ko": "韩语",
|
||||
"fr": "法语",
|
||||
"de": "德语",
|
||||
}
|
||||
|
||||
langName, ok := langNames[targetLang]
|
||||
if !ok {
|
||||
langName = targetLang
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
ID: "",
|
||||
Output: fmt.Sprintf("【翻译请求】\n目标语言: %s (%s)\n原文 (%d 字符):\n---\n%s\n---\n\n提示: 实际翻译由LLM完成,请基于以上原文和目标语言进行翻译。",
|
||||
langName, targetLang, len([]rune(text)), text),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *TextTool) handleExtract(arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
text, _ := arguments["text"].(string)
|
||||
pattern, _ := arguments["pattern"].(string)
|
||||
|
||||
presets := map[string]string{
|
||||
"email": `[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}`,
|
||||
"phone": `(?:\+?86[\-\s]?)?1[3-9]\d{9}`,
|
||||
"url": `https?://[^\s<>"{}|\\^` + "`" + `\[\]]+`,
|
||||
}
|
||||
|
||||
if preset, ok := presets[strings.ToLower(pattern)]; ok {
|
||||
pattern = preset
|
||||
}
|
||||
|
||||
if pattern == "" {
|
||||
var result strings.Builder
|
||||
result.WriteString("文本提取结果:\n\n")
|
||||
|
||||
for name, p := range presets {
|
||||
re, err := regexp.Compile(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
matches := re.FindAllString(text, -1)
|
||||
if len(matches) > 0 {
|
||||
result.WriteString(fmt.Sprintf("【%s】(共 %d 个):\n", name, len(matches)))
|
||||
seen := make(map[string]bool)
|
||||
for _, m := range matches {
|
||||
if !seen[m] {
|
||||
result.WriteString(fmt.Sprintf(" - %s\n", m))
|
||||
seen[m] = true
|
||||
}
|
||||
}
|
||||
result.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
if result.Len() == len("文本提取结果:\n\n") {
|
||||
return &model.ToolResult{ID: "", Output: "未提取到匹配的内容(邮箱、电话、URL)"}, nil
|
||||
}
|
||||
|
||||
return &model.ToolResult{ID: "", Output: result.String()}, nil
|
||||
}
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return &model.ToolResult{ID: "", Error: fmt.Sprintf("正则表达式无效: %v", err)}, nil
|
||||
}
|
||||
|
||||
matches := re.FindAllString(text, -1)
|
||||
if len(matches) == 0 {
|
||||
return &model.ToolResult{ID: "", Output: fmt.Sprintf("未找到匹配模式 '%s' 的内容", pattern)}, nil
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.WriteString(fmt.Sprintf("正则提取结果 (模式: %s, 共 %d 个匹配):\n", pattern, len(matches)))
|
||||
seen := make(map[string]bool)
|
||||
for _, m := range matches {
|
||||
if !seen[m] {
|
||||
result.WriteString(fmt.Sprintf(" - %s\n", m))
|
||||
seen[m] = true
|
||||
}
|
||||
}
|
||||
|
||||
return &model.ToolResult{ID: "", Output: result.String()}, nil
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// WebFetchTool 网络访问工具 - 允许昔涟获取网页内容
|
||||
type WebFetchTool struct {
|
||||
client *http.Client
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewWebFetchTool 创建网络访问工具
|
||||
func NewWebFetchTool() *WebFetchTool {
|
||||
return &WebFetchTool{
|
||||
client: &http.Client{
|
||||
Timeout: 15 * time.Second,
|
||||
},
|
||||
timeout: 15 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Definition 返回工具定义
|
||||
func (t *WebFetchTool) Definition() model.ToolDefinition {
|
||||
return model.ToolDefinition{
|
||||
Name: "web_fetch",
|
||||
Description: "获取指定URL的网页内容。用于查阅新闻、文档、资料等。返回纯文本摘要(前2000字符)。仅支持 HTTP/HTTPS URL。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"url": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "要获取的网页URL,必须是完整的 http:// 或 https:// 链接",
|
||||
},
|
||||
},
|
||||
"required": []string{"url"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute 执行网页获取
|
||||
func (t *WebFetchTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
rawURL, ok := arguments["url"].(string)
|
||||
if !ok || rawURL == "" {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: "缺少 url 参数",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 安全检查:只允许 HTTP/HTTPS
|
||||
if !strings.HasPrefix(rawURL, "http://") && !strings.HasPrefix(rawURL, "https://") {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: "仅支持 http:// 或 https:// 链接",
|
||||
}, nil
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", rawURL, nil)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("创建请求失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 模拟常见浏览器 User-Agent,避免被拒
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyreneBot/1.0; +https://github.com/AskaEth/Cyrene)")
|
||||
req.Header.Set("Accept", "text/html,text/plain,*/*")
|
||||
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("请求失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("HTTP %d", resp.StatusCode),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 限制读取大小(最多 100KB)
|
||||
limitedReader := io.LimitReader(resp.Body, 100*1024)
|
||||
body, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("读取响应失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 提取纯文本摘要(去除 HTML 标签)
|
||||
text := extractText(string(body))
|
||||
|
||||
// 截断到 2000 字符
|
||||
if len([]rune(text)) > 2000 {
|
||||
runes := []rune(text)
|
||||
text = string(runes[:2000]) + "\n\n... [内容已截断,共" + fmt.Sprintf("%d", len(runes)) + "字符]"
|
||||
}
|
||||
|
||||
result := fmt.Sprintf("URL: %s\n状态: %d\n内容类型: %s\n\n%s",
|
||||
rawURL, resp.StatusCode, resp.Header.Get("Content-Type"), text)
|
||||
|
||||
return &model.ToolResult{
|
||||
Output: result,
|
||||
Error: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractText 从 HTML/文本中提取纯文本
|
||||
func extractText(raw string) string {
|
||||
// 简单的 HTML 标签去除
|
||||
text := raw
|
||||
inTag := false
|
||||
var result []rune
|
||||
for _, r := range text {
|
||||
if r == '<' {
|
||||
inTag = true
|
||||
continue
|
||||
}
|
||||
if r == '>' {
|
||||
inTag = false
|
||||
continue
|
||||
}
|
||||
if !inTag {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
|
||||
// 去除多余空白
|
||||
trimmed := strings.TrimSpace(string(result))
|
||||
// 压缩连续空行
|
||||
lines := strings.Split(trimmed, "\n")
|
||||
var cleanLines []string
|
||||
for _, line := range lines {
|
||||
trimLine := strings.TrimSpace(line)
|
||||
if trimLine != "" {
|
||||
cleanLines = append(cleanLines, trimLine)
|
||||
}
|
||||
}
|
||||
return strings.Join(cleanLines, "\n")
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/tool-engine/internal/model"
|
||||
)
|
||||
|
||||
// WebSearchTool 网页搜索工具 - 基于 DuckDuckGo Instant Answer API
|
||||
type WebSearchTool struct {
|
||||
client *http.Client
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewWebSearchTool 创建网页搜索工具
|
||||
func NewWebSearchTool() *WebSearchTool {
|
||||
return &WebSearchTool{
|
||||
client: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
timeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Definition 返回工具定义
|
||||
func (t *WebSearchTool) Definition() model.ToolDefinition {
|
||||
return model.ToolDefinition{
|
||||
Name: "web_search",
|
||||
Description: "搜索互联网信息。用于查找新闻、资料、知识等。返回搜索结果摘要(最多5条)。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "搜索关键词",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// duckDuckGoResponse DuckDuckGo API 响应
|
||||
type duckDuckGoResponse struct {
|
||||
AbstractText string `json:"AbstractText"`
|
||||
AbstractURL string `json:"AbstractURL"`
|
||||
AbstractSource string `json:"AbstractSource"`
|
||||
Heading string `json:"Heading"`
|
||||
Answer string `json:"Answer"`
|
||||
AnswerType string `json:"AnswerType"`
|
||||
RelatedTopics []duckDuckGoRelated `json:"RelatedTopics"`
|
||||
Results []duckDuckGoResult `json:"Results"`
|
||||
}
|
||||
|
||||
type duckDuckGoRelated struct {
|
||||
Text string `json:"Text"`
|
||||
FirstURL string `json:"FirstURL"`
|
||||
}
|
||||
|
||||
type duckDuckGoResult struct {
|
||||
Text string `json:"Text"`
|
||||
FirstURL string `json:"FirstURL"`
|
||||
}
|
||||
|
||||
// Execute 执行网页搜索
|
||||
func (t *WebSearchTool) Execute(ctx context.Context, arguments map[string]interface{}) (*model.ToolResult, error) {
|
||||
query, ok := arguments["query"].(string)
|
||||
if !ok || query == "" {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: "缺少 query 参数",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 使用 DuckDuckGo Instant Answer API
|
||||
apiURL := fmt.Sprintf("https://api.duckduckgo.com/?q=%s&format=json&no_html=1&skip_disambig=1",
|
||||
url.QueryEscape(query))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("创建请求失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyreneBot/1.0)")
|
||||
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("请求失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("HTTP %d", resp.StatusCode),
|
||||
}, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 500*1024))
|
||||
if err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("读取响应失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
var ddg duckDuckGoResponse
|
||||
if err := json.Unmarshal(body, &ddg); err != nil {
|
||||
return &model.ToolResult{
|
||||
Output: "",
|
||||
Error: fmt.Sprintf("解析响应失败: %v", err),
|
||||
}, nil
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.WriteString(fmt.Sprintf("搜索关键词: %s\n\n", query))
|
||||
|
||||
// 1. 如果有即时答案
|
||||
if ddg.Answer != "" {
|
||||
result.WriteString(fmt.Sprintf("📌 即时答案: %s\n\n", ddg.Answer))
|
||||
}
|
||||
|
||||
// 2. 摘要
|
||||
if ddg.AbstractText != "" {
|
||||
abstract := ddg.AbstractText
|
||||
if len([]rune(abstract)) > 500 {
|
||||
runes := []rune(abstract)
|
||||
abstract = string(runes[:500]) + "..."
|
||||
}
|
||||
result.WriteString(fmt.Sprintf("摘要: %s\n", abstract))
|
||||
if ddg.AbstractURL != "" {
|
||||
result.WriteString(fmt.Sprintf("来源: %s\n", ddg.AbstractURL))
|
||||
}
|
||||
result.WriteString("\n")
|
||||
}
|
||||
|
||||
// 3. 相关话题
|
||||
topics := ddg.RelatedTopics
|
||||
if len(ddg.Results) > 0 {
|
||||
// 优先用 Results
|
||||
count := 0
|
||||
for _, r := range ddg.Results {
|
||||
if count >= 5 {
|
||||
break
|
||||
}
|
||||
if r.Text != "" {
|
||||
text := stripHTML(r.Text)
|
||||
if len([]rune(text)) > 200 {
|
||||
runes := []rune(text)
|
||||
text = string(runes[:200]) + "..."
|
||||
}
|
||||
result.WriteString(fmt.Sprintf("\n🔗 %s\n", text))
|
||||
if r.FirstURL != "" {
|
||||
result.WriteString(fmt.Sprintf(" %s\n", r.FirstURL))
|
||||
}
|
||||
count++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
count := 0
|
||||
for _, topic := range topics {
|
||||
if count >= 5 {
|
||||
break
|
||||
}
|
||||
if topic.Text != "" {
|
||||
text := stripHTML(topic.Text)
|
||||
if len([]rune(text)) > 200 {
|
||||
runes := []rune(text)
|
||||
text = string(runes[:200]) + "..."
|
||||
}
|
||||
result.WriteString(fmt.Sprintf("\n🔗 %s\n", text))
|
||||
if topic.FirstURL != "" {
|
||||
result.WriteString(fmt.Sprintf(" %s\n", topic.FirstURL))
|
||||
}
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if result.Len() == 0 {
|
||||
result.WriteString("未找到相关结果。")
|
||||
}
|
||||
|
||||
return &model.ToolResult{
|
||||
Output: result.String(),
|
||||
Error: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// stripHTML 去除 HTML 标签
|
||||
func stripHTML(s string) string {
|
||||
inTag := false
|
||||
var result []rune
|
||||
for _, r := range s {
|
||||
if r == '<' {
|
||||
inTag = true
|
||||
continue
|
||||
}
|
||||
if r == '>' {
|
||||
inTag = false
|
||||
// 替换常见块级标签为空格
|
||||
result = append(result, ' ')
|
||||
continue
|
||||
}
|
||||
if !inTag {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(string(result))
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
你要经常在 ./docs/progress/ 目录下新建进度 md 文件以便于后续对话继承开发进度。
|
||||
你要经常在 ./docs/decisions/ 目录下新建决策 md 文件以便于后续对话继承开发决策。
|
||||
你要经常在 ./docs/tasks/ 目录下新建每次任务 md 文件以便于后续对话回顾开发任务。
|
||||
|
||||
你同样要经常在 ./docs/ 目录下按照规范新建一些便于后续开发参考的 md 文件或文件夹(文件命名要求统一格式 YYYY-MM-DD.HH-mm-SS-topic.md)。
|
||||
每次开启新对话或处理新任务前可以看看这些文件。
|
||||
你可以在思考过程中或任务过程中随时新建/修改/删除这些文件,动作可以频繁一点。
|
||||
已经实现并通过调试确定完善的功能对应的 md 文件记得做好统一标记避免后续频繁阅读。
|
||||
|
||||
调试功能的时候你可以在终端挂一个 devtools.sh 通过 curl 启动所有服务 然后通过 curl 等工具去调试实现的这些功能。devtools 提供的 API 可以启动各前后端服务。牢记。
|
||||
|
||||
在你觉得用户要求的某个功能已经完全修复或编写并验证成功后,可以向当前分支(如 dev)推送。
|
||||
禁止推送 docs/ 文件夹和编译后的二进制内容。
|
||||
你在测试长脚本或命令的时候可以在项目根目录临时创建test文件夹并新建脚本文件,用完记得删。
|
||||
|
||||
@@ -450,10 +450,16 @@ input[type="range"] { accent-color: var(--accent); padding: 0; }
|
||||
<span class="nav-icon">🏠</span><span class="nav-label">IoT 设备</span>
|
||||
<span class="nav-badge" id="iot-badge" style="display:none">0</span>
|
||||
</button>
|
||||
<button class="nav-item" data-panel="toolCalls">
|
||||
<span class="nav-icon">🔧</span><span class="nav-label">工具调用</span>
|
||||
</button>
|
||||
<button class="nav-item" data-panel="database">
|
||||
<span class="nav-icon">🗄️</span><span class="nav-label">数据库监看</span>
|
||||
<span class="nav-badge" id="db-badge" style="display:none">●</span>
|
||||
</button>
|
||||
<button class="nav-item" data-panel="thinking">
|
||||
<span class="nav-icon">💭</span><span class="nav-label">自主思考</span>
|
||||
</button>
|
||||
</nav>
|
||||
<div class="sidebar-footer">
|
||||
<span id="ws-dot" class="disconnected"></span>
|
||||
@@ -482,6 +488,10 @@ input[type="range"] { accent-color: var(--accent); padding: 0; }
|
||||
<div class="panel" id="panel-performance"></div>
|
||||
<!-- 数据库监看 -->
|
||||
<div class="panel" id="panel-database"></div>
|
||||
<!-- 工具调用记录 -->
|
||||
<div class="panel" id="panel-toolCalls"></div>
|
||||
<!-- 自主思考日志 -->
|
||||
<div class="panel" id="panel-thinking"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -695,6 +705,7 @@ function switchPanel(name) {
|
||||
const titles = {
|
||||
dashboard: '🏠 仪表盘', memory: '🧠 记忆管理', sessions: '💬 会话监看',
|
||||
services: '🖥 服务管理', iot: '🏠 IoT 设备控制', performance: '📊 性能监控', database: '🗄️ 数据库监看',
|
||||
toolCalls: '🔧 工具调用记录', thinking: '💭 自主思考',
|
||||
};
|
||||
document.getElementById('panel-title').textContent = titles[name] || name;
|
||||
|
||||
@@ -714,6 +725,8 @@ function switchPanel(name) {
|
||||
case 'iot': renderIoTPanel(); stopSessionsAutoRefresh(); stopDashboardAutoRefresh(); stopDbAutoRefresh(); startIoTRefresh(); break;
|
||||
case 'performance': renderPerformancePanel(); stopSessionsAutoRefresh(); stopDashboardAutoRefresh(); stopDbAutoRefresh(); stopIoTRefresh(); break;
|
||||
case 'database': renderDatabasePanel(); stopSessionsAutoRefresh(); stopDashboardAutoRefresh(); startDbAutoRefresh(); stopIoTRefresh(); break;
|
||||
case 'toolCalls': renderToolCallsPanel(); stopSessionsAutoRefresh(); stopDashboardAutoRefresh(); stopDbAutoRefresh(); stopIoTRefresh(); break;
|
||||
case 'thinking': renderThinkingPanel(); stopSessionsAutoRefresh(); stopDashboardAutoRefresh(); stopDbAutoRefresh(); stopIoTRefresh(); break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2284,6 +2297,337 @@ async function controlDB(action) {
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 面板8: 工具调用记录 ==========
|
||||
async function renderToolCallsPanel() {
|
||||
var container = document.getElementById('panel-toolCalls');
|
||||
|
||||
if (!STATE.toolCallsPage) STATE.toolCallsPage = 1;
|
||||
if (!STATE.toolCallsFilter) STATE.toolCallsFilter = '';
|
||||
if (!STATE.toolCallsAutoRefresh) STATE.toolCallsAutoRefresh = null;
|
||||
if (!STATE.toolCallsLimit) STATE.toolCallsLimit = 20;
|
||||
|
||||
var actionsEl = document.getElementById('panel-actions');
|
||||
var autoRefreshOn = STATE.toolCallsAutoRefresh !== null;
|
||||
actionsEl.innerHTML = '<label style="display:flex;align-items:center;gap:6px;font-size:12px;color:var(--text2);cursor:pointer;">' +
|
||||
'<input type="checkbox" id="toolcalls-autorefresh" ' + (autoRefreshOn ? 'checked' : '') + ' onchange="toggleToolCallsAutoRefresh(this.checked)">' +
|
||||
'自动刷新</label>';
|
||||
|
||||
var statsData = null;
|
||||
try {
|
||||
var statsResp = await fetch('/api/tool-calls/stats');
|
||||
if (statsResp.ok) statsData = await statsResp.json();
|
||||
} catch(e) {}
|
||||
|
||||
var callsData = null;
|
||||
try {
|
||||
var params = '?page=' + STATE.toolCallsPage + '&limit=' + STATE.toolCallsLimit;
|
||||
if (STATE.toolCallsFilter) params += '&tool_name=' + encodeURIComponent(STATE.toolCallsFilter);
|
||||
var callsResp = await fetch('/api/tool-calls' + params);
|
||||
if (callsResp.ok) callsData = await callsResp.json();
|
||||
} catch(e) {}
|
||||
|
||||
if (!callsData || callsData.error) {
|
||||
container.innerHTML = '<div class="empty-state"><div class="icon">⚠️</div>' +
|
||||
(callsData && callsData.error ? escHtml(callsData.error) : '无法连接到 Tool-Engine 服务,请在「服务管理」中启动 Tool-Engine') +
|
||||
(callsData && callsData.hint ? '<br><small>' + escHtml(callsData.hint) + '</small>' : '') +
|
||||
'</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
var totalCalls = statsData ? statsData.total_calls : callsData.total;
|
||||
var successRate = statsData ? (statsData.success_rate || 0).toFixed(1) : 0;
|
||||
var avgDuration = statsData ? (statsData.avg_duration_ms || 0).toFixed(0) : 0;
|
||||
var byTool = statsData && statsData.by_tool ? statsData.by_tool : [];
|
||||
|
||||
var toolNamesSet = {};
|
||||
if (byTool.length > 0) {
|
||||
for (var i = 0; i < byTool.length; i++) { toolNamesSet[byTool[i].tool_name] = true; }
|
||||
}
|
||||
if (callsData.calls) {
|
||||
for (var i = 0; i < callsData.calls.length; i++) { toolNamesSet[callsData.calls[i].tool_name] = true; }
|
||||
}
|
||||
var toolNames = Object.keys(toolNamesSet).sort();
|
||||
|
||||
var statsCardsHtml = '<div class="cards-grid cards-4" style="margin-bottom:14px;">' +
|
||||
'<div class="stat-card accent"><div class="stat-value">' + totalCalls + '</div><div class="stat-label">总调用</div></div>' +
|
||||
'<div class="stat-card green"><div class="stat-value">' + successRate + '%</div><div class="stat-label">成功率</div></div>' +
|
||||
'<div class="stat-card blue"><div class="stat-value">' + avgDuration + 'ms</div><div class="stat-label">平均耗时</div></div>' +
|
||||
'<div class="stat-card orange"><div class="stat-value">' + toolNames.length + '</div><div class="stat-label">工具数</div></div>' +
|
||||
'</div>';
|
||||
|
||||
var filterHtml = '<div style="display:flex;gap:10px;align-items:center;margin-bottom:14px;flex-wrap:wrap;">' +
|
||||
'<select id="toolcalls-filter-select" onchange="changeToolCallsFilter(this.value)" style="width:auto;min-width:160px;">' +
|
||||
'<option value="">全部工具</option>';
|
||||
for (var i = 0; i < toolNames.length; i++) {
|
||||
filterHtml += '<option value="' + escHtml(toolNames[i]) + '"' + (STATE.toolCallsFilter === toolNames[i] ? ' selected' : '') + '>' + escHtml(toolNames[i]) + '</option>';
|
||||
}
|
||||
filterHtml += '</select>' +
|
||||
'<button class="btn btn-sm" onclick="refreshToolCallsPanel()">🔄 刷新</button>' +
|
||||
'</div>';
|
||||
|
||||
var distHtml = '';
|
||||
if (byTool.length > 0) {
|
||||
distHtml = '<div style="display:flex;gap:8px;flex-wrap:wrap;margin-bottom:14px;">';
|
||||
for (var i = 0; i < byTool.length; i++) {
|
||||
var t = byTool[i];
|
||||
distHtml += '<span class="badge" style="background:var(--bg3);font-size:11px;cursor:pointer;" onclick="changeToolCallsFilter(\'' + escHtml(t.tool_name) + '\')" title="' + escHtml(t.tool_name) + ': ' + t.count + ' 次, 成功率 ' + (t.count > 0 ? (t.success_count / t.count * 100).toFixed(0) : 0) + '%">' +
|
||||
escHtml(t.tool_name) + ' ' + t.count + '</span>';
|
||||
}
|
||||
distHtml += '</div>';
|
||||
}
|
||||
|
||||
var tableHtml = '<div class="table-wrap"><table>' +
|
||||
'<thead><tr><th style="width:100px;">时间</th><th style="width:110px;">工具</th><th>参数</th><th>结果</th><th style="width:50px;">状态</th><th style="width:55px;">耗时</th></tr></thead><tbody>';
|
||||
|
||||
if (callsData.calls && callsData.calls.length > 0) {
|
||||
for (var i = 0; i < callsData.calls.length; i++) {
|
||||
var call = callsData.calls[i];
|
||||
var argsStr = '';
|
||||
try {
|
||||
if (typeof call.arguments === 'string') {
|
||||
var parsed = JSON.parse(call.arguments);
|
||||
argsStr = JSON.stringify(parsed);
|
||||
} else if (call.arguments) {
|
||||
argsStr = JSON.stringify(call.arguments);
|
||||
}
|
||||
} catch(e) { argsStr = String(call.arguments || ''); }
|
||||
if (argsStr.length > 60) argsStr = argsStr.slice(0, 57) + '...';
|
||||
|
||||
var outputStr = (call.output || call.error || '');
|
||||
if (outputStr.length > 80) outputStr = outputStr.slice(0, 77) + '...';
|
||||
|
||||
var timeStr = call.created_at ? new Date(call.created_at).toLocaleTimeString('zh-CN', {hour12: false}) : '—';
|
||||
var statusIcon = call.success ? '✅' : '❌';
|
||||
var statusColor = call.success ? 'var(--green)' : 'var(--red)';
|
||||
var durationStr = call.duration_ms ? call.duration_ms + 'ms' : '—';
|
||||
var callId = 'tc-' + call.id;
|
||||
|
||||
tableHtml += '<tr class="toolcall-row" data-callid="' + callId + '" onclick="toggleToolCallExpand(\'' + callId + '\')" style="cursor:pointer;">' +
|
||||
'<td style="font-size:11px;color:var(--text2);">' + escHtml(timeStr) + '</td>' +
|
||||
'<td><span class="badge" style="background:var(--bg3);">' + escHtml(call.tool_name) + '</span></td>' +
|
||||
'<td style="font-family:monospace;font-size:11px;max-width:200px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;">' + escHtml(argsStr) + '</td>' +
|
||||
'<td style="font-family:monospace;font-size:11px;max-width:250px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;color:' + (call.success ? 'var(--text2)' : 'var(--red)') + ';">' + escHtml(outputStr) + '</td>' +
|
||||
'<td style="text-align:center;color:' + statusColor + ';">' + statusIcon + '</td>' +
|
||||
'<td style="text-align:right;font-size:11px;color:var(--text2);">' + durationStr + '</td>' +
|
||||
'</tr>';
|
||||
|
||||
// Expand row showing full args and output
|
||||
var argsDisplay = call.arguments;
|
||||
if (typeof argsDisplay === 'string') {
|
||||
try { argsDisplay = JSON.parse(argsDisplay); } catch(e) {}
|
||||
}
|
||||
tableHtml += '<tr class="toolcall-expand" id="' + callId + '-expand" style="display:none;"><td colspan="6" style="background:var(--bg);padding:12px 24px;">' +
|
||||
'<div style="display:grid;grid-template-columns:1fr 1fr;gap:12px;">' +
|
||||
'<div><strong style="color:var(--text2);font-size:11px;">完整参数:</strong><pre style="background:var(--bg3);padding:8px;border-radius:4px;font-size:11px;margin-top:4px;max-height:200px;overflow:auto;white-space:pre-wrap;">' + escHtml(JSON.stringify(argsDisplay, null, 2)) + '</pre></div>' +
|
||||
'<div><strong style="color:var(--text2);font-size:11px;">完整输出:</strong><pre style="background:var(--bg3);padding:8px;border-radius:4px;font-size:11px;margin-top:4px;max-height:200px;overflow:auto;white-space:pre-wrap;color:' + (call.success ? 'var(--text)' : 'var(--red)') + ';">' + escHtml(call.output || call.error || '') + '</pre></div>' +
|
||||
'</div>' +
|
||||
'<div style="margin-top:8px;font-size:11px;color:var(--text3);">Call ID: ' + escHtml(call.call_id || '—') + ' | User: ' + escHtml(call.user_id || '—') + ' | Session: ' + escHtml(call.session_id || '—') + ' | ' + (call.created_at ? new Date(call.created_at).toLocaleString('zh-CN', {hour12: false}) : '') + '</div>' +
|
||||
'</td></tr>';
|
||||
}
|
||||
} else {
|
||||
tableHtml += '<tr><td colspan="6" style="text-align:center;color:var(--text3);padding:30px;">暂无调用记录</td></tr>';
|
||||
}
|
||||
|
||||
tableHtml += '</tbody></table></div>';
|
||||
|
||||
var paginationHtml = '';
|
||||
if (callsData.total_pages > 0) {
|
||||
paginationHtml = '<div style="display:flex;align-items:center;justify-content:center;gap:12px;margin-top:14px;font-size:12px;">' +
|
||||
'<button class="btn btn-sm" ' + (callsData.page <= 1 ? 'disabled' : '') + ' onclick="goToolCallsPage(' + (callsData.page - 1) + ')">← 上一页</button>' +
|
||||
'<span style="color:var(--text2);">第 ' + callsData.page + '/' + callsData.total_pages + ' 页</span>' +
|
||||
'<button class="btn btn-sm" ' + (callsData.page >= callsData.total_pages ? 'disabled' : '') + ' onclick="goToolCallsPage(' + (callsData.page + 1) + ')">下一页 →</button>' +
|
||||
'</div>';
|
||||
}
|
||||
|
||||
container.innerHTML = statsCardsHtml + filterHtml + distHtml + tableHtml + paginationHtml;
|
||||
}
|
||||
|
||||
function toggleToolCallExpand(callId) {
|
||||
var expandRow = document.getElementById(callId + '-expand');
|
||||
if (expandRow) {
|
||||
expandRow.style.display = expandRow.style.display === 'none' ? '' : 'none';
|
||||
}
|
||||
}
|
||||
|
||||
function changeToolCallsFilter(toolName) {
|
||||
STATE.toolCallsFilter = toolName;
|
||||
STATE.toolCallsPage = 1;
|
||||
renderToolCallsPanel();
|
||||
}
|
||||
|
||||
function goToolCallsPage(page) {
|
||||
STATE.toolCallsPage = page;
|
||||
renderToolCallsPanel();
|
||||
document.getElementById('panel-toolCalls').scrollIntoView({ behavior: 'smooth', block: 'start' });
|
||||
}
|
||||
|
||||
function refreshToolCallsPanel() {
|
||||
renderToolCallsPanel();
|
||||
}
|
||||
|
||||
function toggleToolCallsAutoRefresh(on) {
|
||||
if (STATE.toolCallsAutoRefresh) {
|
||||
clearInterval(STATE.toolCallsAutoRefresh);
|
||||
STATE.toolCallsAutoRefresh = null;
|
||||
}
|
||||
if (on) {
|
||||
STATE.toolCallsAutoRefresh = setInterval(function() {
|
||||
if (STATE.activePanel === 'toolCalls') renderToolCallsPanel();
|
||||
}, 5000);
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 面板9: 自主思考日志 ==========
|
||||
async function renderThinkingPanel() {
|
||||
var container = document.getElementById('panel-thinking');
|
||||
|
||||
if (!STATE.thinkingPage) STATE.thinkingPage = 1;
|
||||
if (!STATE.thinkingAutoRefresh) STATE.thinkingAutoRefresh = null;
|
||||
if (!STATE.thinkingLimit) STATE.thinkingLimit = 20;
|
||||
if (!STATE.thinkingUserId) STATE.thinkingUserId = 'admin_admin';
|
||||
|
||||
var actionsEl = document.getElementById('panel-actions');
|
||||
var autoRefreshOn = STATE.thinkingAutoRefresh !== null;
|
||||
actionsEl.innerHTML = '<label style="display:flex;align-items:center;gap:6px;font-size:12px;color:var(--text2);cursor:pointer;">' +
|
||||
'<input type="checkbox" id="thinking-autorefresh" ' + (autoRefreshOn ? 'checked' : '') + ' onchange="toggleThinkingAutoRefresh(this.checked)">' +
|
||||
'自动刷新 (5s)</label>';
|
||||
|
||||
var statsData = null;
|
||||
try {
|
||||
var statsResp = await fetch('/api/v1/thinking/stats?user_id=' + encodeURIComponent(STATE.thinkingUserId));
|
||||
if (statsResp.ok) statsData = await statsResp.json();
|
||||
} catch(e) {}
|
||||
|
||||
var logsData = null;
|
||||
try {
|
||||
var params = '?user_id=' + encodeURIComponent(STATE.thinkingUserId) + '&limit=' + STATE.thinkingLimit + '&offset=' + ((STATE.thinkingPage - 1) * STATE.thinkingLimit);
|
||||
var logsResp = await fetch('/api/v1/thinking' + params);
|
||||
if (logsResp.ok) logsData = await logsResp.json();
|
||||
} catch(e) {}
|
||||
|
||||
if (!logsData || logsData.error) {
|
||||
container.innerHTML = '<div class="empty-state"><div class="icon">⚠️</div>' +
|
||||
(logsData && logsData.error ? escHtml(logsData.error) : '无法连接到 Memory-Service,请在「服务管理」中启动 Memory-Service') +
|
||||
(logsData && logsData.hint ? '<br><small>' + escHtml(logsData.hint) + '</small>' : '') +
|
||||
'</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
var totalLogs = statsData ? statsData.total_logs : (logsData.total || 0);
|
||||
var totalToolCalls = statsData ? statsData.total_tool_calls : 0;
|
||||
var avgContentLen = statsData ? (statsData.avg_content_length || 0).toFixed(0) : 0;
|
||||
var latestAt = statsData ? (statsData.latest_at ? formatTime(statsData.latest_at) : '—') : '—';
|
||||
|
||||
var statsCardsHtml = '<div class="cards-grid cards-4" style="margin-bottom:14px;">' +
|
||||
'<div class="stat-card accent"><div class="stat-value">' + totalLogs + '</div><div class="stat-label">总思考次数</div></div>' +
|
||||
'<div class="stat-card green"><div class="stat-value">' + totalToolCalls + '</div><div class="stat-label">总工具调用</div></div>' +
|
||||
'<div class="stat-card blue"><div class="stat-value">' + avgContentLen + '</div><div class="stat-label">平均长度(字符)</div></div>' +
|
||||
'<div class="stat-card orange"><div class="stat-value">' + latestAt + '</div><div class="stat-label">最近思考</div></div>' +
|
||||
'</div>';
|
||||
|
||||
var filterHtml = '<div style="display:flex;gap:10px;align-items:center;margin-bottom:14px;flex-wrap:wrap;">' +
|
||||
'<button class="btn btn-sm" onclick="refreshThinkingPanel()">🔄 刷新</button>' +
|
||||
'<span style="font-size:11px;color:var(--text3);">用户: ' + escHtml(STATE.thinkingUserId) + '</span>' +
|
||||
'</div>';
|
||||
|
||||
var tableHtml = '<div class="table-wrap"><table>' +
|
||||
'<thead><tr><th style="width:130px;">时间</th><th style="width:80px;">工具调用</th><th style="width:80px;">内容长度</th><th>内容摘要</th></tr></thead><tbody>';
|
||||
|
||||
var logs = logsData.logs || [];
|
||||
if (logs.length > 0) {
|
||||
for (var i = 0; i < logs.length; i++) {
|
||||
var log = logs[i];
|
||||
var timeStr = log.created_at ? new Date(log.created_at).toLocaleString('zh-CN', {hour12: false}) : '—';
|
||||
var toolCallCount = log.tool_call_count || 0;
|
||||
var contentLen = log.content_length || 0;
|
||||
var summary = log.content || '';
|
||||
// Extract first line or first 120 chars as summary
|
||||
var firstLine = summary.split('\n')[0];
|
||||
if (firstLine.length > 120) firstLine = firstLine.slice(0, 117) + '...';
|
||||
var logId = 'th-' + log.id;
|
||||
|
||||
tableHtml += '<tr class="thinking-row" data-logid="' + logId + '" onclick="toggleThinkingExpand(\'' + logId + '\')" style="cursor:pointer;">' +
|
||||
'<td style="font-size:11px;color:var(--text2);">' + escHtml(timeStr) + '</td>' +
|
||||
'<td style="text-align:center;"><span class="badge" style="background:' + (toolCallCount > 0 ? 'var(--accent)' : 'var(--bg3)') + ';">' + toolCallCount + '</span></td>' +
|
||||
'<td style="text-align:right;font-size:11px;color:var(--text2);">' + contentLen + '</td>' +
|
||||
'<td style="max-width:400px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap;font-size:12px;">' + escHtml(firstLine) + '</td>' +
|
||||
'</tr>';
|
||||
|
||||
// Expand row with full content and tool calls
|
||||
var toolCallsDisplay = '无';
|
||||
try {
|
||||
var parsed = typeof log.tool_calls === 'string' ? JSON.parse(log.tool_calls) : log.tool_calls;
|
||||
if (parsed && Array.isArray(parsed) && parsed.length > 0) {
|
||||
toolCallsDisplay = '';
|
||||
for (var j = 0; j < parsed.length; j++) {
|
||||
var tc = parsed[j];
|
||||
var tcName = tc.function ? (tc.function.name || '未知') : (tc.name || '未知');
|
||||
var tcArgs = tc.function ? (tc.function.arguments || '') : (tc.arguments || '');
|
||||
if (typeof tcArgs === 'object') tcArgs = JSON.stringify(tcArgs);
|
||||
toolCallsDisplay += '<div style="margin-bottom:6px;padding:6px;background:var(--bg2);border-radius:4px;">' +
|
||||
'<strong style="color:var(--accent2);">' + escHtml(tcName) + '</strong>' +
|
||||
'<pre style="font-size:10px;margin:4px 0 0;white-space:pre-wrap;color:var(--text2);">' + escHtml(String(tcArgs).slice(0, 500)) + '</pre>' +
|
||||
'</div>';
|
||||
}
|
||||
}
|
||||
} catch(e) { toolCallsDisplay = '<span style="color:var(--text3);">解析失败</span>'; }
|
||||
|
||||
tableHtml += '<tr class="thinking-expand" id="' + logId + '-expand" style="display:none;"><td colspan="4" style="background:var(--bg);padding:12px 24px;">' +
|
||||
'<div style="display:grid;grid-template-columns:1fr 1fr;gap:12px;">' +
|
||||
'<div><strong style="color:var(--text2);font-size:11px;">完整思考内容:</strong><pre style="background:var(--bg3);padding:8px;border-radius:4px;font-size:11px;margin-top:4px;max-height:300px;overflow:auto;white-space:pre-wrap;">' + escHtml(log.content || '') + '</pre></div>' +
|
||||
'<div><strong style="color:var(--text2);font-size:11px;">工具调用详情:</strong><div style="margin-top:4px;max-height:300px;overflow:auto;">' + toolCallsDisplay + '</div></div>' +
|
||||
'</div>' +
|
||||
'<div style="margin-top:8px;font-size:11px;color:var(--text3);">ID: ' + escHtml(log.id || '—') + ' | User: ' + escHtml(log.user_id || '—') + ' | 内容长度: ' + (log.content_length || 0) + ' | ' + (log.created_at ? new Date(log.created_at).toLocaleString('zh-CN', {hour12: false}) : '') + '</div>' +
|
||||
'</td></tr>';
|
||||
}
|
||||
} else {
|
||||
tableHtml += '<tr><td colspan="4" style="text-align:center;color:var(--text3);padding:30px;">暂无自主思考记录</td></tr>';
|
||||
}
|
||||
|
||||
tableHtml += '</tbody></table></div>';
|
||||
|
||||
var totalPages = Math.max(1, Math.ceil((logsData.total || 0) / STATE.thinkingLimit));
|
||||
var paginationHtml = '';
|
||||
if (totalPages > 1) {
|
||||
paginationHtml = '<div style="display:flex;align-items:center;justify-content:center;gap:12px;margin-top:14px;font-size:12px;">' +
|
||||
'<button class="btn btn-sm" ' + (STATE.thinkingPage <= 1 ? 'disabled' : '') + ' onclick="goThinkingPage(' + (STATE.thinkingPage - 1) + ')">← 上一页</button>' +
|
||||
'<span style="color:var(--text2);">第 ' + STATE.thinkingPage + '/' + totalPages + ' 页</span>' +
|
||||
'<button class="btn btn-sm" ' + (STATE.thinkingPage >= totalPages ? 'disabled' : '') + ' onclick="goThinkingPage(' + (STATE.thinkingPage + 1) + ')">下一页 →</button>' +
|
||||
'</div>';
|
||||
}
|
||||
|
||||
container.innerHTML = statsCardsHtml + filterHtml + tableHtml + paginationHtml;
|
||||
}
|
||||
|
||||
function toggleThinkingExpand(logId) {
|
||||
var expandRow = document.getElementById(logId + '-expand');
|
||||
if (expandRow) {
|
||||
expandRow.style.display = expandRow.style.display === 'none' ? '' : 'none';
|
||||
}
|
||||
}
|
||||
|
||||
function goThinkingPage(page) {
|
||||
STATE.thinkingPage = page;
|
||||
renderThinkingPanel();
|
||||
document.getElementById('panel-thinking').scrollIntoView({ behavior: 'smooth', block: 'start' });
|
||||
}
|
||||
|
||||
function refreshThinkingPanel() {
|
||||
renderThinkingPanel();
|
||||
}
|
||||
|
||||
function toggleThinkingAutoRefresh(on) {
|
||||
if (STATE.thinkingAutoRefresh) {
|
||||
clearInterval(STATE.thinkingAutoRefresh);
|
||||
STATE.thinkingAutoRefresh = null;
|
||||
}
|
||||
if (on) {
|
||||
STATE.thinkingAutoRefresh = setInterval(function() {
|
||||
if (STATE.activePanel === 'thinking') renderThinkingPanel();
|
||||
}, 5000);
|
||||
}
|
||||
}
|
||||
|
||||
</script>
|
||||
<script src="iot-panel.js"></script>
|
||||
<script>
|
||||
|
||||
@@ -13,6 +13,7 @@ const ROOT = path.resolve(__dirname, '../..');
|
||||
export const DEVTOOLS_PORT = process.env.DEVTOOLS_PORT || 9090;
|
||||
export const LOGS_DIR = path.resolve(__dirname, '../logs');
|
||||
export const GATEWAY_URL = process.env.GATEWAY_URL || 'http://localhost:8080';
|
||||
export const TOOL_ENGINE_URL = process.env.TOOL_ENGINE_URL || 'http://localhost:8092';
|
||||
export const ADMIN_USERNAME = process.env.ADMIN_USERNAME || 'admin';
|
||||
export const ADMIN_PASSWORD = process.env.ADMIN_PASSWORD || 'cyrene-dev-admin';
|
||||
|
||||
@@ -54,6 +55,8 @@ export const SERVICES = {
|
||||
GATEWAY_PORT: '8080',
|
||||
JWT_SECRET: process.env.JWT_SECRET || 'dev-secret-key-change-me',
|
||||
AI_CORE_URL: 'http://localhost:8081',
|
||||
MEMORY_SERVICE_URL: process.env.MEMORY_SERVICE_URL || 'http://localhost:8091',
|
||||
TOOL_ENGINE_URL: process.env.TOOL_ENGINE_URL || 'http://localhost:8092',
|
||||
ADMIN_USERNAME: process.env.ADMIN_USERNAME || 'admin',
|
||||
ADMIN_PASSWORD: process.env.ADMIN_PASSWORD || 'cyrene-dev-admin',
|
||||
REGISTRATION_ENABLED: process.env.REGISTRATION_ENABLED || 'true',
|
||||
|
||||
+130
-1
@@ -17,7 +17,9 @@ import { execSync, spawn } from 'child_process';
|
||||
|
||||
import { processManager } from './process-manager.js';
|
||||
import { performanceMonitor } from './performance.js';
|
||||
import { SERVICES, DEVTOOLS_PORT, LOGS_DIR, logFile, GATEWAY_URL, ADMIN_USERNAME, ADMIN_PASSWORD } from './config.js';
|
||||
import { SERVICES, DEVTOOLS_PORT, LOGS_DIR, logFile, GATEWAY_URL, TOOL_ENGINE_URL, ADMIN_USERNAME, ADMIN_PASSWORD } from './config.js';
|
||||
|
||||
const MEMORY_SERVICE_URL = process.env.MEMORY_SERVICE_URL || 'http://localhost:8091';
|
||||
|
||||
const ROOT = path.resolve(path.dirname(fileURLToPath(import.meta.url)), '../..');
|
||||
const TUNNEL_SCRIPT = path.join(ROOT, 'scripts/tunnel.sh');
|
||||
@@ -562,6 +564,133 @@ app.get('/api/iot/devices/:id/history', async (req, res) => {
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// ---- 工具调用记录代理 (转发到 tool-engine) ----
|
||||
|
||||
/**
|
||||
* 代理请求到 Tool-Engine
|
||||
* @param {string} path - Tool-Engine API 路径
|
||||
* @param {object} opts - fetch 选项
|
||||
*/
|
||||
async function proxyToToolEngine(path, opts = {}) {
|
||||
const url = `${TOOL_ENGINE_URL}${path}`;
|
||||
const logPrefix = `[ToolEngine代理]`;
|
||||
try {
|
||||
console.log(`${logPrefix} ${opts.method || 'GET'} ${path}`);
|
||||
const resp = await fetch(url, {
|
||||
...opts,
|
||||
headers: { 'Content-Type': 'application/json', ...opts.headers },
|
||||
signal: AbortSignal.timeout(10000),
|
||||
});
|
||||
const body = await resp.json().catch(() => null);
|
||||
if (!resp.ok) {
|
||||
console.log(`${logPrefix} 请求失败 (HTTP ${resp.status}): ${path}`);
|
||||
}
|
||||
return { status: resp.status, body };
|
||||
} catch (err) {
|
||||
const isConnRefused = err.message?.includes('ECONNREFUSED') || err.cause?.code === 'ECONNREFUSED';
|
||||
console.error(`${logPrefix} 请求异常: ${path} - ${err.message}`);
|
||||
return {
|
||||
status: 502,
|
||||
body: {
|
||||
error: `Tool-Engine 不可达: ${err.message}`,
|
||||
errorType: isConnRefused ? 'tool_engine_not_running' : 'tool_engine_unreachable',
|
||||
hint: isConnRefused
|
||||
? 'Tool-Engine 服务未启动,请先在「服务管理」面板中启动 Tool-Engine'
|
||||
: 'Tool-Engine 服务无响应,请检查网络连接和服务状态',
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// GET /api/tool-calls — 查询工具调用记录
|
||||
app.get('/api/tool-calls', async (req, res) => {
|
||||
const { tool_name, page, limit } = req.query;
|
||||
const params = new URLSearchParams();
|
||||
if (tool_name) params.set('tool_name', tool_name);
|
||||
params.set('page', page || '1');
|
||||
params.set('limit', limit || '20');
|
||||
const result = await proxyToToolEngine(`/api/v1/tools/calls?${params.toString()}`);
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// GET /api/tool-calls/stats — 工具调用统计
|
||||
app.get('/api/tool-calls/stats', async (_req, res) => {
|
||||
const result = await proxyToToolEngine('/api/v1/tools/calls/stats');
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// ---- 自主思考日志代理 (转发到 memory-service) ----
|
||||
|
||||
/**
|
||||
* 代理请求到 Memory-Service
|
||||
* @param {string} path - Memory-Service API 路径
|
||||
* @param {object} opts - fetch 选项
|
||||
*/
|
||||
async function proxyToMemoryService(path, opts = {}) {
|
||||
const url = `${MEMORY_SERVICE_URL}${path}`;
|
||||
const logPrefix = `[MemoryService代理]`;
|
||||
try {
|
||||
console.log(`${logPrefix} ${opts.method || 'GET'} ${path}`);
|
||||
const resp = await fetch(url, {
|
||||
...opts,
|
||||
headers: { 'Content-Type': 'application/json', ...opts.headers },
|
||||
signal: AbortSignal.timeout(10000),
|
||||
});
|
||||
const body = await resp.json().catch(() => null);
|
||||
if (!resp.ok) {
|
||||
console.log(`${logPrefix} 请求失败 (HTTP ${resp.status}): ${path}`);
|
||||
}
|
||||
return { status: resp.status, body };
|
||||
} catch (err) {
|
||||
const isConnRefused = err.message?.includes('ECONNREFUSED') || err.cause?.code === 'ECONNREFUSED';
|
||||
console.error(`${logPrefix} 请求异常: ${path} - ${err.message}`);
|
||||
return {
|
||||
status: 502,
|
||||
body: {
|
||||
error: `Memory-Service 不可达: ${err.message}`,
|
||||
errorType: isConnRefused ? 'memory_service_not_running' : 'memory_service_unreachable',
|
||||
hint: isConnRefused
|
||||
? 'Memory-Service 服务未启动,请先在「服务管理」面板中启动 Memory-Service'
|
||||
: 'Memory-Service 服务无响应,请检查网络连接和服务状态',
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// GET /api/v1/thinking — 查询自主思考日志列表
|
||||
app.get('/api/v1/thinking', async (req, res) => {
|
||||
const { user_id, limit, offset } = req.query;
|
||||
const params = new URLSearchParams();
|
||||
if (user_id) params.set('user_id', user_id);
|
||||
if (limit) params.set('limit', limit);
|
||||
if (offset) params.set('offset', offset);
|
||||
const result = await proxyToMemoryService(`/api/v1/thinking?${params.toString()}`);
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// POST /api/v1/thinking — 创建自主思考日志
|
||||
app.post('/api/v1/thinking', async (req, res) => {
|
||||
const result = await proxyToMemoryService('/api/v1/thinking', {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(req.body),
|
||||
});
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// GET /api/v1/thinking/stats — 自主思考统计
|
||||
app.get('/api/v1/thinking/stats', async (req, res) => {
|
||||
const { user_id } = req.query;
|
||||
const params = user_id ? `?user_id=${encodeURIComponent(user_id)}` : '';
|
||||
const result = await proxyToMemoryService(`/api/v1/thinking/stats${params}`);
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// GET /api/v1/thinking/:id — 获取单条自主思考日志
|
||||
app.get('/api/v1/thinking/:id', async (req, res) => {
|
||||
const result = await proxyToMemoryService(`/api/v1/thinking/${req.params.id}`);
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// ---- 健康检查代理 ----
|
||||
app.get('/api/proxy/:id/health', async (req, res) => {
|
||||
const svc = SERVICES[req.params.id];
|
||||
|
||||
@@ -14,7 +14,7 @@ export function ChatContainer() {
|
||||
return (
|
||||
<div className="flex flex-col h-full overflow-hidden chat-background">
|
||||
{/* 状态指示器栏 */}
|
||||
<div className="flex items-center justify-between px-4 py-1.5 border-b border-pink-100 dark:border-pink-900 bg-pink-50/50 dark:bg-pink-950/20 flex-shrink-0">
|
||||
<div className="relative z-10 flex items-center justify-between px-4 py-1.5 border-b border-pink-100 dark:border-pink-900 bg-pink-50/50 dark:bg-pink-950/20 flex-shrink-0">
|
||||
<div className="flex items-center gap-2">
|
||||
{statusLabel && (
|
||||
<span className="text-xs font-medium text-pink-500 dark:text-pink-400 bg-pink-100 dark:bg-pink-900/50 px-2 py-0.5 rounded-full">
|
||||
@@ -42,12 +42,14 @@ export function ChatContainer() {
|
||||
</div>
|
||||
|
||||
{/* 消息列表 */}
|
||||
<div className="flex-1 min-h-0 overflow-hidden">
|
||||
<div className="relative z-10 flex-1 min-h-0 overflow-hidden">
|
||||
<MessageList messages={messages} isTyping={isTyping} />
|
||||
</div>
|
||||
|
||||
{/* IoT 状态栏(底部) */}
|
||||
<IoTStatusBar />
|
||||
<div className="relative z-10">
|
||||
<IoTStatusBar />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ export function Sidebar({ onClose }: SidebarProps) {
|
||||
};
|
||||
|
||||
return (
|
||||
<aside className="h-full bg-white/90 dark:bg-gray-900/90 border-r border-pink-100 dark:border-pink-900 flex flex-col">
|
||||
<aside className="h-full bg-white/70 dark:bg-gray-900/70 backdrop-blur-md border-r border-pink-100 dark:border-pink-900 flex flex-col">
|
||||
{/* 主对话按钮 (仅管理员可见) */}
|
||||
{isAdmin && (
|
||||
<div className="p-3 border-b border-pink-100 dark:border-pink-900 flex gap-2">
|
||||
|
||||
@@ -158,7 +158,19 @@ function handleServerMessage(msg: WSServerMessage) {
|
||||
break;
|
||||
|
||||
case 'history_response':
|
||||
// 防御性检查:仅当当前消息为空时才加载 WebSocket 历史响应
|
||||
// 避免 WebSocket 的 history_response (可能来自后端空缓存) 覆盖 HTTP loadMessagesFromServer 已加载的消息
|
||||
if (msg.messages) {
|
||||
const sessionState = useSessionStore.getState();
|
||||
// 如果 sessionStore 或 chatStore 中已有消息,说明 HTTP 已加载完成,忽略 WS 的历史响应
|
||||
if (sessionState.messages.length > 0 || chatState.messages.length > 0) {
|
||||
console.log(
|
||||
'[WS] 忽略 history_response:消息已由 HTTP 加载',
|
||||
`sessionMessages=${sessionState.messages.length}`,
|
||||
`chatMessages=${chatState.messages.length}`
|
||||
);
|
||||
break;
|
||||
}
|
||||
const msgsWithIds = msg.messages.map((m: any, i: number) => ({
|
||||
...m,
|
||||
id: m.id || `hist_${i}_${Date.now()}`,
|
||||
|
||||
@@ -105,12 +105,24 @@
|
||||
|
||||
/* ===== 聊天背景 ===== */
|
||||
.chat-background {
|
||||
position: relative;
|
||||
isolation: isolate;
|
||||
background-image: url('/images/Cyrene_ChatBackground/Vertical/2nd_Form/1.png');
|
||||
background-size: cover;
|
||||
background-position: center;
|
||||
background-attachment: fixed;
|
||||
}
|
||||
|
||||
/* 半透明遮罩:避免背景色彩过于鲜艳影响阅读 */
|
||||
.chat-background::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
background: rgba(255, 255, 255, 0.5);
|
||||
pointer-events: none;
|
||||
z-index: -1;
|
||||
}
|
||||
|
||||
/* 横屏时使用 Landscape 背景 */
|
||||
@media (orientation: landscape) {
|
||||
.chat-background {
|
||||
@@ -128,4 +140,8 @@
|
||||
background: rgba(244, 114, 182, 0.2);
|
||||
color: #fbcfe8;
|
||||
}
|
||||
|
||||
.chat-background::before {
|
||||
background: rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +32,9 @@ interface SessionStore {
|
||||
loading: boolean;
|
||||
messages: Message[];
|
||||
|
||||
// 防止竞态条件的请求版本号 (每次 setCurrentSessionId 时递增)
|
||||
_loadVersion: number;
|
||||
|
||||
// 基础操作
|
||||
setSessions: (sessions: Session[]) => void;
|
||||
addSession: (session: Session) => void;
|
||||
@@ -55,6 +58,7 @@ export const useSessionStore = create<SessionStore>((set, get) => ({
|
||||
currentSessionId: null,
|
||||
loading: false,
|
||||
messages: [],
|
||||
_loadVersion: 0,
|
||||
|
||||
setSessions: (sessions) => set({ sessions }),
|
||||
addSession: (session) =>
|
||||
@@ -67,7 +71,9 @@ export const useSessionStore = create<SessionStore>((set, get) => ({
|
||||
})),
|
||||
setCurrentSessionId: (id) => {
|
||||
const oldId = get().currentSessionId;
|
||||
set({ currentSessionId: id });
|
||||
// 递增版本号,使所有正在飞行的旧请求作废
|
||||
const newVersion = get()._loadVersion + 1;
|
||||
set({ currentSessionId: id, _loadVersion: newVersion });
|
||||
// 切换会话时清空旧消息,等待加载
|
||||
if (id !== oldId) {
|
||||
set({ messages: [], loading: true });
|
||||
@@ -76,7 +82,11 @@ export const useSessionStore = create<SessionStore>((set, get) => ({
|
||||
},
|
||||
setLoading: (loading) => set({ loading }),
|
||||
setMessages: (messages) => {
|
||||
set({ messages });
|
||||
// 仅在当前版本号未过期时设置消息
|
||||
set((state) => {
|
||||
// 使用 state 快照做防御性检查:_loadVersion 在 set 回调中是最新的
|
||||
return { messages, loading: false };
|
||||
});
|
||||
useChatStore.getState().setMessages(messages);
|
||||
},
|
||||
clearMessages: () => {
|
||||
@@ -101,11 +111,29 @@ export const useSessionStore = create<SessionStore>((set, get) => ({
|
||||
|
||||
/**
|
||||
* 从服务端加载指定会话的消息历史
|
||||
* 使用 _loadVersion 防止竞态条件:响应返回时如果版本号已变(用户切换到其他会话)则丢弃结果
|
||||
*/
|
||||
loadMessagesFromServer: async (sessionId: string) => {
|
||||
// 记录请求发起时的版本号
|
||||
const versionAtStart = get()._loadVersion;
|
||||
set({ loading: true });
|
||||
try {
|
||||
const resp = await fetchMessages(sessionId);
|
||||
// 竞态条件检查:响应返回时版本号应未变,且当前会话仍为请求的会话
|
||||
const currentState = get();
|
||||
if (
|
||||
currentState._loadVersion !== versionAtStart ||
|
||||
currentState.currentSessionId !== sessionId
|
||||
) {
|
||||
// 用户已切换到其他会话,丢弃此过期响应
|
||||
console.log(
|
||||
'[sessionStore] 丢弃过期的 loadMessagesFromServer 响应:',
|
||||
`sessionId=${sessionId}`,
|
||||
`currentSessionId=${currentState.currentSessionId}`,
|
||||
`version=${versionAtStart}→${currentState._loadVersion}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
const rawMessages = resp.messages || [];
|
||||
const msgs: Message[] = rawMessages.map((m: any, i: number) => ({
|
||||
id: m.id ? String(m.id) : `hist_${i}_${Date.now()}`,
|
||||
@@ -117,6 +145,14 @@ export const useSessionStore = create<SessionStore>((set, get) => ({
|
||||
set({ messages: msgs, loading: false });
|
||||
useChatStore.getState().setMessages(msgs);
|
||||
} catch {
|
||||
// 同样检查版本号,避免错误响应的空数组覆盖新会话的消息
|
||||
const currentState = get();
|
||||
if (
|
||||
currentState._loadVersion !== versionAtStart ||
|
||||
currentState.currentSessionId !== sessionId
|
||||
) {
|
||||
return;
|
||||
}
|
||||
set({ messages: [], loading: false });
|
||||
useChatStore.getState().clearMessages();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user