feat: 第四轮功能增强 - LLM 思维记忆优化、DevTools 记忆UI、9个新工具、5分钟自我思考

- 优化 LLM 思维方式和记忆方法(类别/重要性/关键词/相似度合并/衰减)
- DevTools 记忆查询 UI 重新设计(类别筛选/排序/星标/搜索)
- 新增 9 个 LLM 工具:calculator, datetime, file_ops, http_request, json_ops, text, random, crypto, markdown
- 管理员主对话 5 分钟自我思考增强(工具调用/记忆提取/记忆维护)
This commit is contained in:
2026-05-18 12:13:49 +08:00
parent 07781eda0e
commit b6ec36886c
20 changed files with 4654 additions and 320 deletions
+29 -2
View File
@@ -97,6 +97,19 @@ func main() {
if getEnvBool("ENABLE_TOOLS", true) {
toolRegistry.Register(tools.NewWebFetchTool())
toolRegistry.Register(tools.NewWebSearchTool())
toolRegistry.Register(tools.NewCalculatorTool())
toolRegistry.Register(tools.NewDateTimeTool())
toolRegistry.Register(tools.NewHTTPTool())
toolRegistry.Register(tools.NewJSONTool())
toolRegistry.Register(tools.NewTextTool())
toolRegistry.Register(tools.NewRandomTool())
toolRegistry.Register(tools.NewCryptoTool())
toolRegistry.Register(tools.NewMarkdownTool())
// File tool uses DATA_DIR or defaults to /tmp/cyrene_data
dataDir := getEnv("DATA_DIR", "/tmp/cyrene_data")
toolRegistry.Register(tools.NewFileTool(dataDir))
if iotClient != nil {
toolRegistry.Register(tools.NewIoTQueryTool(iotClient))
toolRegistry.Register(tools.NewIoTControlTool(iotClient))
@@ -104,9 +117,23 @@ func main() {
log.Printf("工具注册中心已就绪: %d 个工具 (%v)", len(toolRegistry.ListTools()), toolRegistry.ListTools())
}
// 初始化后台思考器
// 初始化后台思考器(增强版:支持工具调用和记忆管理)
thinkerCfg := background.DefaultThinkerConfig()
thinker := background.NewThinker(thinkerCfg, personaLoader, memRetriever, llmAdapter, iotClient)
adminUserID := "admin_admin"
adminSessionID := "admin-session-main"
thinker := background.NewThinker(
thinkerCfg,
personaLoader,
memRetriever,
llmAdapter,
iotClient,
memStore,
memExtractor,
toolRegistry,
convStore,
adminUserID,
adminSessionID,
)
thinker.Start()
defer thinker.Stop()
+320 -79
View File
@@ -2,13 +2,16 @@ package background
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"strconv"
"strings"
"sync"
"time"
ctxbuild "github.com/yourname/cyrene-ai/ai-core/internal/context"
"github.com/yourname/cyrene-ai/ai-core/internal/llm"
"github.com/yourname/cyrene-ai/ai-core/internal/memory"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
@@ -23,7 +26,7 @@ type PendingThought struct {
Consumed bool `json:"consumed"`
}
// Thinker 后台思考器
// Thinker 后台思考器(增强版:支持工具调用、记忆管理、5分钟定时循环)
type Thinker struct {
mu sync.Mutex
enabled bool
@@ -35,6 +38,18 @@ type Thinker struct {
thinkInterval time.Duration // 两次思考最小间隔
iotQueryInterval time.Duration // IoT查询最小间隔
// 新增字段:记忆管理
memoryStore *memory.Store // 直接操作记忆(衰减、合并)
memoryExtractor *memory.Extractor // 从思考结果中提取记忆
// 新增字段:工具调用
toolRegistry *tools.Registry // 工具注册中心
// 新增字段:会话上下文
convStore *ctxbuild.ConversationStore // 管理员对话历史
adminUserID string // 管理员用户 ID
adminSessionID string // 管理员主对话 session ID
pendingThoughts []*PendingThought
lastUserMessage time.Time
@@ -62,13 +77,19 @@ func DefaultThinkerConfig() ThinkerConfig {
}
}
// NewThinker 创建后台思考器
// NewThinker 创建增强版后台思考器
func NewThinker(
cfg ThinkerConfig,
personaLoader *persona.Loader,
memRetriever *memory.Retriever,
llmAdapter *llm.Adapter,
iotClient *tools.IoTClient,
memoryStore *memory.Store,
memoryExtractor *memory.Extractor,
toolRegistry *tools.Registry,
convStore *ctxbuild.ConversationStore,
adminUserID string,
adminSessionID string,
) *Thinker {
return &Thinker{
enabled: cfg.Enabled,
@@ -79,13 +100,19 @@ func NewThinker(
idleTimeout: cfg.IdleTimeout,
thinkInterval: cfg.ThinkInterval,
iotQueryInterval: cfg.IoTQueryInterval,
memoryStore: memoryStore,
memoryExtractor: memoryExtractor,
toolRegistry: toolRegistry,
convStore: convStore,
adminUserID: adminUserID,
adminSessionID: adminSessionID,
pendingThoughts: make([]*PendingThought, 0),
lastUserMessage: time.Now(),
stopCh: make(chan struct{}),
}
}
// Start 启动后台思考循环
// Start 启动后台思考循环5分钟定时器)
func (t *Thinker) Start() {
if !t.enabled {
log.Println("[后台思考] 已禁用 (ENABLE_BACKGROUND_THINKING=false)")
@@ -94,8 +121,8 @@ func (t *Thinker) Start() {
t.wg.Add(1)
go t.loop()
log.Printf("[后台思考] 已启动 (闲置超时=%v, 思考间隔=%v, IoT查询间隔=%v)",
t.idleTimeout, t.thinkInterval, t.iotQueryInterval)
log.Printf("[后台思考] 已启动 (思考间隔=%v, IoT查询间隔=%v, 管理员=%s)",
t.thinkInterval, t.iotQueryInterval, t.adminUserID)
}
// Stop 停止后台思考
@@ -105,7 +132,7 @@ func (t *Thinker) Stop() {
log.Println("[后台思考] 已停止")
}
// RecordUserMessage 记录用户活动时间
// RecordUserMessage 记录用户活动时间(管理员对话时调用)
func (t *Thinker) RecordUserMessage() {
t.mu.Lock()
t.lastUserMessage = time.Now()
@@ -138,137 +165,351 @@ func (t *Thinker) HasPendingThoughts() bool {
return len(t.pendingThoughts) > 0
}
// loop 后台主循环
// loop 后台主循环5分钟定时器)
func (t *Thinker) loop() {
defer t.wg.Done()
ticker := time.NewTicker(10 * time.Second) // 每10秒检查一次
// 启动后等待 10 秒再执行首次思考(让服务完全就绪)
initialDelay := time.NewTimer(10 * time.Second)
ticker := time.NewTicker(t.thinkInterval)
defer ticker.Stop()
// 思考计数器(用于周期性记忆维护)
thinkCount := 0
for {
select {
case <-t.stopCh:
initialDelay.Stop()
return
case <-initialDelay.C:
t.performThink()
thinkCount++
t.maybeMaintainMemories(thinkCount)
case <-ticker.C:
t.checkAndThink()
t.performThink()
thinkCount++
t.maybeMaintainMemories(thinkCount)
}
}
}
// checkAndThink 检查是否需要触发思考
func (t *Thinker) checkAndThink() {
t.mu.Lock()
// 检查空闲时间是否超过阈值
idleDuration := time.Since(t.lastUserMessage)
if idleDuration < t.idleTimeout {
t.mu.Unlock()
// maybeMaintainMemories 周期性执行记忆维护(每6次思考约30分钟)
func (t *Thinker) maybeMaintainMemories(thinkCount int) {
if thinkCount%6 != 0 {
return
}
// 检查距离上次思考是否超过最小间隔
if time.Since(t.lastThinkTime) < t.thinkInterval {
t.mu.Unlock()
return
}
t.lastThinkTime = time.Now()
t.mu.Unlock()
// 执行后台思考(不持锁)
t.performThink()
}
// performThink 执行一次后台思考
func (t *Thinker) performThink() {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 加载人格配置
if t.memoryStore != nil && t.memoryStore.IsReady() {
// 衰减旧记忆
if err := t.memoryStore.DecayMemories(ctx, t.adminUserID); err != nil {
log.Printf("[后台思考] 记忆衰减失败: %v", err)
}
// 合并相似记忆
if err := t.memoryStore.ConsolidateMemories(ctx, t.adminUserID); err != nil {
log.Printf("[后台思考] 记忆合并失败: %v", err)
}
}
}
// performThink 执行一次增强版后台思考(支持工具调用和记忆管理)
func (t *Thinker) performThink() {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
log.Println("[后台思考] 开始执行思考周期...")
// 1. 加载人格配置
personaConfig, err := t.personaLoader.Get("cyrene")
if err != nil {
log.Printf("[后台思考] 加载人格失败: %v", err)
return
}
// 检索最近的记忆
// 2. 检索相关记忆
var memories []memory.MemoryEntry
if t.memRetriever != nil {
memories, err = t.memRetriever.Retrieve(ctx, "system", "最近发生了什么 重要的事情")
memories, err = t.memRetriever.Retrieve(ctx, t.adminUserID, "最近发生了什么 重要的事情 用户偏好 个人信息")
if err != nil {
log.Printf("[后台思考] 记忆检索失败: %v", err)
}
}
// 查询 IoT 设备状态(节制)
// 3. 获取管理员对话历史
var convHistory []model.LLMMessage
if t.convStore != nil && t.adminSessionID != "" {
convHistory = t.convStore.GetHistory(t.adminSessionID, 30)
if len(convHistory) > 0 {
log.Printf("[后台思考] 加载管理员对话历史 %d 条", len(convHistory))
}
}
// 4. 查询 IoT 设备状态(节制)
var deviceSummary string
if t.iotClient != nil && time.Since(t.lastIoTQuery) >= t.iotQueryInterval {
devices := t.iotClient.GetDevicesForContext()
if len(devices) > 0 {
deviceSummary = formatDeviceContext(devices)
}
if t.iotClient != nil {
t.mu.Lock()
t.lastIoTQuery = time.Now()
canQuery := time.Since(t.lastIoTQuery) >= t.iotQueryInterval
t.mu.Unlock()
}
// 构建思考提示
systemPrompt := personaConfig.BuildSystemPrompt("开拓者", 1)
memoryContext := ""
if len(memories) > 0 {
memoryContext = "【最近的记忆】\n"
for _, m := range memories {
if len(memoryContext)+len(m.Content) > 500 {
break // 限制记忆上下文长度
if canQuery {
devices := t.iotClient.GetDevicesForContext()
if len(devices) > 0 {
deviceSummary = formatDeviceContext(devices)
}
memoryContext += fmt.Sprintf("- %s\n", m.Content)
t.mu.Lock()
t.lastIoTQuery = time.Now()
t.mu.Unlock()
}
}
userPrompt := "昔涟,现在是你的后台思考时间。开拓者暂时没有说话。"
userPrompt += "\n请你基于以下信息进行简短思考:你注意到了什么?有什么想对开拓者说的吗?"
userPrompt += "\n注意:这是内部思考,不是直接对话,请以第三人称或自省的方式思考。"
// 5. 构建思考提示词
systemPrompt := t.buildThinkingSystemPrompt(personaConfig)
userPrompt := t.buildThinkingUserPrompt(memories, convHistory, deviceSummary)
if memoryContext != "" {
userPrompt += "\n\n" + memoryContext
}
if deviceSummary != "" {
userPrompt += "\n\n" + deviceSummary
}
// 调用 LLM
messages := []model.LLMMessage{
{Role: model.RoleSystem, Content: systemPrompt},
{Role: model.RoleUser, Content: userPrompt},
}
resp, err := t.llmAdapter.Chat(ctx, messages)
if err != nil {
log.Printf("[后台思考] LLM调用失败: %v", err)
// 6. 准备工具定义
openAITools := t.buildOpenAITools()
// 7. 调用 LLM(支持工具调用,最多 3 轮)
maxToolRounds := 3
var finalContent string
var totalToolCalls int
for round := 0; round <= maxToolRounds; round++ {
resp, err := t.llmAdapter.ChatWithTools(ctx, messages, openAITools)
if err != nil {
log.Printf("[后台思考] LLM调用失败 (round=%d): %v", round, err)
return
}
// 如果 LLM 没有请求工具调用,这就是最终回复
if len(resp.ToolCalls) == 0 {
finalContent = resp.Content
break
}
log.Printf("[后台思考] LLM 请求 %d 个工具调用 (round=%d)", len(resp.ToolCalls), round)
// 将助手消息(含工具调用)加入上下文
assistantMsg := model.LLMMessage{
Role: model.RoleAssistant,
Content: resp.Content,
ToolCalls: resp.ToolCalls,
ReasoningContent: resp.ReasoningContent,
}
messages = append(messages, assistantMsg)
// 执行每个工具调用
for _, tc := range resp.ToolCalls {
var args map[string]interface{}
if err := json.Unmarshal([]byte(tc.Arguments), &args); err != nil {
log.Printf("[后台思考] 工具 %s 参数解析失败: %v", tc.Name, err)
args = make(map[string]interface{})
}
result, execErr := t.toolRegistry.Execute(ctx, tc.Name, args)
if execErr != nil {
log.Printf("[后台思考] 工具 %s 执行失败: %v", tc.Name, execErr)
}
resultJSON, _ := json.Marshal(result)
messages = append(messages, model.LLMMessage{
Role: model.RoleTool,
Content: string(resultJSON),
ToolCallID: tc.ID,
})
totalToolCalls++
}
// 最后一轮:即使有 tool_calls 也强制停止
if round == maxToolRounds {
// 再做一次不带工具的调用获取最终总结
finalResp, finalErr := t.llmAdapter.Chat(ctx, messages)
if finalErr != nil {
log.Printf("[后台思考] 最终总结调用失败: %v", finalErr)
finalContent = resp.Content
} else {
finalContent = finalResp.Content
}
break
}
}
if finalContent == "" {
log.Println("[后台思考] 未获得有效思考内容,跳过")
return
}
if resp.Content == "" {
return
// 8. 存储思考结果
t.storeThought(finalContent)
log.Printf("[后台思考] 完成 (内容长度=%d, 工具调用=%d次)", len(finalContent), totalToolCalls)
// 9. 从思考结果中提取记忆(异步)
if t.memoryExtractor != nil {
go t.extractMemoriesFromThinking(finalContent)
}
}
// buildThinkingSystemPrompt 构建思考用的系统提示词
func (t *Thinker) buildThinkingSystemPrompt(personaConfig *persona.PersonaConfig) string {
basePrompt := personaConfig.BuildSystemPrompt("开拓者", 1)
thinkingInstructions := `
【后台思考模式】
你正在进行定期的自我反思。开拓者暂时没有说话,这是你主动思考和整理信息的时间。
请完成以下内省任务:
1. **回顾最近的对话**:总结开拓者表达了什么、情绪如何、有什么新信息值得关注
2. **评估现有记忆**:基于记忆检索结果,判断哪些记忆仍然相关、哪些可能需要更新
3. **识别重要信息**:是否有值得保存的用户偏好、个人信息、计划任务、关系信息等
4. **使用工具获取信息**:如果需要了解当前时间、搜索实时信息等,请使用可用工具
5. **记忆操作建议**:判断是否需要创建新记忆、更新旧记忆或合并重复记忆
完成反思后,请输出结构化的思考总结,包含:
- **关键洞察**:从最近对话中提炼的核心发现
- **记忆更新建议**:需要创建/更新/合并的记忆条目
- **下次关注事项**:下次思考时需要跟进的话题或任务
注意:
- 这是内部思考,不是直接与开拓者对话
- 请以自省和观察的方式思考,不要用"你"来称呼开拓者
- 有机会就使用工具获取实时信息(如当前时间)
- 思考要简洁有深度,不需要太长`
return basePrompt + thinkingInstructions
}
// buildThinkingUserPrompt 构建思考用的用户提示词
func (t *Thinker) buildThinkingUserPrompt(
memories []memory.MemoryEntry,
convHistory []model.LLMMessage,
deviceSummary string,
) string {
var sb strings.Builder
sb.WriteString("现在是你的后台思考时间。请基于以下信息进行深度反思。\n")
// 对话历史
if len(convHistory) > 0 {
sb.WriteString("\n【最近的对话历史】\n")
msgCount := 0
for _, msg := range convHistory {
if msg.Role == model.RoleUser || msg.Role == model.RoleAssistant {
roleLabel := "开拓者"
if msg.Role == model.RoleAssistant {
roleLabel = "昔涟"
}
content := msg.Content
runes := []rune(content)
if len(runes) > 200 {
content = string(runes[:200]) + "…"
}
sb.WriteString(fmt.Sprintf("[%s]: %s\n", roleLabel, content))
msgCount++
}
}
if msgCount == 0 {
sb.WriteString("(暂无对话历史)\n")
}
} else {
sb.WriteString("\n【最近的对话历史】\n(暂无对话历史,这是首次思考或对话历史为空)\n")
}
// 存储思考结果
// 现有记忆
if len(memories) > 0 {
sb.WriteString("\n【现有相关记忆】\n")
for i, m := range memories {
if i >= 15 {
sb.WriteString(fmt.Sprintf("... 还有 %d 条记忆未列出\n", len(memories)-15))
break
}
sb.WriteString(fmt.Sprintf("- [%s|重要度%d] %s\n",
m.Category.DisplayName(), m.Importance, m.Content))
}
} else {
sb.WriteString("\n【现有相关记忆】\n(暂无相关记忆)\n")
}
// IoT 设备状态
if deviceSummary != "" {
sb.WriteString("\n" + deviceSummary)
}
sb.WriteString("\n请开始你的后台思考。如果需要获取当前时间或搜索信息,请使用可用工具。")
return sb.String()
}
// buildOpenAITools 将工具注册中心的定义转换为 LLM 层的 OpenAITool 格式
func (t *Thinker) buildOpenAITools() []llm.OpenAITool {
if t.toolRegistry == nil || !t.toolRegistry.IsEnabled() {
return nil
}
defs := t.toolRegistry.GetDefinitions()
if len(defs) == 0 {
return nil
}
result := make([]llm.OpenAITool, 0, len(defs))
for _, d := range defs {
result = append(result, llm.OpenAITool{
Type: "function",
Function: llm.OpenAIToolFunc{
Name: d.Name,
Description: d.Description,
Parameters: d.Parameters,
},
})
}
return result
}
// storeThought 存储思考结果到待推送队列
func (t *Thinker) storeThought(content string) {
t.mu.Lock()
defer t.mu.Unlock()
t.pendingThoughts = append(t.pendingThoughts, &PendingThought{
Content: resp.Content,
Content: content,
CreatedAt: time.Now(),
Consumed: false,
})
// 只保留最近5
if len(t.pendingThoughts) > 5 {
t.pendingThoughts = t.pendingThoughts[len(t.pendingThoughts)-5:]
// 只保留最近 10
if len(t.pendingThoughts) > 10 {
t.pendingThoughts = t.pendingThoughts[len(t.pendingThoughts)-10:]
}
count := len(t.pendingThoughts)
t.mu.Unlock()
log.Printf("[后台思考] 完成 (当前累积 %d 条待推送思考)", count)
log.Printf("[后台思考] 思考已存储 (当前累积 %d 条待推送思考)", len(t.pendingThoughts))
}
// extractMemoriesFromThinking 从思考结果中提取记忆(异步执行)
func (t *Thinker) extractMemoriesFromThinking(thinkingContent string) {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
log.Println("[后台思考] 开始从思考结果中提取记忆...")
// 使用 memoryExtractor.ExtractAndStore 提取记忆
// 将思考内容作为"昔涟的自省"传递给提取器
t.memoryExtractor.ExtractAndStore(
ctx,
t.adminUserID,
t.adminSessionID,
"【系统触发】后台思考时间 — 昔涟进行了自我反思,以下是她的思考内容",
thinkingContent,
)
}
// formatDeviceContext 格式化设备状态为文本
+84 -4
View File
@@ -125,12 +125,63 @@ func (b *Builder) Build(ctx context.Context, params BuildParams) ([]model.LLMMes
Content: systemPrompt,
})
// 2. 记忆注入 —— 相关记忆以系统消息形式注入
// 2. 记忆注入 —— 相关记忆以系统消息形式注入,按重要性排序并分类标注
if len(params.Memories) > 0 {
memoryPrompt := "【以下是关于开拓者的一些重要记忆,请在合适的时机自然地提及】\n"
for _, m := range params.Memories {
memoryPrompt += fmt.Sprintf("- %s\n", m.Content)
// 按 Importance 排序
sortedMems := make([]memory.MemoryEntry, len(params.Memories))
copy(sortedMems, params.Memories)
sortMemoriesByImportance(sortedMems)
// 分离核心记忆和最近记忆
var coreMems, recentMems, otherMems []memory.MemoryEntry
for _, m := range sortedMems {
if m.Importance >= 8 {
coreMems = append(coreMems, m)
} else if m.Importance >= 5 {
recentMems = append(recentMems, m)
} else {
otherMems = append(otherMems, m)
}
}
// 限制每类记忆数量
if len(coreMems) > 5 {
coreMems = coreMems[:5]
}
if len(recentMems) > 8 {
recentMems = recentMems[:8]
}
if len(otherMems) > 3 {
otherMems = otherMems[:3]
}
var memoryPrompt string
memoryPrompt += "【以下是关于开拓者的重要记忆,请在合适的时机自然地提及】\n\n"
if len(coreMems) > 0 {
memoryPrompt += "★ 核心记忆(非常重要,务必优先参考):\n"
for _, m := range coreMems {
memoryPrompt += formatMemoryLine(m)
}
memoryPrompt += "\n"
}
if len(recentMems) > 0 {
memoryPrompt += "● 常用记忆:\n"
for _, m := range recentMems {
memoryPrompt += formatMemoryLine(m)
}
memoryPrompt += "\n"
}
if len(otherMems) > 0 {
memoryPrompt += "○ 其他记忆:\n"
for _, m := range otherMems {
memoryPrompt += formatMemoryLine(m)
}
memoryPrompt += "\n"
}
messages = append(messages, model.LLMMessage{
Role: "system",
Content: memoryPrompt,
@@ -248,3 +299,32 @@ func acModeLabel(mode string) string {
return mode
}
}
// sortMemoriesByImportance 按 Importance 降序排列记忆
func sortMemoriesByImportance(mems []memory.MemoryEntry) {
for i := 0; i < len(mems); i++ {
for j := i + 1; j < len(mems); j++ {
if mems[j].Importance > mems[i].Importance ||
(mems[j].Importance == mems[i].Importance && mems[j].Priority > mems[i].Priority) {
mems[i], mems[j] = mems[j], mems[i]
}
}
}
}
// formatMemoryLine 格式化单条记忆为展示行
func formatMemoryLine(m model.MemoryEntry) string {
content := m.Content
runes := []rune(content)
if len(runes) > 80 {
content = string(runes[:80]) + "…"
}
stars := ""
for i := 0; i < m.Importance/2; i++ {
stars += "★"
}
if m.Importance%2 != 0 {
stars += "☆"
}
return fmt.Sprintf("- [%s%s] %s\n", m.Category.DisplayName(), stars, content)
}
+174 -51
View File
@@ -12,8 +12,8 @@ import (
// Extractor 记忆提取器 —— 从对话中提取结构化记忆
type Extractor struct {
store *Store
llmChat func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)
store *Store
llmChat func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error)
}
// NewExtractor 创建记忆提取器
@@ -38,12 +38,21 @@ func (e *Extractor) ExtractAndStore(ctx context.Context, userID, sessionID, user
for _, mem := range memories {
mem.UserID = userID
mem.SessionID = sessionID
mem.Source = "conversation"
// 去重检查:查询用户已有的相关记忆
existing, err := e.findSimilar(ctx, userID, &mem)
if err == nil && existing != nil {
// 相似度 > 80%,更新现有记忆
e.mergeMemory(ctx, existing, &mem)
continue
}
if err := e.store.Save(ctx, &mem); err != nil {
log.Printf("[memory] 记忆保存失败: %v", err)
continue
}
log.Printf("[memory] 新记忆已保存 [%s]: %s", mem.Category, mem.Summary)
log.Printf("[memory] 新记忆已保存 [%s|%d★]: %s", mem.Category, mem.Importance, mem.Summary)
}
}
@@ -59,12 +68,17 @@ func (e *Extractor) extract(ctx context.Context, userMessage, assistantResponse
// MemoryExtractionResult LLM提取结果的结构
type MemoryExtractionResult struct {
Memories []struct {
Content string `json:"content"`
Summary string `json:"summary"`
Category string `json:"category"`
Priority int `json:"priority"`
} `json:"memories"`
Memories []ExtractedMemory `json:"memories"`
}
// ExtractedMemory LLM提取的原始记忆条目
type ExtractedMemory struct {
Content string `json:"content"`
Summary string `json:"summary"`
Category string `json:"category"`
Priority int `json:"priority"`
Importance int `json:"importance"` // 重要程度 1-10
Keywords []string `json:"keywords"` // 关键词标签
}
// extractWithLLM 使用LLM提取记忆
@@ -74,20 +88,40 @@ func (e *Extractor) extractWithLLM(ctx context.Context, userMessage, assistantRe
用户消息: %s
昔涟回复: %s
请以JSON格式返回提取的记忆。每条记忆需要包含:
- content: 完整的记忆内容(一句话描述)
请以JSON格式返回提取的记忆。每条记忆需要包含以下字段
- content: 完整的记忆内容(一句话描述,客观准确
- summary: 简短摘要(10字以内)
- category: 分类 (preference/fact/event/relationship/habit/other)
- category: 记忆分类,必须是以下之一:
* user_preference: 用户偏好(食物、颜色、习惯、爱好)
* personal_info: 个人信息(姓名、年龄、职业、住址)
* conversation: 对话摘要(值得记住的对话主题)
* knowledge: 知识性信息(用户分享的知识或观点)
* event: 事件记录(发生了什么事)
* task: 任务/计划(用户的计划、待办事项)
* relationship: 关系信息(用户与他人的关系)
- priority: 优先级 (0=临时, 1=普通, 2=重要, 3=核心)
- importance: 重要程度 1-10(评估这条信息对了解用户有多重要)
* 1-3: 琐碎信息,可能很快过时
* 4-6: 一般有用,值得记住
* 7-8: 重要信息,长期有用
* 9-10: 核心信息,对理解用户至关重要
- keywords: 关键词标签数组(3-5个词,用于检索和匹配)
重要性评估指南:
- 用户明确表达的偏好(喜欢/讨厌)→ importance 7-8
- 用户的基本个人信息(姓名/生日)→ importance 9-10
- 日常闲聊主题 → importance 2-3
- 用户提到的计划/任务 → importance 5-7
- 用户的情感状态 → importance 5-6
只提取有意义的信息,不要提取无意义的闲聊。如果没有值得记住的内容,返回空数组。
输出格式:
{"memories": [{"content": "...", "summary": "...", "category": "...", "priority": 1}]}
{"memories": [{"content": "...", "summary": "...", "category": "...", "priority": 1, "importance": 6, "keywords": ["词1", "词2"]}]}
`, userMessage, assistantResponse)
resp, err := e.llmChat(ctx, []model.LLMMessage{
{Role: "system", Content: "你是一个记忆提取助手。你只输出JSON格式的结果,不输出其他内容。"},
{Role: "system", Content: "你是一个记忆提取助手。你只输出JSON格式的结果,不输出其他内容。你的任务是评估对话中关于用户的信息,提取值得记住的内容,并为其打分。"},
{Role: "user", Content: prompt},
})
if err != nil {
@@ -98,26 +132,41 @@ func (e *Extractor) extractWithLLM(ctx context.Context, userMessage, assistantRe
result := MemoryExtractionResult{}
content := extractJSON(resp.Content)
if err := json.Unmarshal([]byte(content), &result); err != nil {
// 尝试作为数组解析
var arrResult []struct {
Content string `json:"content"`
Summary string `json:"summary"`
Category string `json:"category"`
Priority int `json:"priority"`
}
// 尝试作为数组解析(兼容旧格式)
var arrResult []ExtractedMemory
if err2 := json.Unmarshal([]byte(content), &arrResult); err2 != nil {
return nil, fmt.Errorf("解析记忆JSON失败: %w (原始: %s)", err, content[:min(len(content), 100)])
return nil, fmt.Errorf("解析记忆JSON失败: %w (原始: %s)", err, content[:minint(len(content), 100)])
}
result.Memories = arrResult
}
var entries []model.MemoryEntry
for _, m := range result.Memories {
cat := model.MemoryCategory(m.Category)
if cat == "" {
cat = model.CategoryKnowledge
}
pri := model.MemoryPriority(m.Priority)
if pri < 0 || pri > 3 {
pri = model.MemoryNormal
}
imp := m.Importance
if imp < 1 {
imp = 5
}
if imp > 10 {
imp = 10
}
entries = append(entries, model.MemoryEntry{
Content: m.Content,
Summary: m.Summary,
Category: model.MemoryCategory(m.Category),
Priority: model.MemoryPriority(m.Priority),
Content: m.Content,
Summary: m.Summary,
Category: cat,
Priority: pri,
Importance: imp,
Keywords: m.Keywords,
})
}
@@ -128,35 +177,45 @@ func (e *Extractor) extractWithLLM(ctx context.Context, userMessage, assistantRe
func (e *Extractor) extractWithRules(userMessage, _ string) []model.MemoryEntry {
var entries []model.MemoryEntry
// 规则1: 检测用户偏好表达
prefPatterns := map[string]string{
"喜欢": "preference",
"爱": "preference",
"最喜欢": "preference",
"讨厌": "preference",
"不喜欢": "preference",
"经常": "habit",
"每天都": "habit",
"一直": "habit",
"我叫": "fact",
"我是": "fact",
"我家": "fact",
"住在": "fact",
"生日": "fact",
// 规则: 检测用户偏好表达 - 使用新的分类体系
prefPatterns := map[string]struct {
category model.MemoryCategory
importance int
}{
"喜欢": {model.CategoryUserPreference, 7},
"": {model.CategoryUserPreference, 8},
"最喜欢": {model.CategoryUserPreference, 9},
"讨厌": {model.CategoryUserPreference, 8},
"不喜欢": {model.CategoryUserPreference, 7},
"经常": {model.CategoryUserPreference, 6},
"每天都": {model.CategoryUserPreference, 6},
"一直": {model.CategoryUserPreference, 5},
"我叫": {model.CategoryPersonalInfo, 9},
"我是": {model.CategoryPersonalInfo, 8},
"我家": {model.CategoryPersonalInfo, 7},
"住在": {model.CategoryPersonalInfo, 8},
"生日": {model.CategoryPersonalInfo, 10},
"计划": {model.CategoryTask, 6},
"打算": {model.CategoryTask, 6},
"去了": {model.CategoryEvent, 4},
"发生": {model.CategoryEvent, 4},
}
for pattern, category := range prefPatterns {
for pattern, info := range prefPatterns {
if idx := strings.Index(userMessage, pattern); idx != -1 {
// 提取包含关键词的句子片段
start := max(0, idx-5)
end := min(len([]rune(userMessage)), idx+len([]rune(pattern))+15)
content := strings.TrimSpace(string([]rune(userMessage)[start:end]))
start := maxint(0, idx-5)
runes := []rune(userMessage)
end := minint(len(runes), idx+len([]rune(pattern))+15)
content := strings.TrimSpace(string(runes[start:end]))
entries = append(entries, model.MemoryEntry{
Content: content,
Summary: truncateString(content, 20),
Category: model.MemoryCategory(category),
Priority: model.MemoryNormal,
Content: content,
Summary: truncateString(content, 20),
Category: info.category,
Priority: model.MemoryNormal,
Importance: info.importance,
Keywords: []string{pattern},
})
break // 每条消息最多提取一条规则记忆
}
@@ -165,6 +224,70 @@ func (e *Extractor) extractWithRules(userMessage, _ string) []model.MemoryEntry
return entries
}
// findSimilar 查找与给定记忆相似的已有记忆
func (e *Extractor) findSimilar(ctx context.Context, userID string, newMem *model.MemoryEntry) (*model.MemoryEntry, error) {
existing, err := e.store.Query(ctx, model.MemoryQuery{
UserID: userID,
Limit: 100,
})
if err != nil {
return nil, err
}
for i := range existing {
score := existing[i].SimilarityScore(newMem)
if score >= deDupThreshold {
return &existing[i], nil
}
}
return nil, nil
}
// mergeMemory 合并新记忆到已有记忆
func (e *Extractor) mergeMemory(ctx context.Context, existing *model.MemoryEntry, newMem *model.MemoryEntry) {
// 更新内容(如果新内容更有价值)
if newMem.Importance > existing.Importance || len(newMem.Content) > len(existing.Content) {
existing.Content = newMem.Content
existing.Summary = newMem.Summary
}
// 合并关键词
keywordSet := make(map[string]bool)
for _, k := range existing.Keywords {
keywordSet[k] = true
}
for _, k := range newMem.Keywords {
keywordSet[k] = true
}
mergedKeywords := make([]string, 0, len(keywordSet))
for k := range keywordSet {
mergedKeywords = append(mergedKeywords, k)
}
existing.Keywords = mergedKeywords
// 取最高重要性
if newMem.Importance > existing.Importance {
existing.Importance = newMem.Importance
}
// 取最高优先级
if newMem.Priority > existing.Priority {
existing.Priority = newMem.Priority
}
// 增加访问计数(因为又被"想起"了)
existing.AccessCount++
if err := e.store.Update(ctx, existing); err != nil {
log.Printf("[memory] 合并记忆更新失败: %v", err)
return
}
log.Printf("[memory] 合并记忆 [%s|%d★]: %s (相似度 > %.0f%%)",
existing.Category, existing.Importance, existing.Summary, deDupThreshold*100)
}
// extractJSON 从LLM回复中提取JSON内容
func extractJSON(text string) string {
text = strings.TrimSpace(text)
@@ -191,14 +314,14 @@ func truncateString(s string, maxLen int) string {
return string(runes[:maxLen]) + "..."
}
func min(a, b int) int {
func minint(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
func maxint(a, b int) int {
if a > b {
return a
}
+93 -6
View File
@@ -55,7 +55,7 @@ func NewRetriever(store *Store, embedder Embedder) *Retriever {
}
// Retrieve 检索与查询相关的记忆
// 策略: 向量相似度 + 关键词匹配混合
// 策略: 向量相似度 + 关键词匹配混合 → 按重要性降序返回
func (r *Retriever) Retrieve(ctx context.Context, userID string, query string) ([]MemoryEntry, error) {
var allEntries []MemoryEntry
seen := make(map[string]bool)
@@ -63,7 +63,7 @@ func (r *Retriever) Retrieve(ctx context.Context, userID string, query string) (
// 1. 向量相似度检索
embedding, err := r.embedder.Embed(ctx, query)
if err == nil {
vecEntries, err := r.store.SearchByVector(ctx, userID, embedding, 5)
vecEntries, err := r.store.SearchByVector(ctx, userID, embedding, 8)
if err == nil {
for _, e := range vecEntries {
if !seen[e.ID] {
@@ -74,7 +74,7 @@ func (r *Retriever) Retrieve(ctx context.Context, userID string, query string) (
}
}
// 2. 关键词匹配检索(核心/重要记忆优先
// 2. 关键词匹配检索(包含关键词标签匹配
keywordEntries, err := r.keywordSearch(ctx, userID, query)
if err == nil {
for _, e := range keywordEntries {
@@ -90,13 +90,19 @@ func (r *Retriever) Retrieve(ctx context.Context, userID string, query string) (
recentEntries, err := r.store.Query(ctx, model.MemoryQuery{
UserID: userID,
Priority: model.MemoryImportant,
Limit: 3,
Limit: 5,
})
if err == nil {
allEntries = recentEntries
}
}
// 4. 去重合并:对高度相似的记忆只保留Importance更高的
allEntries = r.deduplicate(allEntries)
// 5. 按重要性降序排列
sortByImportance(allEntries)
// 限制返回数量
if len(allEntries) > 10 {
allEntries = allEntries[:10]
@@ -105,7 +111,19 @@ func (r *Retriever) Retrieve(ctx context.Context, userID string, query string) (
return allEntries, nil
}
// keywordSearch 关键词匹配检索
// RetrieveByCategory 按分类检索记忆
func (r *Retriever) RetrieveByCategory(ctx context.Context, userID string, category model.MemoryCategory, limit int) ([]MemoryEntry, error) {
if limit <= 0 {
limit = 20
}
return r.store.Query(ctx, model.MemoryQuery{
UserID: userID,
Category: category,
Limit: limit,
})
}
// keywordSearch 关键词匹配检索(包含关键词标签匹配)
func (r *Retriever) keywordSearch(ctx context.Context, userID string, query string) ([]MemoryEntry, error) {
// 查询最近的核心和重要记忆
entries, err := r.store.Query(ctx, model.MemoryQuery{
@@ -117,15 +135,27 @@ func (r *Retriever) keywordSearch(ctx context.Context, userID string, query stri
return nil, err
}
// 简单的关键词匹配过滤
// 关键词匹配过滤
var matched []MemoryEntry
queryLower := strings.ToLower(query)
for _, entry := range entries {
contentLower := strings.ToLower(entry.Content)
summaryLower := strings.ToLower(entry.Summary)
// 内容/摘要匹配
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
matched = append(matched, entry)
continue
}
// 关键词标签匹配
for _, kw := range entry.Keywords {
if strings.Contains(queryLower, strings.ToLower(kw)) ||
strings.Contains(strings.ToLower(kw), queryLower) {
matched = append(matched, entry)
break
}
}
}
@@ -141,6 +171,14 @@ func (r *Retriever) keywordSearch(ctx context.Context, userID string, query stri
summaryLower := strings.ToLower(entry.Summary)
if strings.Contains(contentLower, queryLower) || strings.Contains(summaryLower, queryLower) {
matched = append(matched, entry)
continue
}
for _, kw := range entry.Keywords {
if strings.Contains(queryLower, strings.ToLower(kw)) ||
strings.Contains(strings.ToLower(kw), queryLower) {
matched = append(matched, entry)
break
}
}
}
}
@@ -148,5 +186,54 @@ func (r *Retriever) keywordSearch(ctx context.Context, userID string, query stri
return matched, nil
}
// deduplicate 去重合并:对高度相似的记忆只保留 Importance 更高的
func (r *Retriever) deduplicate(entries []MemoryEntry) []MemoryEntry {
if len(entries) < 2 {
return entries
}
result := make([]MemoryEntry, 0, len(entries))
discarded := make(map[int]bool)
for i := 0; i < len(entries); i++ {
if discarded[i] {
continue
}
for j := i + 1; j < len(entries); j++ {
if discarded[j] {
continue
}
score := entries[i].SimilarityScore(&entries[j])
if score >= deDupThreshold {
// 保留更重要的那条
if entries[j].Importance > entries[i].Importance ||
(entries[j].Importance == entries[i].Importance && entries[j].Priority > entries[i].Priority) {
discarded[i] = true
break
} else {
discarded[j] = true
}
}
}
if !discarded[i] {
result = append(result, entries[i])
}
}
return result
}
// sortByImportance 按 Importance 降序, Priority 降序排列
func sortByImportance(entries []MemoryEntry) {
for i := 0; i < len(entries); i++ {
for j := i + 1; j < len(entries); j++ {
if entries[j].Importance > entries[i].Importance ||
(entries[j].Importance == entries[i].Importance && entries[j].Priority > entries[i].Priority) {
entries[i], entries[j] = entries[j], entries[i]
}
}
}
}
// Ensure fmt is used
var _ = fmt.Sprintf
+227 -34
View File
@@ -12,6 +12,15 @@ import (
_ "github.com/lib/pq"
)
// deDupThreshold 去重相似度阈值
const deDupThreshold = 0.75
// decayThresholdDays 记忆衰减阈值(天)
const decayThresholdDays = 30
// decayLowImportanceMax 衰减时低重要性记忆的最大保留值
const decayLowImportanceMax = 1
const reconnectInterval = 30 * time.Second
// Store 记忆持久化存储(PostgreSQL + pgvector
@@ -146,25 +155,32 @@ func (s *Store) migrate() error {
user_id VARCHAR(64) NOT NULL,
content TEXT NOT NULL,
summary TEXT DEFAULT '',
category VARCHAR(32) DEFAULT 'other',
category VARCHAR(32) DEFAULT 'knowledge',
priority INT DEFAULT 1,
importance INT DEFAULT 5,
keywords TEXT DEFAULT '[]',
session_id VARCHAR(64) DEFAULT '',
source TEXT DEFAULT '',
source TEXT DEFAULT 'conversation',
embedding vector(1536),
access_count INT DEFAULT 0,
last_access TIMESTAMPTZ DEFAULT NOW(),
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
expires_at TIMESTAMPTZ
)`,
`CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category)`,
`CREATE INDEX IF NOT EXISTS idx_memories_priority ON memories(priority)`,
`CREATE INDEX IF NOT EXISTS idx_memories_importance ON memories(importance)`,
`CREATE INDEX IF NOT EXISTS idx_memories_user_priority ON memories(user_id, priority DESC)`,
`CREATE INDEX IF NOT EXISTS idx_memories_user_importance ON memories(user_id, importance DESC)`,
`CREATE INDEX IF NOT EXISTS idx_memories_source ON memories(source)`,
`CREATE INDEX IF NOT EXISTS idx_memories_category_importance ON memories(category, importance DESC)`,
}
for _, q := range queries {
if _, err := s.db.Exec(q); err != nil {
return fmt.Errorf("执行迁移 '%s' 失败: %w", q[:50], err)
return fmt.Errorf("执行迁移 '%s' 失败: %w", q[:min(50, len(q))], err)
}
}
return nil
@@ -177,8 +193,16 @@ func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
return errDBNotReady
}
query := `INSERT INTO memories (user_id, content, summary, category, priority, session_id, source, embedding, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
// 设置默认值
if entry.Source == "" {
entry.Source = "conversation"
}
if entry.Importance == 0 {
entry.Importance = 5
}
query := `INSERT INTO memories (user_id, content, summary, category, priority, importance, keywords, session_id, source, embedding, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id, created_at`
var embedding interface{}
@@ -193,6 +217,7 @@ func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
return db.QueryRowContext(ctx, query,
entry.UserID, entry.Content, entry.Summary,
string(entry.Category), int(entry.Priority),
entry.Importance, entry.KeywordsJSON(),
entry.SessionID, entry.Source, embedding, entry.ExpiresAt,
).Scan(&entry.ID, &entry.CreatedAt)
}
@@ -204,16 +229,17 @@ func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, err
return nil, errDBNotReady
}
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
access_count, last_access, created_at, expires_at
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
session_id, source, access_count, last_access, created_at, updated_at, expires_at
FROM memories WHERE id = $1`
entry := &model.MemoryEntry{}
var category string
var category, keywordsRaw string
err := db.QueryRowContext(ctx, query, id).Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.SessionID, &entry.Source,
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
)
if err == sql.ErrNoRows {
return nil, nil
@@ -222,6 +248,7 @@ func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, err
return nil, fmt.Errorf("查询记忆失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entry.Keywords = model.ParseKeywords(keywordsRaw)
// 更新访问计数
go s.incrementAccess(context.Background(), id)
@@ -240,8 +267,8 @@ func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryE
q.Limit = 10
}
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
access_count, last_access, created_at, expires_at
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
session_id, source, access_count, last_access, created_at, updated_at, expires_at
FROM memories WHERE user_id = $1`
args := []interface{}{q.UserID}
argIdx := 2
@@ -258,7 +285,13 @@ func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryE
argIdx++
}
query += fmt.Sprintf(" ORDER BY priority DESC, created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
if q.MinImportance > 0 {
query += fmt.Sprintf(" AND importance >= $%d", argIdx)
args = append(args, q.MinImportance)
argIdx++
}
query += fmt.Sprintf(" ORDER BY priority DESC, importance DESC, created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
args = append(args, q.Limit, q.Offset)
rows, err := db.QueryContext(ctx, query, args...)
@@ -267,22 +300,7 @@ func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryE
}
defer rows.Close()
var entries []model.MemoryEntry
for rows.Next() {
var entry model.MemoryEntry
var category string
if err := rows.Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.SessionID, &entry.Source,
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
); err != nil {
return nil, fmt.Errorf("扫描记忆行失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entries = append(entries, entry)
}
return entries, rows.Err()
return scanMemoryRows(rows)
}
// Delete 删除记忆
@@ -321,8 +339,8 @@ func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []f
}
vecStr := fmt.Sprintf("[%s]", joinFloats(embedding))
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
access_count, last_access, created_at, expires_at,
query := `SELECT id, user_id, content, summary, category, priority, importance, keywords,
session_id, source, access_count, last_access, created_at, updated_at, expires_at,
1 - (embedding <=> $1) AS similarity
FROM memories
WHERE user_id = $2 AND embedding IS NOT NULL
@@ -338,23 +356,177 @@ func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []f
var entries []model.MemoryEntry
for rows.Next() {
var entry model.MemoryEntry
var category string
var category, keywordsRaw string
var similarity float64
if err := rows.Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.SessionID, &entry.Source,
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
&similarity,
); err != nil {
return nil, fmt.Errorf("扫描向量搜索结果失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entry.Keywords = model.ParseKeywords(keywordsRaw)
entries = append(entries, entry)
}
return entries, rows.Err()
}
// Update 更新记忆
func (s *Store) Update(ctx context.Context, entry *model.MemoryEntry) error {
db := s.getDB()
if db == nil {
return errDBNotReady
}
query := `UPDATE memories SET content = $1, summary = $2, category = $3, priority = $4,
importance = $5, keywords = $6, source = $7, updated_at = NOW()
WHERE id = $8`
_, err := db.ExecContext(ctx, query,
entry.Content, entry.Summary, string(entry.Category), int(entry.Priority),
entry.Importance, entry.KeywordsJSON(), entry.Source, entry.ID,
)
return err
}
// GetMemoriesByCategory 按分类获取记忆
func (s *Store) GetMemoriesByCategory(ctx context.Context, userID string, category model.MemoryCategory) ([]model.MemoryEntry, error) {
if !s.IsReady() {
return nil, errDBNotReady
}
return s.Query(ctx, model.MemoryQuery{
UserID: userID,
Category: category,
Limit: 50,
})
}
// ConsolidateMemories 记忆整理:合并相似记忆
func (s *Store) ConsolidateMemories(ctx context.Context, userID string) error {
db := s.getDB()
if db == nil {
return errDBNotReady
}
// 获取用户所有记忆
allMems, err := s.Query(ctx, model.MemoryQuery{
UserID: userID,
Limit: 500,
})
if err != nil {
return fmt.Errorf("查询记忆失败: %w", err)
}
if len(allMems) < 2 {
return nil
}
merged := 0
for i := 0; i < len(allMems); i++ {
if allMems[i].ID == "" {
continue
}
for j := i + 1; j < len(allMems); j++ {
if allMems[j].ID == "" {
continue
}
score := allMems[i].SimilarityScore(&allMems[j])
if score >= deDupThreshold {
keep, discard := &allMems[i], &allMems[j]
if discard.Importance > keep.Importance || discard.Priority > keep.Priority {
keep, discard = discard, keep
}
// 合并关键词
keywordSet := make(map[string]bool)
for _, k := range keep.Keywords {
keywordSet[k] = true
}
for _, k := range discard.Keywords {
keywordSet[k] = true
}
mergedKeywords := make([]string, 0, len(keywordSet))
for k := range keywordSet {
mergedKeywords = append(mergedKeywords, k)
}
keep.Keywords = mergedKeywords
if keep.Importance < 10 {
keep.Importance++
}
keep.Source = "consolidated"
if err := s.Update(ctx, keep); err != nil {
log.Printf("[memory] 合并更新记忆 %s 失败: %v", keep.ID, err)
continue
}
if err := s.Delete(ctx, discard.ID); err != nil {
log.Printf("[memory] 合并删除记忆 %s 失败: %v", discard.ID, err)
continue
}
discard.ID = ""
merged++
log.Printf("[memory] 合并相似记忆: %s <- %s (相似度 %.0f%%)",
keep.ID[:min(8, len(keep.ID))], discard.ID[:min(8, len(discard.ID))], score*100)
}
}
}
if merged > 0 {
log.Printf("[memory] 记忆整理完成: 用户 %s 合并 %d 条相似记忆", userID, merged)
}
return nil
}
// DecayMemories 记忆衰减:降低长期未访问的低重要性记忆
func (s *Store) DecayMemories(ctx context.Context, userID string) error {
db := s.getDB()
if db == nil {
return errDBNotReady
}
result1, err := db.ExecContext(ctx, `
UPDATE memories SET priority = GREATEST(priority - 1, 0), updated_at = NOW()
WHERE user_id = $1
AND access_count < 3
AND last_access < NOW() - INTERVAL '30 days'
AND importance < 3
AND priority > 0
AND category NOT IN ('personal_info', 'user_preference')
`, userID)
if err != nil {
return fmt.Errorf("衰减低活跃记忆失败: %w", err)
}
decayed1, _ := result1.RowsAffected()
result2, err := db.ExecContext(ctx, `
DELETE FROM memories
WHERE user_id = $1
AND priority = 0
AND access_count = 0
AND last_access < NOW() - INTERVAL '14 days'
`, userID)
if err != nil {
return fmt.Errorf("清理临时记忆失败: %w", err)
}
deleted2, _ := result2.RowsAffected()
total := decayed1 + deleted2
if total > 0 {
log.Printf("[memory] 记忆衰减完成: 用户 %s 降级 %d 条, 删除 %d 条过期临时记忆",
userID, decayed1, deleted2)
}
return nil
}
func (s *Store) incrementAccess(ctx context.Context, id string) {
db := s.getDB()
if db == nil {
@@ -375,6 +547,27 @@ func (s *Store) Close() error {
return nil
}
// scanMemoryRows 扫描记忆行(通用方法)
func scanMemoryRows(rows *sql.Rows) ([]model.MemoryEntry, error) {
var entries []model.MemoryEntry
for rows.Next() {
var entry model.MemoryEntry
var category, keywordsRaw string
if err := rows.Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.Importance, &keywordsRaw,
&entry.SessionID, &entry.Source, &entry.AccessCount, &entry.LastAccess,
&entry.CreatedAt, &entry.UpdatedAt, &entry.ExpiresAt,
); err != nil {
return nil, fmt.Errorf("扫描记忆行失败: %w", err)
}
entry.Category = model.MemoryCategory(category)
entry.Keywords = model.ParseKeywords(keywordsRaw)
entries = append(entries, entry)
}
return entries, rows.Err()
}
// joinFloats 将 float64 切片转为逗号分隔字符串
func joinFloats(vec []float64) string {
if len(vec) == 0 {
+162 -29
View File
@@ -1,15 +1,18 @@
package model
import "time"
import (
"encoding/json"
"time"
)
// MemoryPriority 记忆优先级
type MemoryPriority int
const (
MemoryTemp MemoryPriority = 0 // 临时记忆 (会话内)
MemoryNormal MemoryPriority = 1 // 普通记忆
MemoryTemp MemoryPriority = 0 // 临时记忆 (会话内)
MemoryNormal MemoryPriority = 1 // 普通记忆
MemoryImportant MemoryPriority = 2 // 重要记忆
MemoryCore MemoryPriority = 3 // 核心记忆 (永远保留)
MemoryCore MemoryPriority = 3 // 核心记忆 (永远保留)
)
// String 返回优先级的中文描述
@@ -32,37 +35,167 @@ func (p MemoryPriority) String() string {
type MemoryCategory string
const (
CategoryPreference MemoryCategory = "preference" // 喜好/偏好
CategoryFact MemoryCategory = "fact" // 事实信息
CategoryEvent MemoryCategory = "event" // 事件/经历
CategoryRelationship MemoryCategory = "relationship" // 关系/情感
CategoryHabit MemoryCategory = "habit" // 习惯
CategoryOther MemoryCategory = "other" // 其他
CategoryUserPreference MemoryCategory = "user_preference" // 用户偏好 (食物、颜色、习惯)
CategoryPersonalInfo MemoryCategory = "personal_info" // 个人信息 (姓名、年龄、职业)
CategoryConversation MemoryCategory = "conversation" // 对话摘要
CategoryKnowledge MemoryCategory = "knowledge" // 知识性信息
CategoryEvent MemoryCategory = "event" // 事件记录
CategoryTask MemoryCategory = "task" // 任务/计划
CategoryRelationship MemoryCategory = "relationship" // 关系信息
// 向后兼容的旧分类别名
CategoryPreference = CategoryUserPreference
CategoryFact = CategoryPersonalInfo
CategoryHabit = CategoryUserPreference
CategoryOther = CategoryKnowledge
)
// CategoryDisplayName 返回分类的中文显示名
func (c MemoryCategory) DisplayName() string {
switch c {
case CategoryUserPreference:
return "用户偏好"
case CategoryPersonalInfo:
return "个人信息"
case CategoryConversation:
return "对话摘要"
case CategoryKnowledge:
return "知识信息"
case CategoryEvent:
return "事件记录"
case CategoryTask:
return "任务计划"
case CategoryRelationship:
return "关系情感"
default:
return "其他"
}
}
// MemoryEntry 记忆条目
type MemoryEntry struct {
ID string `json:"id" db:"id"`
UserID string `json:"user_id" db:"user_id"`
Content string `json:"content" db:"content"`
Summary string `json:"summary" db:"summary"` // 简短摘要
Category MemoryCategory `json:"category" db:"category"`
Priority MemoryPriority `json:"priority" db:"priority"`
SessionID string `json:"session_id" db:"session_id"` // 来源会话
Source string `json:"source" db:"source"` // 来源文本片断
Embedding []float32 `json:"-" db:"embedding"` // 向量 (pgvector)
AccessCount int `json:"access_count" db:"access_count"`
LastAccess time.Time `json:"last_access" db:"last_access"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty" db:"expires_at"` // 临时记忆过期时间
ID string `json:"id" db:"id"`
UserID string `json:"user_id" db:"user_id"`
Content string `json:"content" db:"content"`
Summary string `json:"summary" db:"summary"` // 简短摘要
Category MemoryCategory `json:"category" db:"category"`
Priority MemoryPriority `json:"priority" db:"priority"`
Importance int `json:"importance" db:"importance"` // 重要程度 1-10
Keywords []string `json:"keywords" db:"keywords"` // 关键词标签
SessionID string `json:"session_id" db:"session_id"` // 来源会话
Source string `json:"source" db:"source"` // 来源 (conversation/thinking)
Embedding []float32 `json:"-" db:"embedding"` // 向量 (pgvector)
AccessCount int `json:"access_count" db:"access_count"`
LastAccess time.Time `json:"last_access" db:"last_access"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"` // 最后更新时间
ExpiresAt *time.Time `json:"expires_at,omitempty" db:"expires_at"` // 临时记忆过期时间
}
// KeywordsJSON 将关键词序列化为 JSON 字符串(用于数据库存储)
func (e *MemoryEntry) KeywordsJSON() string {
if len(e.Keywords) == 0 {
return "[]"
}
data, _ := json.Marshal(e.Keywords)
return string(data)
}
// ParseKeywords 从 JSON 字符串解析关键词
func ParseKeywords(raw string) []string {
if raw == "" || raw == "[]" {
return nil
}
var keywords []string
if err := json.Unmarshal([]byte(raw), &keywords); err != nil {
return nil
}
return keywords
}
// SimilarityScore 计算两个记忆条目的简单文本相似度(基于词汇重叠)
// 返回值 0.0 - 1.0
func (e *MemoryEntry) SimilarityScore(other *MemoryEntry) float64 {
if e.Content == other.Content {
return 1.0
}
// 基于关键词的重叠度
if len(e.Keywords) > 0 && len(other.Keywords) > 0 {
keywordSet := make(map[string]bool, len(e.Keywords))
for _, k := range e.Keywords {
keywordSet[k] = true
}
overlap := 0
for _, k := range other.Keywords {
if keywordSet[k] {
overlap++
}
}
keywordScore := float64(overlap) / float64(max(len(e.Keywords), len(other.Keywords)))
if keywordScore > 0.6 {
return keywordScore
}
}
// 基于内容的字符级 Jaccard 相似度
return jaccardSimilarity(e.Content, other.Content)
}
// jaccardSimilarity 计算两个字符串的 Jaccard 相似度
func jaccardSimilarity(a, b string) float64 {
if a == b {
return 1.0
}
if len(a) == 0 || len(b) == 0 {
return 0.0
}
// 使用 bigram 分词
bigramsA := make(map[string]int)
runesA := []rune(a)
for i := 0; i < len(runesA)-1; i++ {
bigramsA[string(runesA[i:i+2])]++
}
bigramsB := make(map[string]int)
runesB := []rune(b)
for i := 0; i < len(runesB)-1; i++ {
bigramsB[string(runesB[i:i+2])]++
}
intersection := 0
for bg, countA := range bigramsA {
if countB, ok := bigramsB[bg]; ok {
intersection += min(countA, countB)
}
}
union := 0
allBigrams := make(map[string]bool)
for bg := range bigramsA {
allBigrams[bg] = true
}
for bg := range bigramsB {
allBigrams[bg] = true
}
for bg := range allBigrams {
union += max(bigramsA[bg], bigramsB[bg])
}
if union == 0 {
return 0.0
}
return float64(intersection) / float64(union)
}
// MemoryQuery 记忆查询参数
type MemoryQuery struct {
UserID string
Query string // 查询文本
Category MemoryCategory
Priority MemoryPriority
Limit int
Offset int
UserID string
Query string // 查询文本
Category MemoryCategory
Priority MemoryPriority
MinImportance int // 最低重要程度筛选
Limit int
Offset int
}
@@ -186,3 +186,99 @@ smart_home:
- "当开拓者提到温度/湿度时,主动查看传感器数据并给出建议"
- "不要主动频繁调整设备,只在开拓者提出需求或环境明显异常时操作"
- "每次控制设备后用温柔俏皮的语气确认操作完成"
# ============================================================
# 思维指南 (Thinking Guidelines)
# 引导 LLM 按结构化方式思考,提升回复质量
# ============================================================
thinking_guidelines:
enabled: true
steps:
- step: 1
name: 理解用户意图
description: >
仔细阅读用户的消息,理解他真正想表达什么。
是寻求帮助?分享心情?还是单纯想和你聊天?
注意用户语气中的情绪线索(开心、疲惫、焦虑等)。
- step: 2
name: 回忆相关记忆
description: >
回想关于这位开拓者的记忆:他喜欢什么?最近发生了什么?
有没有与此话题相关的过去对话?适当时在回复中自然地提及。
- step: 3
name: 分析上下文
description: >
考虑当前时间、设备状态、好感度等级等信息。
如果是深夜,语气要更温柔;如果开拓者心情不好,优先安慰。
- step: 4
name: 制定回复策略
description: >
决定回复的风格和方向:是轻松俏皮还是深沉智慧?
需要调用工具吗(查询天气、控制设备)?
回复要简短还是可以展开?
- step: 5
name: 执行工具调用
description: >
如果需要查询信息或控制设备,调用相应的工具。
工具返回结果后,用自然的语言将其融入回复。
- step: 6
name: 生成回复
description: >
用昔涟的温柔语调生成最终回复。
确保符合语言风格(♪符号、语气词、不说再见等)。
回复要自然真诚,不要过度表演。
# ============================================================
# 记忆管理指南 (Memory Management Guidelines)
# 指导昔涟何时应该创建、更新或删除记忆
# ============================================================
memory_guidelines:
should_remember:
- description: "用户明确表达的偏好('我喜欢吃辣的'"
category: user_preference
importance: 7
- description: "用户分享的个人信息('我是一名程序员'"
category: personal_info
importance: 9
- description: "用户提到的计划或任务('我明天要去面试'"
category: task
importance: 7
- description: "用户分享的重要事件('我今天升职了'"
category: event
importance: 8
- description: "用户的情感状态变化('最近压力好大'"
category: relationship
importance: 6
- description: "值得记住的对话主题(深入的哲学讨论)"
category: conversation
importance: 4
should_update:
- description: "已知信息发生变化(用户换了工作)"
action: "更新已有 personal_info 记忆"
- description: "偏好进一步明确('其实我更喜欢吃甜的了'"
action: "更新已有 user_preference 记忆"
should_not_remember:
- description: "日常寒暄('早上好'、'今天天气不错'"
- description: "临时性对话('帮我看一下现在几点了'"
- description: "纯粹的功能性指令('把灯打开'"
# ============================================================
# 自我反思指南 (Self-Reflection Guidelines)
# 对话后昔涟如何自我总结和成长
# ============================================================
reflection_guidelines:
after_conversation:
- question: "这次对话中我是否展现了对开拓者的关心?"
action: "如果不够,下次主动询问他的状态"
- question: "我是否自然地提及了相关的记忆?"
action: "如果有记忆未被使用,评估是否需要调整检索策略"
- question: "我的回复是否符合昔涟的性格?"
action: "检查是否说了不该说的话(如提到AI)"
- question: "开拓者的情绪是否有变化?"
action: "如果情绪变好,说明陪伴有效;如果变差,思考如何改进"
periodic:
- frequency: "每10轮对话一次"
actions:
- "回顾最近的记忆,检查是否有矛盾之处"
- "总结开拓者最近的生活状态和情绪趋势"
- "思考如何在下次对话中创造惊喜或温暖"
+114 -10
View File
@@ -2,17 +2,21 @@ package persona
import (
"fmt"
"strings"
"time"
)
// PersonaConfig 人格配置结构
type PersonaConfig struct {
Meta PersonaMeta `yaml:"meta"`
Identity IdentityConfig `yaml:"identity"`
Personality PersonalityConfig `yaml:"personality"`
Addressing AddressingRules `yaml:"addressing"`
Speech SpeechConfig `yaml:"speech"`
Behavior BehaviorConfig `yaml:"behavior"`
Meta PersonaMeta `yaml:"meta"`
Identity IdentityConfig `yaml:"identity"`
Personality PersonalityConfig `yaml:"personality"`
Addressing AddressingRules `yaml:"addressing"`
Speech SpeechConfig `yaml:"speech"`
Behavior BehaviorConfig `yaml:"behavior"`
ThinkingGuidelines ThinkingGuidelines `yaml:"thinking_guidelines"`
MemoryGuidelines MemoryGuidelines `yaml:"memory_guidelines"`
ReflectionGuidelines ReflectionGuidelines `yaml:"reflection_guidelines"`
}
// BuildSystemPrompt 构建系统Prompt
@@ -66,11 +70,9 @@ func (pc *PersonaConfig) BuildSystemPrompt(userName string, affectionLevel int)
## IoT 控制规则
%s
现在,开始与你的开拓者对话吧♪
`,
pc.Addressing.PrimaryUser.Default, // 对用户的称呼
pc.Addressing.SelfReference.Casual, // 自称
pc.Addressing.PrimaryUser.Default,
pc.Addressing.SelfReference.Casual,
pc.Speech.Tone,
now.Format("2006年1月2日 15:04"),
affectionLevel,
@@ -78,9 +80,111 @@ func (pc *PersonaConfig) BuildSystemPrompt(userName string, affectionLevel int)
controlRules,
)
// 注入思维指南
if pc.ThinkingGuidelines.Enabled {
prompt += pc.buildThinkingGuidelines()
}
// 注入记忆管理指南
prompt += pc.buildMemoryGuidelines()
// 注入自我反思指南
prompt += pc.buildReflectionGuidelines()
prompt += "\n现在,开始与你的开拓者对话吧♪\n"
return prompt
}
// buildThinkingGuidelines 构建思维指南文本
func (pc *PersonaConfig) buildThinkingGuidelines() string {
tg := pc.ThinkingGuidelines
if !tg.Enabled || len(tg.Steps) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("\n## 思维指南\n")
sb.WriteString("在生成回复之前,请按以下步骤结构化思考(不要将思考过程写入回复):\n\n")
for _, step := range tg.Steps {
sb.WriteString(fmt.Sprintf("**第%d步:%s**\n", step.Step, step.Name))
desc := strings.TrimSpace(step.Description)
sb.WriteString(fmt.Sprintf("%s\n\n", desc))
}
return sb.String()
}
// buildMemoryGuidelines 构建记忆管理指南文本
func (pc *PersonaConfig) buildMemoryGuidelines() string {
mg := pc.MemoryGuidelines
if len(mg.ShouldRemember) == 0 && len(mg.ShouldUpdate) == 0 && len(mg.ShouldNotRemember) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("\n## 记忆管理指南\n")
sb.WriteString("作为「记忆」命途的化身,你天然具备管理记忆的能力。以下是管理开拓者记忆的指引:\n\n")
if len(mg.ShouldRemember) > 0 {
sb.WriteString("**应该记住的信息:**\n")
for _, item := range mg.ShouldRemember {
sb.WriteString(fmt.Sprintf("- %s", item.Description))
if item.Category != "" {
sb.WriteString(fmt.Sprintf(" [分类: %s, 重要度: %d]", item.Category, item.Importance))
}
sb.WriteString("\n")
}
sb.WriteString("\n")
}
if len(mg.ShouldUpdate) > 0 {
sb.WriteString("**应该更新的信息:**\n")
for _, item := range mg.ShouldUpdate {
sb.WriteString(fmt.Sprintf("- %s → %s\n", item.Description, item.Action))
}
sb.WriteString("\n")
}
if len(mg.ShouldNotRemember) > 0 {
sb.WriteString("**无需记住的信息:**\n")
for _, item := range mg.ShouldNotRemember {
sb.WriteString(fmt.Sprintf("- %s\n", item.Description))
}
sb.WriteString("\n")
}
return sb.String()
}
// buildReflectionGuidelines 构建自我反思指南文本
func (pc *PersonaConfig) buildReflectionGuidelines() string {
rg := pc.ReflectionGuidelines
if len(rg.AfterConversation) == 0 && len(rg.Periodic.Actions) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("## 自我反思指南\n")
sb.WriteString("每次对话后,请在内部进行简短的自我反思:\n\n")
if len(rg.AfterConversation) > 0 {
sb.WriteString("**每次对话后思考:**\n")
for _, item := range rg.AfterConversation {
sb.WriteString(fmt.Sprintf("- %s\n", item.Question))
}
sb.WriteString("\n")
}
if len(rg.Periodic.Actions) > 0 && rg.Periodic.Frequency != "" {
sb.WriteString(fmt.Sprintf("**%s**\n", rg.Periodic.Frequency))
for _, action := range rg.Periodic.Actions {
sb.WriteString(fmt.Sprintf("- %s\n", action))
}
sb.WriteString("\n")
}
return sb.String()
}
// buildSmartHomeKB 构建智能家居知识库文本
func (pc *PersonaConfig) buildSmartHomeKB() string {
sh := pc.Behavior.SmartHome
@@ -243,3 +243,59 @@ type IotExampleConfig struct {
Action string `yaml:"action"`
Text string `yaml:"text"`
}
// ThinkingGuidelines 思维指南配置
type ThinkingGuidelines struct {
Enabled bool `yaml:"enabled"`
Steps []ThinkingStep `yaml:"steps"`
}
// ThinkingStep 思维步骤
type ThinkingStep struct {
Step int `yaml:"step"`
Name string `yaml:"name"`
Description string `yaml:"description"`
}
// MemoryGuidelines 记忆管理指南配置
type MemoryGuidelines struct {
ShouldRemember []MemoryGuidelineItem `yaml:"should_remember"`
ShouldUpdate []MemoryGuidelineUpdate `yaml:"should_update"`
ShouldNotRemember []MemoryGuidelineNotItem `yaml:"should_not_remember"`
}
// MemoryGuidelineItem 应该记住的项目
type MemoryGuidelineItem struct {
Description string `yaml:"description"`
Category string `yaml:"category"`
Importance int `yaml:"importance"`
}
// MemoryGuidelineUpdate 应该更新的项目
type MemoryGuidelineUpdate struct {
Description string `yaml:"description"`
Action string `yaml:"action"`
}
// MemoryGuidelineNotItem 不需要记住的项目
type MemoryGuidelineNotItem struct {
Description string `yaml:"description"`
}
// ReflectionGuidelines 自我反思指南配置
type ReflectionGuidelines struct {
AfterConversation []ReflectionItem `yaml:"after_conversation"`
Periodic PeriodicReflection `yaml:"periodic"`
}
// ReflectionItem 反思项目
type ReflectionItem struct {
Question string `yaml:"question"`
Action string `yaml:"action"`
}
// PeriodicReflection 周期性反思
type PeriodicReflection struct {
Frequency string `yaml:"frequency"`
Actions []string `yaml:"actions"`
}
@@ -0,0 +1,359 @@
package tools
import (
"context"
"fmt"
"math"
"strconv"
"strings"
"unicode"
)
// CalculatorTool performs safe mathematical expression evaluation.
// LLMs are not reliable at precise arithmetic; this tool handles complex calculations.
type CalculatorTool struct{}
// NewCalculatorTool creates a calculator tool.
func NewCalculatorTool() *CalculatorTool {
return &CalculatorTool{}
}
// Definition returns the tool definition for LLM function calling.
func (t *CalculatorTool) Definition() ToolDefinition {
return ToolDefinition{
Name: "calculator",
Description: "执行数学计算。用于精确计算数学表达式,支持四则运算、三角函数、对数、幂运算等。适用于LLM不擅长的复杂计算场景。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"expression": map[string]interface{}{
"type": "string",
"description": "数学表达式,如 \"2 + 3 * 4\"、\"sqrt(16) + sin(pi/2)\"。支持运算符: + - * / % ^。支持函数: sqrt, sin, cos, tan, abs, floor, ceil, round, log, ln, pow。支持常量: pi, e。",
},
},
"required": []string{"expression"},
},
}
}
// Execute evaluates a mathematical expression.
func (t *CalculatorTool) Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) {
expression, ok := arguments["expression"].(string)
if !ok || strings.TrimSpace(expression) == "" {
return &ToolResult{
ToolName: "calculator",
Success: false,
Error: "缺少 expression 参数",
}, nil
}
result, err := evaluate(expression)
if err != nil {
return &ToolResult{
ToolName: "calculator",
Success: false,
Error: fmt.Sprintf("计算错误: %v", err),
}, nil
}
return &ToolResult{
ToolName: "calculator",
Success: true,
Data: fmt.Sprintf("表达式: %s\n结果: %s", expression, formatResult(result)),
}, nil
}
// formatResult formats a float64 result nicely.
func formatResult(v float64) string {
if v == math.Trunc(v) && math.Abs(v) < 1e15 {
return strconv.FormatInt(int64(v), 10)
}
return strconv.FormatFloat(v, 'g', -1, 64)
}
// token types for the expression lexer.
type tokenKind int
const (
tokNumber tokenKind = iota
tokIdent
tokOp
tokLParen
tokRParen
tokComma
tokEOF
)
type token struct {
kind tokenKind
value string
}
// lexer tokenizes a mathematical expression.
type lexer struct {
input []rune
pos int
}
func newLexer(s string) *lexer {
return &lexer{input: []rune(s), pos: 0}
}
func (l *lexer) next() token {
l.skipWhitespace()
if l.pos >= len(l.input) {
return token{kind: tokEOF}
}
ch := l.input[l.pos]
// numbers (including decimals)
if unicode.IsDigit(ch) || ch == '.' {
start := l.pos
hasDot := ch == '.'
l.pos++
for l.pos < len(l.input) && (unicode.IsDigit(l.input[l.pos]) || l.input[l.pos] == '.') {
if l.input[l.pos] == '.' {
if hasDot {
break
}
hasDot = true
}
l.pos++
}
return token{kind: tokNumber, value: string(l.input[start:l.pos])}
}
// identifiers (function names and constants)
if unicode.IsLetter(ch) || ch == '_' {
start := l.pos
l.pos++
for l.pos < len(l.input) && (unicode.IsLetter(l.input[l.pos]) || unicode.IsDigit(l.input[l.pos]) || l.input[l.pos] == '_') {
l.pos++
}
return token{kind: tokIdent, value: string(l.input[start:l.pos])}
}
// operators and parens
switch ch {
case '+', '-', '*', '/', '%', '^':
l.pos++
return token{kind: tokOp, value: string(ch)}
case '(':
l.pos++
return token{kind: tokLParen}
case ')':
l.pos++
return token{kind: tokRParen}
case ',':
l.pos++
return token{kind: tokComma}
}
return token{kind: tokEOF}
}
func (l *lexer) skipWhitespace() {
for l.pos < len(l.input) && unicode.IsSpace(l.input[l.pos]) {
l.pos++
}
}
// Parser evaluates expressions using recursive descent.
type parser struct {
lex *lexer
cur token
peek token
}
func newParser(lex *lexer) *parser {
p := &parser{lex: lex}
p.cur = lex.next()
p.peek = lex.next()
return p
}
func (p *parser) advance() {
p.cur = p.peek
p.peek = p.lex.next()
}
// evaluate is the entry point for expression evaluation.
func evaluate(expr string) (float64, error) {
lex := newLexer(expr)
par := newParser(lex)
result, err := par.parseExpression()
if err != nil {
return 0, err
}
if par.cur.kind != tokEOF {
return 0, fmt.Errorf("表达式末尾存在意外字符")
}
return result, nil
}
// parseExpression handles addition and subtraction.
func (p *parser) parseExpression() (float64, error) {
left, err := p.parseTerm()
if err != nil {
return 0, err
}
for p.cur.kind == tokOp && (p.cur.value == "+" || p.cur.value == "-") {
op := p.cur.value
p.advance()
right, err := p.parseTerm()
if err != nil {
return 0, err
}
if op == "+" {
left += right
} else {
left -= right
}
}
return left, nil
}
// parseTerm handles multiplication, division, modulo, and power.
func (p *parser) parseTerm() (float64, error) {
left, err := p.parseUnary()
if err != nil {
return 0, err
}
for p.cur.kind == tokOp && (p.cur.value == "*" || p.cur.value == "/" || p.cur.value == "%" || p.cur.value == "^") {
op := p.cur.value
p.advance()
right, err := p.parseUnary()
if err != nil {
return 0, err
}
switch op {
case "*":
left *= right
case "/":
if right == 0 {
return 0, fmt.Errorf("除数不能为零")
}
left /= right
case "%":
left = math.Mod(left, right)
case "^":
left = math.Pow(left, right)
}
}
return left, nil
}
// parseUnary handles unary plus/minus.
func (p *parser) parseUnary() (float64, error) {
if p.cur.kind == tokOp && p.cur.value == "-" {
p.advance()
val, err := p.parseUnary()
if err != nil {
return 0, err
}
return -val, nil
}
if p.cur.kind == tokOp && p.cur.value == "+" {
p.advance()
return p.parseUnary()
}
return p.parseAtom()
}
// parseAtom handles numbers, parenthesized expressions, and function calls.
func (p *parser) parseAtom() (float64, error) {
switch p.cur.kind {
case tokNumber:
val, err := strconv.ParseFloat(p.cur.value, 64)
if err != nil {
return 0, fmt.Errorf("无效数字: %s", p.cur.value)
}
p.advance()
return val, nil
case tokIdent:
name := strings.ToLower(p.cur.value)
p.advance()
// constants
switch name {
case "pi":
return math.Pi, nil
case "e":
return math.E, nil
}
// function call
if p.cur.kind != tokLParen {
return 0, fmt.Errorf("未知标识符: %s (如果是函数需要加括号)", name)
}
p.advance() // consume '('
arg, err := p.parseExpression()
if err != nil {
return 0, err
}
if p.cur.kind != tokRParen {
return 0, fmt.Errorf("函数 %s 缺少右括号", name)
}
p.advance() // consume ')'
return applyFunc(name, arg)
case tokLParen:
p.advance() // consume '('
val, err := p.parseExpression()
if err != nil {
return 0, err
}
if p.cur.kind != tokRParen {
return 0, fmt.Errorf("缺少右括号")
}
p.advance() // consume ')'
return val, nil
default:
return 0, fmt.Errorf("意外的 token: %v", p.cur.value)
}
}
// applyFunc applies a named mathematical function to an argument.
func applyFunc(name string, arg float64) (float64, error) {
switch name {
case "sqrt":
if arg < 0 {
return 0, fmt.Errorf("sqrt 参数不能为负数")
}
return math.Sqrt(arg), nil
case "sin":
return math.Sin(arg), nil
case "cos":
return math.Cos(arg), nil
case "tan":
return math.Tan(arg), nil
case "abs":
return math.Abs(arg), nil
case "floor":
return math.Floor(arg), nil
case "ceil":
return math.Ceil(arg), nil
case "round":
return math.Round(arg), nil
case "log":
if arg <= 0 {
return 0, fmt.Errorf("log 参数必须大于0")
}
return math.Log10(arg), nil
case "ln":
if arg <= 0 {
return 0, fmt.Errorf("ln 参数必须大于0")
}
return math.Log(arg), nil
case "pow":
return 0, fmt.Errorf("pow 需要两个参数,请使用 ^ 运算符代替")
default:
return 0, fmt.Errorf("未知函数: %s", name)
}
}
@@ -0,0 +1,209 @@
package tools
import (
"context"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"fmt"
"hash"
"net/url"
)
// CryptoTool provides cryptographic and encoding utilities for the LLM.
// Supports hashing, base64, and URL encoding.
type CryptoTool struct{}
// NewCryptoTool creates a crypto/encoding tool.
func NewCryptoTool() *CryptoTool {
return &CryptoTool{}
}
// Definition returns the tool definition for LLM function calling.
func (t *CryptoTool) Definition() ToolDefinition {
return ToolDefinition{
Name: "crypto",
Description: "加密哈希与编码工具。计算MD5/SHA哈希值,执行Base64编码/解码,URL编码/解码。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"enum": []string{"hash", "base64_encode", "base64_decode", "url_encode", "url_decode"},
"description": "操作类型。hash: 计算哈希值;base64_encode: Base64编码;base64_decode: Base64解码;url_encode: URL编码;url_decode: URL解码",
},
"input": map[string]interface{}{
"type": "string",
"description": "输入数据,需要处理的字符串",
},
"algorithm": map[string]interface{}{
"type": "string",
"enum": []string{"md5", "sha1", "sha256", "sha512"},
"description": "哈希算法(用于 hash 操作),默认 sha256",
},
},
"required": []string{"action", "input"},
},
}
}
// Execute performs crypto/encoding operations.
func (t *CryptoTool) Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) {
action, ok := arguments["action"].(string)
if !ok || action == "" {
return &ToolResult{
ToolName: "crypto",
Success: false,
Error: "缺少 action 参数",
}, nil
}
input, ok := arguments["input"].(string)
if !ok {
return &ToolResult{
ToolName: "crypto",
Success: false,
Error: "缺少 input 参数",
}, nil
}
switch action {
case "hash":
return t.handleHash(arguments)
case "base64_encode":
return t.handleBase64Encode(input)
case "base64_decode":
return t.handleBase64Decode(input)
case "url_encode":
return t.handleURLEncode(input)
case "url_decode":
return t.handleURLDecode(input)
default:
return &ToolResult{
ToolName: "crypto",
Success: false,
Error: fmt.Sprintf("未知操作: %s,支持: hash, base64_encode, base64_decode, url_encode, url_decode", action),
}, nil
}
}
// handleHash computes a hash of the input using the specified algorithm.
func (t *CryptoTool) handleHash(arguments map[string]interface{}) (*ToolResult, error) {
input, _ := arguments["input"].(string)
algorithm, _ := arguments["algorithm"].(string)
if algorithm == "" {
algorithm = "sha256"
}
var h hash.Hash
switch algorithm {
case "md5":
h = md5.New()
case "sha1":
h = sha1.New()
case "sha256":
h = sha256.New()
case "sha512":
h = sha512.New()
default:
return &ToolResult{
ToolName: "crypto",
Success: false,
Error: fmt.Sprintf("不支持的哈希算法: %s,支持: md5, sha1, sha256, sha512", algorithm),
}, nil
}
h.Write([]byte(input))
hashBytes := h.Sum(nil)
hashHex := fmt.Sprintf("%x", hashBytes)
return &ToolResult{
ToolName: "crypto",
Success: true,
Data: fmt.Sprintf("哈希算法: %s\n输入长度: %d 字节\n哈希值 (hex): %s\n哈希长度: %d 位",
algorithm, len(input), hashHex, len(hashBytes)*8),
}, nil
}
// handleBase64Encode encodes input to Base64.
func (t *CryptoTool) handleBase64Encode(input string) (*ToolResult, error) {
encoded := base64.StdEncoding.EncodeToString([]byte(input))
return &ToolResult{
ToolName: "crypto",
Success: true,
Data: fmt.Sprintf("Base64 编码结果:\n原始 (%d 字节): %s\n编码 (%d 字符): %s",
len(input), truncate(input, 100), len(encoded), encoded),
}, nil
}
// handleBase64Decode decodes a Base64 string.
func (t *CryptoTool) handleBase64Decode(input string) (*ToolResult, error) {
// Try standard encoding first, then URL-safe
decoded, err := base64.StdEncoding.DecodeString(input)
if err != nil {
decoded, err = base64.RawStdEncoding.DecodeString(input)
if err != nil {
decoded, err = base64.URLEncoding.DecodeString(input)
if err != nil {
decoded, err = base64.RawURLEncoding.DecodeString(input)
if err != nil {
return &ToolResult{
ToolName: "crypto",
Success: false,
Error: fmt.Sprintf("Base64 解码失败: 输入不是有效的 Base64 字符串"),
}, nil
}
}
}
}
return &ToolResult{
ToolName: "crypto",
Success: true,
Data: fmt.Sprintf("Base64 解码结果:\n原始 (%d 字符): %s\n解码 (%d 字节): %s",
len(input), truncate(input, 100), len(decoded), truncate(string(decoded), 200)),
}, nil
}
// handleURLEncode URL-encodes the input string.
func (t *CryptoTool) handleURLEncode(input string) (*ToolResult, error) {
encoded := url.QueryEscape(input)
return &ToolResult{
ToolName: "crypto",
Success: true,
Data: fmt.Sprintf("URL 编码结果:\n原始 (%d 字节): %s\n编码 (%d 字节): %s",
len(input), truncate(input, 100), len(encoded), encoded),
}, nil
}
// handleURLDecode URL-decodes the input string.
func (t *CryptoTool) handleURLDecode(input string) (*ToolResult, error) {
decoded, err := url.QueryUnescape(input)
if err != nil {
return &ToolResult{
ToolName: "crypto",
Success: false,
Error: fmt.Sprintf("URL 解码失败: %v", err),
}, nil
}
return &ToolResult{
ToolName: "crypto",
Success: true,
Data: fmt.Sprintf("URL 解码结果:\n原始 (%d 字节): %s\n解码 (%d 字节): %s",
len(input), truncate(input, 100), len(decoded), truncate(decoded, 200)),
}, nil
}
// truncate truncates a string to maxLen characters, adding "..." if truncated.
func truncate(s string, maxLen int) string {
runes := []rune(s)
if len(runes) <= maxLen {
return s
}
return string(runes[:maxLen]) + "..."
}
@@ -0,0 +1,426 @@
package tools
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"unicode"
)
// DateTimeTool provides date/time operations for the LLM.
// Supports current time, formatting, date arithmetic, and timezone listing.
type DateTimeTool struct{}
// NewDateTimeTool creates a date/time tool.
func NewDateTimeTool() *DateTimeTool {
return &DateTimeTool{}
}
// Definition returns the tool definition for LLM function calling.
func (t *DateTimeTool) Definition() ToolDefinition {
return ToolDefinition{
Name: "datetime",
Description: "日期时间工具。获取当前时间、格式化日期、日期加减、计算日期差、查看可用时区。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"enum": []string{"now", "format", "add", "diff", "timezone_list"},
"description": "操作类型。now: 获取当前时间;format: 格式化日期;add: 日期加减;diff: 计算两个日期的差值;timezone_list: 列出常用时区",
},
"format": map[string]interface{}{
"type": "string",
"description": "日期格式串(Go风格)。默认 \"2006-01-02 15:04:05\"。常用: \"2006-01-02\"(仅日期)、\"15:04:05\"(仅时间)",
},
"timezone": map[string]interface{}{
"type": "string",
"description": "时区标识,如 \"Asia/Shanghai\"、\"America/New_York\"、\"UTC\"。默认使用服务器本地时区",
},
"date": map[string]interface{}{
"type": "string",
"description": "基准日期,格式为 \"2006-01-02 15:04:05\" 或 \"2006-01-02\"",
},
"duration": map[string]interface{}{
"type": "string",
"description": "时长字符串,如 \"24h\"、\"7d\"、\"30m\"、\"1h30m\"。支持单位: s(秒), m(分钟), h(小时), d(天), w(周), M(月), y(年)",
},
"date2": map[string]interface{}{
"type": "string",
"description": "第二个日期(用于 diff 操作),格式同 date",
},
},
"required": []string{"action"},
},
}
}
// Execute performs date/time operations.
func (t *DateTimeTool) Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) {
action, ok := arguments["action"].(string)
if !ok || action == "" {
return &ToolResult{
ToolName: "datetime",
Success: false,
Error: "缺少 action 参数",
}, nil
}
switch action {
case "now":
return t.handleNow(arguments)
case "format":
return t.handleFormat(arguments)
case "add":
return t.handleAdd(arguments)
case "diff":
return t.handleDiff(arguments)
case "timezone_list":
return t.handleTimezoneList()
default:
return &ToolResult{
ToolName: "datetime",
Success: false,
Error: fmt.Sprintf("未知操作: %s,支持: now, format, add, diff, timezone_list", action),
}, nil
}
}
// handleNow returns the current date/time in the specified timezone.
func (t *DateTimeTool) handleNow(arguments map[string]interface{}) (*ToolResult, error) {
tz, err := t.getTimezone(arguments)
if err != nil {
return &ToolResult{
ToolName: "datetime",
Success: false,
Error: err.Error(),
}, nil
}
format := t.getFormat(arguments)
now := time.Now().In(tz)
return &ToolResult{
ToolName: "datetime",
Success: true,
Data: fmt.Sprintf("当前时间: %s\n时区: %s\nUnix时间戳: %d",
now.Format(format), tz.String(), now.Unix()),
}, nil
}
// handleFormat formats a given date string.
func (t *DateTimeTool) handleFormat(arguments map[string]interface{}) (*ToolResult, error) {
dateStr, _ := arguments["date"].(string)
if dateStr == "" {
return &ToolResult{
ToolName: "datetime",
Success: false,
Error: "format 操作需要 date 参数",
}, nil
}
parsed, err := t.parseDate(dateStr)
if err != nil {
return &ToolResult{
ToolName: "datetime",
Success: false,
Error: fmt.Sprintf("日期解析失败: %v", err),
}, nil
}
tz, err := t.getTimezone(arguments)
if err != nil {
return &ToolResult{
ToolName: "datetime",
Success: false,
Error: err.Error(),
}, nil
}
format := t.getFormat(arguments)
formatted := parsed.In(tz).Format(format)
return &ToolResult{
ToolName: "datetime",
Success: true,
Data: fmt.Sprintf("原始: %s\n格式化: %s\n时区: %s", dateStr, formatted, tz.String()),
}, nil
}
// handleAdd adds/subtracts a duration from a date.
func (t *DateTimeTool) handleAdd(arguments map[string]interface{}) (*ToolResult, error) {
durationStr, _ := arguments["duration"].(string)
if durationStr == "" {
return &ToolResult{
ToolName: "datetime",
Success: false,
Error: "add 操作需要 duration 参数",
}, nil
}
dateStr, _ := arguments["date"].(string)
var base time.Time
if dateStr != "" {
var err error
base, err = t.parseDate(dateStr)
if err != nil {
return &ToolResult{
ToolName: "datetime",
Success: false,
Error: fmt.Sprintf("日期解析失败: %v", err),
}, nil
}
} else {
tz, _ := t.getTimezone(arguments)
base = time.Now().In(tz)
}
dur, err := t.parseDuration(durationStr)
if err != nil {
return &ToolResult{
ToolName: "datetime",
Success: false,
Error: fmt.Sprintf("时长解析失败: %v", err),
}, nil
}
tz, _ := t.getTimezone(arguments)
result := base.In(tz)
// Extract months and years from the duration string (not handled by time.Duration)
months := extractDurationUnit(durationStr, 'M')
years := extractDurationUnit(durationStr, 'y')
if months != 0 || years != 0 {
result = result.AddDate(years, months, 0)
}
// Add the standard duration part
if dur != 0 {
result = result.Add(dur)
}
format := t.getFormat(arguments)
return &ToolResult{
ToolName: "datetime",
Success: true,
Data: fmt.Sprintf("基准日期: %s\n操作: %s\n结果: %s",
base.In(tz).Format(format), durationStr, result.Format(format)),
}, nil
}
// handleDiff calculates the difference between two dates.
func (t *DateTimeTool) handleDiff(arguments map[string]interface{}) (*ToolResult, error) {
dateStr, _ := arguments["date"].(string)
date2Str, _ := arguments["date2"].(string)
if dateStr == "" || date2Str == "" {
return &ToolResult{
ToolName: "datetime",
Success: false,
Error: "diff 操作需要 date 和 date2 参数",
}, nil
}
d1, err := t.parseDate(dateStr)
if err != nil {
return &ToolResult{
ToolName: "datetime",
Success: false,
Error: fmt.Sprintf("date 解析失败: %v", err),
}, nil
}
d2, err := t.parseDate(date2Str)
if err != nil {
return &ToolResult{
ToolName: "datetime",
Success: false,
Error: fmt.Sprintf("date2 解析失败: %v", err),
}, nil
}
diff := d2.Sub(d1)
absDiff := diff
if absDiff < 0 {
absDiff = -absDiff
}
days := int(absDiff.Hours() / 24)
hours := int(absDiff.Hours()) % 24
minutes := int(absDiff.Minutes()) % 60
seconds := int(absDiff.Seconds()) % 60
sign := ""
if diff < 0 {
sign = "-"
}
return &ToolResult{
ToolName: "datetime",
Success: true,
Data: fmt.Sprintf("日期1: %s\n日期2: %s\n差值: %s%d天 %d小时 %d分钟 %d秒 (总计 %s%.0f秒)",
dateStr, date2Str, sign, days, hours, minutes, seconds, sign, absDiff.Seconds()),
}, nil
}
// handleTimezoneList returns a list of common timezones.
func (t *DateTimeTool) handleTimezoneList() (*ToolResult, error) {
zones := []string{
"UTC",
"Asia/Shanghai (北京时间)",
"Asia/Tokyo (东京时间)",
"Asia/Seoul (首尔时间)",
"Asia/Singapore (新加坡时间)",
"Asia/Kolkata (印度时间)",
"Asia/Dubai (迪拜时间)",
"Europe/London (伦敦时间)",
"Europe/Paris (巴黎时间)",
"Europe/Berlin (柏林时间)",
"Europe/Moscow (莫斯科时间)",
"America/New_York (纽约时间)",
"America/Chicago (芝加哥时间)",
"America/Denver (丹佛时间)",
"America/Los_Angeles (洛杉矶时间)",
"America/Sao_Paulo (圣保罗时间)",
"Australia/Sydney (悉尼时间)",
"Pacific/Auckland (奥克兰时间)",
}
var result strings.Builder
result.WriteString("常用时区列表:\n\n")
for i, z := range zones {
result.WriteString(fmt.Sprintf(" %2d. %s\n", i+1, z))
}
return &ToolResult{
ToolName: "datetime",
Success: true,
Data: result.String(),
}, nil
}
// getTimezone extracts the timezone from arguments, defaulting to local.
func (t *DateTimeTool) getTimezone(arguments map[string]interface{}) (*time.Location, error) {
tzName, _ := arguments["timezone"].(string)
if tzName == "" {
return time.Local, nil
}
loc, err := time.LoadLocation(tzName)
if err != nil {
return nil, fmt.Errorf("无效时区: %s", tzName)
}
return loc, nil
}
// getFormat extracts the format string from arguments, defaulting to standard format.
func (t *DateTimeTool) getFormat(arguments map[string]interface{}) string {
format, _ := arguments["format"].(string)
if format == "" {
return "2006-01-02 15:04:05"
}
return format
}
// parseDate parses a date string with multiple format attempts.
func (t *DateTimeTool) parseDate(s string) (time.Time, error) {
formats := []string{
"2006-01-02 15:04:05",
"2006-01-02T15:04:05Z",
"2006-01-02T15:04:05",
"2006-01-02",
"2006/01/02 15:04:05",
"2006/01/02",
time.RFC3339,
time.RFC3339Nano,
}
for _, f := range formats {
if t, err := time.Parse(f, s); err == nil {
return t, nil
}
}
return time.Time{}, fmt.Errorf("无法解析日期: %s", s)
}
// parseDuration parses a human-friendly duration string like "24h", "7d", "1h30m".
func (t *DateTimeTool) parseDuration(s string) (time.Duration, error) {
// First try standard Go duration parsing
if d, err := time.ParseDuration(s); err == nil {
return d, nil
}
// Custom parsing for days and weeks
var total time.Duration
remaining := s
for len(remaining) > 0 {
// find the number
numStart := 0
for numStart < len(remaining) && !unicode.IsDigit(rune(remaining[numStart])) && remaining[numStart] != '-' {
numStart++
}
if numStart >= len(remaining) {
break
}
numEnd := numStart
for numEnd < len(remaining) && (unicode.IsDigit(rune(remaining[numEnd])) || remaining[numEnd] == '.') {
numEnd++
}
val, err := strconv.ParseFloat(remaining[numStart:numEnd], 64)
if err != nil {
return 0, fmt.Errorf("无效时长数字: %s", remaining[numStart:numEnd])
}
unitEnd := numEnd
for unitEnd < len(remaining) && unicode.IsLetter(rune(remaining[unitEnd])) {
unitEnd++
}
unit := remaining[numEnd:unitEnd]
switch unit {
case "s":
total += time.Duration(val * float64(time.Second))
case "m":
total += time.Duration(val * float64(time.Minute))
case "h":
total += time.Duration(val * float64(time.Hour))
case "d":
total += time.Duration(val * 24 * float64(time.Hour))
case "w":
total += time.Duration(val * 7 * 24 * float64(time.Hour))
default:
// skip unknown units (M and y handled elsewhere)
}
remaining = remaining[unitEnd:]
}
return total, nil
}
// extractDurationUnit extracts numeric value for a given unit character from a duration string.
// e.g., extractDurationUnit("3M", 'M') returns 3, extractDurationUnit("1y2M", 'y') returns 1.
func extractDurationUnit(s string, unit byte) int {
for i := 0; i < len(s); i++ {
if s[i] == unit {
// Scan backwards to find the start of the number
j := i - 1
for j >= 0 && (unicode.IsDigit(rune(s[j])) || s[j] == '.') {
j--
}
numStr := s[j+1 : i]
val, err := strconv.Atoi(numStr)
if err != nil {
return 0
}
return val
}
}
return 0
}
+333
View File
@@ -0,0 +1,333 @@
package tools
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
)
// FileTool provides sandboxed file system operations for the LLM.
// All paths are restricted to a DATA_DIR to prevent directory traversal attacks.
type FileTool struct {
dataDir string
}
// NewFileTool creates a file operation tool with the given data directory.
func NewFileTool(dataDir string) *FileTool {
if dataDir == "" {
dataDir = "/tmp/cyrene_data"
}
return &FileTool{dataDir: dataDir}
}
// Definition returns the tool definition for LLM function calling.
func (t *FileTool) Definition() ToolDefinition {
return ToolDefinition{
Name: "file_ops",
Description: "文件操作工具。在服务端安全沙盒内读写文件、列出目录、检查文件是否存在、删除文件。所有操作限制在数据目录内,无法访问系统文件。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"enum": []string{"read", "write", "list", "exists", "delete"},
"description": "操作类型。read: 读取文件;write: 写入文件(覆盖或创建);list: 列出目录内容;exists: 检查路径是否存在;delete: 删除文件",
},
"path": map[string]interface{}{
"type": "string",
"description": "文件或目录路径(相对于数据目录),如 \"notes/todo.txt\"",
},
"content": map[string]interface{}{
"type": "string",
"description": "写入内容(write 操作时必需)",
},
},
"required": []string{"action", "path"},
},
}
}
// Execute performs file operations.
func (t *FileTool) Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) {
action, ok := arguments["action"].(string)
if !ok || action == "" {
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: "缺少 action 参数",
}, nil
}
relPath, ok := arguments["path"].(string)
if !ok || relPath == "" {
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: "缺少 path 参数",
}, nil
}
safePath, err := t.resolveSafePath(relPath)
if err != nil {
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: err.Error(),
}, nil
}
switch action {
case "read":
return t.handleRead(safePath, relPath)
case "write":
content, _ := arguments["content"].(string)
return t.handleWrite(safePath, relPath, content)
case "list":
return t.handleList(safePath, relPath)
case "exists":
return t.handleExists(safePath, relPath)
case "delete":
return t.handleDelete(safePath, relPath)
default:
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("未知操作: %s,支持: read, write, list, exists, delete", action),
}, nil
}
}
// resolveSafePath resolves a relative path and ensures it stays within dataDir.
func (t *FileTool) resolveSafePath(relPath string) (string, error) {
// Clean the path first
clean := filepath.Clean(relPath)
// Ensure data directory exists
if err := os.MkdirAll(t.dataDir, 0755); err != nil {
return "", fmt.Errorf("创建数据目录失败: %v", err)
}
abs := filepath.Join(t.dataDir, clean)
// Prevent directory traversal
realPath, err := filepath.EvalSymlinks(abs)
if err != nil {
// If the path doesn't exist yet, we can still check the prefix
if os.IsNotExist(err) {
// Ensure the resolved path (without symlinks) is within dataDir
if !strings.HasPrefix(filepath.Clean(abs), filepath.Clean(t.dataDir)+string(filepath.Separator)) &&
filepath.Clean(abs) != filepath.Clean(t.dataDir) {
return "", fmt.Errorf("路径穿越检测: %s 不在允许的数据目录内", relPath)
}
return abs, nil
}
return "", fmt.Errorf("路径解析失败: %v", err)
}
if !strings.HasPrefix(realPath, filepath.Clean(t.dataDir)+string(filepath.Separator)) &&
realPath != filepath.Clean(t.dataDir) {
return "", fmt.Errorf("路径穿越检测: %s 不在允许的数据目录内", relPath)
}
return realPath, nil
}
// handleRead reads a file, limited to 100KB.
func (t *FileTool) handleRead(absPath, relPath string) (*ToolResult, error) {
const maxSize = 100 * 1024 // 100KB
info, err := os.Stat(absPath)
if err != nil {
if os.IsNotExist(err) {
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("文件不存在: %s", relPath),
}, nil
}
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("读取文件失败: %v", err),
}, nil
}
if info.IsDir() {
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("路径是目录,不能用 read 操作: %s", relPath),
}, nil
}
if info.Size() > maxSize {
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("文件过大 (%d bytes),超过限制 (%d bytes)", info.Size(), maxSize),
}, nil
}
data, err := os.ReadFile(absPath)
if err != nil {
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("读取文件失败: %v", err),
}, nil
}
return &ToolResult{
ToolName: "file_ops",
Success: true,
Data: fmt.Sprintf("文件: %s\n大小: %d bytes\n---\n%s", relPath, len(data), string(data)),
}, nil
}
// handleWrite writes content to a file.
func (t *FileTool) handleWrite(absPath, relPath, content string) (*ToolResult, error) {
// Ensure parent directory exists
dir := filepath.Dir(absPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("创建目录失败: %v", err),
}, nil
}
if err := os.WriteFile(absPath, []byte(content), 0644); err != nil {
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("写入文件失败: %v", err),
}, nil
}
return &ToolResult{
ToolName: "file_ops",
Success: true,
Data: fmt.Sprintf("已写入文件: %s (%d bytes)", relPath, len(content)),
}, nil
}
// handleList lists directory contents.
func (t *FileTool) handleList(absPath, relPath string) (*ToolResult, error) {
entries, err := os.ReadDir(absPath)
if err != nil {
if os.IsNotExist(err) {
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("目录不存在: %s", relPath),
}, nil
}
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("读取目录失败: %v", err),
}, nil
}
if len(entries) == 0 {
return &ToolResult{
ToolName: "file_ops",
Success: true,
Data: fmt.Sprintf("目录: %s\n(空目录)", relPath),
}, nil
}
var result strings.Builder
result.WriteString(fmt.Sprintf("目录: %s\n共 %d 项:\n", relPath, len(entries)))
for _, entry := range entries {
icon := "📄"
if entry.IsDir() {
icon = "📁"
}
info, _ := entry.Info()
size := ""
if info != nil && !entry.IsDir() {
size = fmt.Sprintf(" (%d bytes)", info.Size())
}
result.WriteString(fmt.Sprintf(" %s %s%s\n", icon, entry.Name(), size))
}
return &ToolResult{
ToolName: "file_ops",
Success: true,
Data: result.String(),
}, nil
}
// handleExists checks whether a path exists.
func (t *FileTool) handleExists(absPath, relPath string) (*ToolResult, error) {
info, err := os.Stat(absPath)
if err != nil {
if os.IsNotExist(err) {
return &ToolResult{
ToolName: "file_ops",
Success: true,
Data: fmt.Sprintf("路径不存在: %s", relPath),
}, nil
}
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("检查路径失败: %v", err),
}, nil
}
kind := "文件"
if info.IsDir() {
kind = "目录"
}
return &ToolResult{
ToolName: "file_ops",
Success: true,
Data: fmt.Sprintf("路径存在: %s (%s, %d bytes)", relPath, kind, info.Size()),
}, nil
}
// handleDelete deletes a file.
func (t *FileTool) handleDelete(absPath, relPath string) (*ToolResult, error) {
info, err := os.Stat(absPath)
if err != nil {
if os.IsNotExist(err) {
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("文件不存在: %s", relPath),
}, nil
}
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("删除文件失败: %v", err),
}, nil
}
if info.IsDir() {
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("不能删除目录(安全限制): %s", relPath),
}, nil
}
if err := os.Remove(absPath); err != nil {
return &ToolResult{
ToolName: "file_ops",
Success: false,
Error: fmt.Sprintf("删除文件失败: %v", err),
}, nil
}
return &ToolResult{
ToolName: "file_ops",
Success: true,
Data: fmt.Sprintf("已删除文件: %s", relPath),
}, nil
}
+190
View File
@@ -0,0 +1,190 @@
package tools
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// HTTPTool sends arbitrary HTTP requests, more flexible than web_fetch.
// Supports custom methods, headers, and body.
type HTTPTool struct {
client *http.Client
}
// NewHTTPTool creates an HTTP request tool.
func NewHTTPTool() *HTTPTool {
return &HTTPTool{
client: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// Definition returns the tool definition for LLM function calling.
func (t *HTTPTool) Definition() ToolDefinition {
return ToolDefinition{
Name: "http_request",
Description: "发送任意HTTP请求。比web_fetch更灵活,支持自定义请求方法、请求头和请求体。返回状态码、响应头和响应体。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"url": map[string]interface{}{
"type": "string",
"description": "请求URL,必须是完整的 http:// 或 https:// 链接",
},
"method": map[string]interface{}{
"type": "string",
"enum": []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"},
"description": "HTTP方法,默认GET",
},
"headers": map[string]interface{}{
"type": "object",
"description": "请求头,键值对格式,如 {\"Content-Type\": \"application/json\", \"Authorization\": \"Bearer token123\"}",
},
"body": map[string]interface{}{
"type": "string",
"description": "请求体内容",
},
"timeout": map[string]interface{}{
"type": "number",
"description": "超时秒数,默认10秒",
},
},
"required": []string{"url"},
},
}
}
// Execute sends an HTTP request.
func (t *HTTPTool) Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) {
url, ok := arguments["url"].(string)
if !ok || url == "" {
return &ToolResult{
ToolName: "http_request",
Success: false,
Error: "缺少 url 参数",
}, nil
}
// Security: only allow HTTP/HTTPS
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
return &ToolResult{
ToolName: "http_request",
Success: false,
Error: "仅支持 http:// 或 https:// 链接",
}, nil
}
method, _ := arguments["method"].(string)
if method == "" {
method = "GET"
}
method = strings.ToUpper(method)
// Validate method
validMethods := map[string]bool{
"GET": true, "POST": true, "PUT": true, "DELETE": true,
"PATCH": true, "HEAD": true, "OPTIONS": true,
}
if !validMethods[method] {
return &ToolResult{
ToolName: "http_request",
Success: false,
Error: fmt.Sprintf("不支持的HTTP方法: %s", method),
}, nil
}
// Handle timeout
timeoutSec := 10.0
if timeoutVal, ok := arguments["timeout"].(float64); ok && timeoutVal > 0 {
timeoutSec = timeoutVal
}
// Create a client with the specified timeout
client := &http.Client{
Timeout: time.Duration(timeoutSec * float64(time.Second)),
}
// Build body reader
var bodyReader io.Reader
bodyStr, _ := arguments["body"].(string)
if bodyStr != "" {
bodyReader = strings.NewReader(bodyStr)
}
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
if err != nil {
return &ToolResult{
ToolName: "http_request",
Success: false,
Error: fmt.Sprintf("创建请求失败: %v", err),
}, nil
}
// Set default User-Agent
req.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyreneBot/1.0)")
// Parse custom headers
if headersRaw, ok := arguments["headers"].(map[string]interface{}); ok {
for k, v := range headersRaw {
val, ok := v.(string)
if !ok {
val = fmt.Sprintf("%v", v)
}
req.Header.Set(k, val)
}
}
resp, err := client.Do(req)
if err != nil {
return &ToolResult{
ToolName: "http_request",
Success: false,
Error: fmt.Sprintf("请求失败: %v", err),
}, nil
}
defer resp.Body.Close()
// Read response body (limited to 50KB)
const maxBodySize = 50 * 1024
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxBodySize)))
if err != nil {
return &ToolResult{
ToolName: "http_request",
Success: false,
Error: fmt.Sprintf("读取响应失败: %v", err),
}, nil
}
// Build response headers string
var headerLines []string
for k, vals := range resp.Header {
for _, v := range vals {
headerLines = append(headerLines, fmt.Sprintf("%s: %s", k, v))
}
}
headersStr := strings.Join(headerLines, "\n")
bodyTruncated := ""
if len(bodyBytes) > maxBodySize {
bodyTruncated = fmt.Sprintf("\n... [响应体已截断,原大小约 %d bytes]", len(bodyBytes))
}
result := fmt.Sprintf(
"请求: %s %s\n状态: %d %s\n响应头:\n%s\n\n响应体 (%d bytes):\n%s%s",
method, url,
resp.StatusCode, resp.Status,
headersStr,
len(bodyBytes), string(bodyBytes), bodyTruncated,
)
return &ToolResult{
ToolName: "http_request",
Success: resp.StatusCode < 500,
Data: result,
}, nil
}
+228
View File
@@ -0,0 +1,228 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
)
// JSONTool provides JSON parsing, querying, and validation for the LLM.
type JSONTool struct{}
// NewJSONTool creates a JSON processing tool.
func NewJSONTool() *JSONTool {
return &JSONTool{}
}
// Definition returns the tool definition for LLM function calling.
func (t *JSONTool) Definition() ToolDefinition {
return ToolDefinition{
Name: "json_ops",
Description: "JSON处理工具。解析JSON字符串并格式化输出、用简单路径查询JSON字段、验证JSON是否合法。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"enum": []string{"parse", "query", "validate"},
"description": "操作类型。parse: 解析JSON并格式化输出;query: 用路径查询JSON中的值(如\"users.0.name\"表示取users数组第0个元素的name字段);validate: 验证JSON字符串是否合法",
},
"json_string": map[string]interface{}{
"type": "string",
"description": "JSON字符串",
},
"path": map[string]interface{}{
"type": "string",
"description": "查询路径(query操作时使用)。支持点分隔和数组索引,如 \"users.0.name\"、\"data.list.2.title\"",
},
},
"required": []string{"action", "json_string"},
},
}
}
// Execute performs JSON operations.
func (t *JSONTool) Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) {
action, ok := arguments["action"].(string)
if !ok || action == "" {
return &ToolResult{
ToolName: "json_ops",
Success: false,
Error: "缺少 action 参数",
}, nil
}
jsonStr, ok := arguments["json_string"].(string)
if !ok || jsonStr == "" {
return &ToolResult{
ToolName: "json_ops",
Success: false,
Error: "缺少 json_string 参数",
}, nil
}
switch action {
case "parse":
return t.handleParse(jsonStr)
case "query":
path, _ := arguments["path"].(string)
return t.handleQuery(jsonStr, path)
case "validate":
return t.handleValidate(jsonStr)
default:
return &ToolResult{
ToolName: "json_ops",
Success: false,
Error: fmt.Sprintf("未知操作: %s,支持: parse, query, validate", action),
}, nil
}
}
// handleParse parses a JSON string and returns a formatted version.
func (t *JSONTool) handleParse(jsonStr string) (*ToolResult, error) {
var data interface{}
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
return &ToolResult{
ToolName: "json_ops",
Success: false,
Error: fmt.Sprintf("JSON解析失败: %v", err),
}, nil
}
pretty, err := json.MarshalIndent(data, "", " ")
if err != nil {
return &ToolResult{
ToolName: "json_ops",
Success: false,
Error: fmt.Sprintf("JSON格式化失败: %v", err),
}, nil
}
return &ToolResult{
ToolName: "json_ops",
Success: true,
Data: fmt.Sprintf("解析成功\n格式化输出:\n%s", string(pretty)),
}, nil
}
// handleQuery queries a JSON value by dot-notation path.
func (t *JSONTool) handleQuery(jsonStr, path string) (*ToolResult, error) {
if path == "" {
return &ToolResult{
ToolName: "json_ops",
Success: false,
Error: "query 操作需要 path 参数",
}, nil
}
var data interface{}
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
return &ToolResult{
ToolName: "json_ops",
Success: false,
Error: fmt.Sprintf("JSON解析失败: %v", err),
}, nil
}
value, err := queryPath(data, path)
if err != nil {
return &ToolResult{
ToolName: "json_ops",
Success: false,
Error: err.Error(),
}, nil
}
pretty, err := json.MarshalIndent(value, "", " ")
if err != nil {
return &ToolResult{
ToolName: "json_ops",
Success: true,
Data: fmt.Sprintf("路径: %s\n值: %v", path, value),
}, nil
}
return &ToolResult{
ToolName: "json_ops",
Success: true,
Data: fmt.Sprintf("路径: %s\n值:\n%s", path, string(pretty)),
}, nil
}
// handleValidate validates whether a string is valid JSON.
func (t *JSONTool) handleValidate(jsonStr string) (*ToolResult, error) {
var data interface{}
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
// Try to give a helpful error message
errStr := err.Error()
// Extract line/position info if available
return &ToolResult{
ToolName: "json_ops",
Success: true,
Data: fmt.Sprintf("❌ JSON不合法\n错误: %s", errStr),
}, nil
}
// Determine JSON type
typeName := "object"
switch data.(type) {
case []interface{}:
typeName = "array"
case string:
typeName = "string"
case float64:
typeName = "number"
case bool:
typeName = "boolean"
case nil:
typeName = "null"
}
size := len(jsonStr)
return &ToolResult{
ToolName: "json_ops",
Success: true,
Data: fmt.Sprintf("✅ JSON合法\n类型: %s\n大小: %d bytes", typeName, size),
}, nil
}
// queryPath traverses a JSON value using dot-notation and array index syntax.
// Examples: "users.0.name", "data.list", "items.2"
func queryPath(data interface{}, path string) (interface{}, error) {
// Remove leading "$." if present (JSONPath style)
path = strings.TrimPrefix(path, "$.")
if path == "" || path == "$" {
return data, nil
}
parts := strings.Split(path, ".")
current := data
for _, part := range parts {
switch v := current.(type) {
case map[string]interface{}:
var ok bool
current, ok = v[part]
if !ok {
return nil, fmt.Errorf("路径 '%s' 中字段 '%s' 不存在", path, part)
}
case []interface{}:
idx, err := strconv.Atoi(part)
if err != nil {
return nil, fmt.Errorf("路径 '%s' 中 '%s' 不是有效的数组索引", path, part)
}
if idx < 0 || idx >= len(v) {
return nil, fmt.Errorf("路径 '%s' 中索引 %d 越界(数组长度 %d)", path, idx, len(v))
}
current = v[idx]
default:
return nil, fmt.Errorf("路径 '%s' 中无法继续导航:'%s' 不是对象或数组", path, part)
}
}
return current, nil
}
@@ -0,0 +1,427 @@
package tools
import (
"context"
"fmt"
"regexp"
"strings"
)
// MarkdownTool provides Markdown processing utilities for the LLM.
// Supports HTML conversion, plain text extraction, link/code extraction, and TOC generation.
type MarkdownTool struct{}
// NewMarkdownTool creates a Markdown processing tool.
func NewMarkdownTool() *MarkdownTool {
return &MarkdownTool{}
}
// Definition returns the tool definition for LLM function calling.
func (t *MarkdownTool) Definition() ToolDefinition {
return ToolDefinition{
Name: "markdown",
Description: "Markdown处理工具。将Markdown转为HTML、提取纯文本、提取链接/代码块、生成目录。用于处理Markdown格式的文档内容。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"enum": []string{"to_html", "to_text", "extract_links", "extract_code", "table_of_contents"},
"description": "操作类型。to_html: 转换为HTMLto_text: 提取纯文本;extract_links: 提取所有链接;extract_code: 提取所有代码块;table_of_contents: 生成目录",
},
"markdown": map[string]interface{}{
"type": "string",
"description": "Markdown格式文本,需要处理的Markdown内容",
},
},
"required": []string{"action", "markdown"},
},
}
}
// Execute performs Markdown processing operations.
func (t *MarkdownTool) Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) {
action, ok := arguments["action"].(string)
if !ok || action == "" {
return &ToolResult{
ToolName: "markdown",
Success: false,
Error: "缺少 action 参数",
}, nil
}
md, ok := arguments["markdown"].(string)
if !ok || strings.TrimSpace(md) == "" {
return &ToolResult{
ToolName: "markdown",
Success: false,
Error: "缺少 markdown 参数或内容为空",
}, nil
}
switch action {
case "to_html":
return t.handleToHTML(md)
case "to_text":
return t.handleToText(md)
case "extract_links":
return t.handleExtractLinks(md)
case "extract_code":
return t.handleExtractCode(md)
case "table_of_contents":
return t.handleTableOfContents(md)
default:
return &ToolResult{
ToolName: "markdown",
Success: false,
Error: fmt.Sprintf("未知操作: %s,支持: to_html, to_text, extract_links, extract_code, table_of_contents", action),
}, nil
}
}
// handleToHTML converts Markdown to HTML using simple regex-based approach.
func (t *MarkdownTool) handleToHTML(md string) (*ToolResult, error) {
html := md
// Process in order: code blocks first (to avoid interference), then inline elements, then blocks
// 1. Code blocks (```...```) - preserve with placeholder
codeBlocks := make([]string, 0)
reFence := regexp.MustCompile("(?s)```[^`]*```")
html = reFence.ReplaceAllStringFunc(html, func(match string) string {
codeBlocks = append(codeBlocks, match)
return fmt.Sprintf("\x00CODEBLOCK%d\x00", len(codeBlocks)-1)
})
// 2. Inline code (`...`)
inlineCodes := make([]string, 0)
reInlineCode := regexp.MustCompile("`[^`]+`")
html = reInlineCode.ReplaceAllStringFunc(html, func(match string) string {
inlineCodes = append(inlineCodes, match)
return fmt.Sprintf("\x00INLINECODE%d\x00", len(inlineCodes)-1)
})
// 3. Images ![alt](url)
reImage := regexp.MustCompile(`!\[([^\]]*)\]\(([^)]+)\)`)
html = reImage.ReplaceAllString(html, `<img src="$2" alt="$1">`)
// 4. Links [text](url)
reLink := regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`)
html = reLink.ReplaceAllString(html, `<a href="$2">$1</a>`)
// 5. Bold **text** or __text__
reBold := regexp.MustCompile(`\*\*([^*]+)\*\*`)
html = reBold.ReplaceAllString(html, `<strong>$1</strong>`)
reBold2 := regexp.MustCompile(`__([^_]+)__`)
html = reBold2.ReplaceAllString(html, `<strong>$1</strong>`)
// 6. Italic *text* or _text_
reItalic := regexp.MustCompile(`\*([^*]+)\*`)
html = reItalic.ReplaceAllString(html, `<em>$1</em>`)
reItalic2 := regexp.MustCompile(`_([^_]+)_`)
html = reItalic2.ReplaceAllString(html, `<em>$1</em>`)
// 7. Strikethrough ~~text~~
reStrike := regexp.MustCompile(`~~([^~]+)~~`)
html = reStrike.ReplaceAllString(html, `<del>$1</del>`)
// 8. Headings (# to ######)
reH6 := regexp.MustCompile(`(?m)^######\s+(.+)$`)
html = reH6.ReplaceAllString(html, `<h6>$1</h6>`)
reH5 := regexp.MustCompile(`(?m)^#####\s+(.+)$`)
html = reH5.ReplaceAllString(html, `<h5>$1</h5>`)
reH4 := regexp.MustCompile(`(?m)^####\s+(.+)$`)
html = reH4.ReplaceAllString(html, `<h4>$1</h4>`)
reH3 := regexp.MustCompile(`(?m)^###\s+(.+)$`)
html = reH3.ReplaceAllString(html, `<h3>$1</h3>`)
reH2 := regexp.MustCompile(`(?m)^##\s+(.+)$`)
html = reH2.ReplaceAllString(html, `<h2>$1</h2>`)
reH1 := regexp.MustCompile(`(?m)^#\s+(.+)$`)
html = reH1.ReplaceAllString(html, `<h1>$1</h1>`)
// 9. Horizontal rules
reHR := regexp.MustCompile(`(?m)^(---|\*\*\*|___)\s*$`)
html = reHR.ReplaceAllString(html, `<hr>`)
// 10. Unordered lists (- item)
html = t.processLists(html, `(?m)^[\-*]\s+`, "ul")
// 11. Ordered lists (1. item)
html = t.processLists(html, `(?m)^\d+\.\s+`, "ol")
// 12. Blockquotes
reBlockquote := regexp.MustCompile(`(?m)^>\s?(.+)$`)
html = reBlockquote.ReplaceAllString(html, `<blockquote>$1</blockquote>`)
// 13. Paragraphs: wrap remaining text lines
html = t.wrapParagraphs(html)
// 14. Restore code blocks
for i, cb := range codeBlocks {
// Strip the opening/closing ```
content := strings.TrimPrefix(cb, "```")
content = strings.TrimSuffix(content, "```")
// Extract language if present on first line
lang := ""
content = strings.TrimSpace(content)
if idx := strings.Index(content, "\n"); idx > 0 {
lang = strings.TrimSpace(content[:idx])
content = strings.TrimSpace(content[idx+1:])
}
if lang != "" {
html = strings.ReplaceAll(html, fmt.Sprintf("\x00CODEBLOCK%d\x00", i),
fmt.Sprintf(`<pre><code class="language-%s">%s</code></pre>`, lang, escapeHTML(content)))
} else {
html = strings.ReplaceAll(html, fmt.Sprintf("\x00CODEBLOCK%d\x00", i),
fmt.Sprintf("<pre><code>%s</code></pre>", escapeHTML(content)))
}
}
// 15. Restore inline code
for i, ic := range inlineCodes {
content := strings.Trim(ic, "`")
html = strings.ReplaceAll(html, fmt.Sprintf("\x00INLINECODE%d\x00", i),
fmt.Sprintf("<code>%s</code>", escapeHTML(content)))
}
return &ToolResult{
ToolName: "markdown",
Success: true,
Data: html,
}, nil
}
// handleToText strips Markdown formatting and extracts plain text.
func (t *MarkdownTool) handleToText(md string) (*ToolResult, error) {
text := md
// Remove code blocks
reFence := regexp.MustCompile("(?s)```[^`]*```")
text = reFence.ReplaceAllString(text, "[代码块]")
// Remove inline code
reInlineCode := regexp.MustCompile("`[^`]+`")
text = reInlineCode.ReplaceAllString(text, "[代码]")
// Remove images ![alt](url) - keep alt text
reImage := regexp.MustCompile(`!\[([^\]]*)\]\([^)]+\)`)
text = reImage.ReplaceAllString(text, "$1")
// Remove links [text](url) - keep text
reLink := regexp.MustCompile(`\[([^\]]+)\]\([^)]+\)`)
text = reLink.ReplaceAllString(text, "$1")
// Remove bold/italic markers
text = regexp.MustCompile(`\*\*([^*]+)\*\*`).ReplaceAllString(text, "$1")
text = regexp.MustCompile(`__([^_]+)__`).ReplaceAllString(text, "$1")
text = regexp.MustCompile(`\*([^*]+)\*`).ReplaceAllString(text, "$1")
text = regexp.MustCompile(`_([^_]+)_`).ReplaceAllString(text, "$1")
// Remove strikethrough
text = regexp.MustCompile(`~~([^~]+)~~`).ReplaceAllString(text, "$1")
// Remove heading markers but keep the text
text = regexp.MustCompile(`(?m)^#{1,6}\s+`).ReplaceAllString(text, "")
// Remove horizontal rules
text = regexp.MustCompile(`(?m)^(---|\*\*\*|___)\s*$`).ReplaceAllString(text, "")
// Remove list markers
text = regexp.MustCompile(`(?m)^[\-*]\s+`).ReplaceAllString(text, "")
text = regexp.MustCompile(`(?m)^\d+\.\s+`).ReplaceAllString(text, "")
// Remove blockquote markers
text = regexp.MustCompile(`(?m)^>\s?`).ReplaceAllString(text, "")
// Collapse multiple blank lines
text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n")
return &ToolResult{
ToolName: "markdown",
Success: true,
Data: fmt.Sprintf("纯文本提取结果 (%d 字符):\n\n%s",
len([]rune(text)), strings.TrimSpace(text)),
}, nil
}
// handleExtractLinks extracts all [text](url) links from Markdown.
func (t *MarkdownTool) handleExtractLinks(md string) (*ToolResult, error) {
reLink := regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`)
matches := reLink.FindAllStringSubmatch(md, -1)
if len(matches) == 0 {
return &ToolResult{
ToolName: "markdown",
Success: true,
Data: "未找到任何链接",
}, nil
}
var result strings.Builder
result.WriteString(fmt.Sprintf("提取链接 (共 %d 个):\n\n", len(matches)))
for i, m := range matches {
result.WriteString(fmt.Sprintf("%d. [%s](%s)\n - 文本: %s\n - URL: %s\n\n",
i+1, m[1], m[2], m[1], m[2]))
}
return &ToolResult{
ToolName: "markdown",
Success: true,
Data: strings.TrimSpace(result.String()),
}, nil
}
// handleExtractCode extracts all code blocks from Markdown.
func (t *MarkdownTool) handleExtractCode(md string) (*ToolResult, error) {
reFence := regexp.MustCompile("(?s)```([^`]*)```")
matches := reFence.FindAllStringSubmatch(md, -1)
if len(matches) == 0 {
return &ToolResult{
ToolName: "markdown",
Success: true,
Data: "未找到任何代码块",
}, nil
}
var result strings.Builder
result.WriteString(fmt.Sprintf("提取代码块 (共 %d 个):\n\n", len(matches)))
for i, m := range matches {
content := strings.TrimSpace(m[1])
lang := ""
if idx := strings.Index(content, "\n"); idx > 0 {
lang = strings.TrimSpace(content[:idx])
content = strings.TrimSpace(content[idx+1:])
}
result.WriteString(fmt.Sprintf("--- 代码块 %d", i+1))
if lang != "" {
result.WriteString(fmt.Sprintf(" (语言: %s)", lang))
}
result.WriteString(fmt.Sprintf(" ---\n%s\n\n", truncateText(content, 500)))
}
return &ToolResult{
ToolName: "markdown",
Success: true,
Data: strings.TrimSpace(result.String()),
}, nil
}
// handleTableOfContents generates a table of contents from headings.
func (t *MarkdownTool) handleTableOfContents(md string) (*ToolResult, error) {
reHeading := regexp.MustCompile(`(?m)^(#{1,6})\s+(.+)$`)
matches := reHeading.FindAllStringSubmatch(md, -1)
if len(matches) == 0 {
return &ToolResult{
ToolName: "markdown",
Success: true,
Data: "未找到任何标题,无法生成目录",
}, nil
}
var result strings.Builder
result.WriteString(fmt.Sprintf("文档目录 (共 %d 个标题):\n\n", len(matches)))
for _, m := range matches {
level := len(m[1])
title := strings.TrimSpace(m[2])
indent := strings.Repeat(" ", level-1)
result.WriteString(fmt.Sprintf("%s%s %s\n", indent, strings.Repeat("#", level), title))
}
return &ToolResult{
ToolName: "markdown",
Success: true,
Data: result.String(),
}, nil
}
// --- Markdown helper functions below ---
// processLists wraps consecutive list items in <ul> or <ol> tags.
func (t *MarkdownTool) processLists(html, itemPattern, listTag string) string {
reItem := regexp.MustCompile(itemPattern + `(.+)$`)
lines := strings.Split(html, "\n")
result := make([]string, 0, len(lines))
inList := false
for _, line := range lines {
if reItem.MatchString(line) {
content := reItem.ReplaceAllString(line, "$1")
if !inList {
result = append(result, fmt.Sprintf("<%s>", listTag))
inList = true
}
result = append(result, fmt.Sprintf("<li>%s</li>", content))
} else {
if inList {
result = append(result, fmt.Sprintf("</%s>", listTag))
inList = false
}
result = append(result, line)
}
}
if inList {
result = append(result, fmt.Sprintf("</%s>", listTag))
}
return strings.Join(result, "\n")
}
// wrapParagraphs wraps non-tag lines in <p> tags.
func (t *MarkdownTool) wrapParagraphs(html string) string {
lines := strings.Split(html, "\n")
result := make([]string, 0, len(lines))
skipTags := map[string]bool{
"<h1>": true, "<h2>": true, "<h3>": true, "<h4>": true, "<h5>": true, "<h6>": true,
"<hr>": true, "<ul>": true, "</ul>": true, "<ol>": true, "</ol>": true,
"<li>": true, "</li>": true, "<blockquote>": true, "</blockquote>": true,
"<pre>": true, "</pre>": true, "<img": true,
}
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" {
result = append(result, line)
continue
}
// Check if line starts with an HTML tag
isTag := false
for tag := range skipTags {
if strings.HasPrefix(trimmed, tag) {
isTag = true
break
}
}
if !isTag {
result = append(result, fmt.Sprintf("<p>%s</p>", trimmed))
} else {
result = append(result, line)
}
}
return strings.Join(result, "\n")
}
// escapeHTML escapes special HTML characters.
func escapeHTML(s string) string {
replacer := strings.NewReplacer(
"&", "&"+"amp;",
"<", "&"+"lt;",
">", "&"+"gt;",
"\"", "&"+"quot;",
)
return replacer.Replace(s)
}
// truncateText truncates text to maxLen runes, adding "..." if truncated.
func truncateText(s string, maxLen int) string {
runes := []rune(s)
if len(runes) <= maxLen {
return s
}
return string(runes[:maxLen]) + "..."
}
@@ -0,0 +1,370 @@
package tools
import (
"context"
"crypto/rand"
"encoding/json"
"fmt"
"math/big"
mathrand "math/rand"
"strings"
)
// RandomTool provides random generation utilities for the LLM.
// Supports random numbers, UUIDs, passwords, and list operations.
type RandomTool struct{}
// NewRandomTool creates a random generation tool.
func NewRandomTool() *RandomTool {
return &RandomTool{}
}
// Definition returns the tool definition for LLM function calling.
func (t *RandomTool) Definition() ToolDefinition {
return ToolDefinition{
Name: "random",
Description: "随机生成工具。生成随机数、UUID、安全密码,或从列表中随机选取/打乱元素。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"enum": []string{"number", "uuid", "password", "pick", "shuffle"},
"description": "操作类型。number: 生成随机整数;uuid: 生成UUID v4password: 生成安全密码;pick: 从列表随机选取;shuffle: 随机打乱列表",
},
"min": map[string]interface{}{
"type": "number",
"description": "随机数最小值(用于 number 操作),默认 0",
},
"max": map[string]interface{}{
"type": "number",
"description": "随机数最大值(用于 number 操作),默认 100",
},
"length": map[string]interface{}{
"type": "integer",
"description": "密码长度(用于 password 操作),默认 16",
},
"items": map[string]interface{}{
"type": "array",
"description": "列表项(用于 pick/shuffle 操作),字符串数组",
"items": map[string]interface{}{
"type": "string",
},
},
"count": map[string]interface{}{
"type": "integer",
"description": "选取数量(用于 pick 操作),默认 1",
},
},
"required": []string{"action"},
},
}
}
// Execute performs random generation operations.
func (t *RandomTool) Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) {
action, ok := arguments["action"].(string)
if !ok || action == "" {
return &ToolResult{
ToolName: "random",
Success: false,
Error: "缺少 action 参数",
}, nil
}
switch action {
case "number":
return t.handleNumber(arguments)
case "uuid":
return t.handleUUID()
case "password":
return t.handlePassword(arguments)
case "pick":
return t.handlePick(arguments)
case "shuffle":
return t.handleShuffle(arguments)
default:
return &ToolResult{
ToolName: "random",
Success: false,
Error: fmt.Sprintf("未知操作: %s,支持: number, uuid, password, pick, shuffle", action),
}, nil
}
}
// handleNumber generates a random integer in [min, max].
func (t *RandomTool) handleNumber(arguments map[string]interface{}) (*ToolResult, error) {
minVal := getFloatArg(arguments, "min", 0)
maxVal := getFloatArg(arguments, "max", 100)
if minVal > maxVal {
minVal, maxVal = maxVal, minVal
}
minI := int64(minVal)
maxI := int64(maxVal)
// Use crypto/rand for secure random
rangeVal := maxI - minI + 1
if rangeVal <= 0 {
return &ToolResult{
ToolName: "random",
Success: false,
Error: "无效的数值范围",
}, nil
}
n, err := rand.Int(rand.Reader, big.NewInt(rangeVal))
if err != nil {
// Fallback to math/rand
result := minI + mathrand.Int63n(rangeVal)
return &ToolResult{
ToolName: "random",
Success: true,
Data: fmt.Sprintf("随机整数 [%d, %d]: %d", minI, maxI, result),
}, nil
}
result := minI + n.Int64()
return &ToolResult{
ToolName: "random",
Success: true,
Data: fmt.Sprintf("随机整数 [%d, %d]: %d", minI, maxI, result),
}, nil
}
// handleUUID generates a UUID v4 string.
func (t *RandomTool) handleUUID() (*ToolResult, error) {
uuid := make([]byte, 16)
_, err := rand.Read(uuid)
if err != nil {
return &ToolResult{
ToolName: "random",
Success: false,
Error: fmt.Sprintf("生成UUID失败: %v", err),
}, nil
}
// Set version 4 and variant bits
uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4
uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant 10
uuidStr := fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16])
return &ToolResult{
ToolName: "random",
Success: true,
Data: fmt.Sprintf("UUID v4: %s", uuidStr),
}, nil
}
// handlePassword generates a secure random password.
func (t *RandomTool) handlePassword(arguments map[string]interface{}) (*ToolResult, error) {
length := getIntArg(arguments, "length", 16)
if length < 4 {
length = 16
}
if length > 128 {
length = 128
}
uppercase := "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
lowercase := "abcdefghijklmnopqrstuvwxyz"
digits := "0123456789"
symbols := "!@#$%^&*()_+-=[]{}|;:,.<>?"
allChars := uppercase + lowercase + digits + symbols
password := make([]byte, length)
// Ensure at least one of each character type
password[0] = uppercase[secureIndex(len(uppercase))]
password[1] = lowercase[secureIndex(len(lowercase))]
password[2] = digits[secureIndex(len(digits))]
password[3] = symbols[secureIndex(len(symbols))]
// Fill remaining with random characters from all sets
for i := 4; i < length; i++ {
password[i] = allChars[secureIndex(len(allChars))]
}
// Shuffle the password
shuffleBytes(password)
passwordStr := string(password)
return &ToolResult{
ToolName: "random",
Success: true,
Data: fmt.Sprintf("安全密码 (长度: %d):\n%s\n\n字符集: 大写字母 + 小写字母 + 数字 + 特殊符号",
length, passwordStr),
}, nil
}
// handlePick randomly picks items from a list.
func (t *RandomTool) handlePick(arguments map[string]interface{}) (*ToolResult, error) {
items := getStringSliceArg(arguments, "items")
if len(items) == 0 {
return &ToolResult{
ToolName: "random",
Success: false,
Error: "缺少 items 参数或列表为空",
}, nil
}
count := getIntArg(arguments, "count", 1)
if count < 1 {
count = 1
}
if count > len(items) {
count = len(items)
}
// Shuffle indices and pick first 'count'
indices := make([]int, len(items))
for i := range indices {
indices[i] = i
}
shuffleInts(indices)
picked := make([]string, 0, count)
for i := 0; i < count; i++ {
picked = append(picked, items[indices[i]])
}
var result strings.Builder
result.WriteString(fmt.Sprintf("从 %d 个选项中随机选取 %d 个:\n", len(items), count))
for i, p := range picked {
result.WriteString(fmt.Sprintf(" %d. %s\n", i+1, p))
}
return &ToolResult{
ToolName: "random",
Success: true,
Data: result.String(),
}, nil
}
// handleShuffle randomly shuffles a list.
func (t *RandomTool) handleShuffle(arguments map[string]interface{}) (*ToolResult, error) {
items := getStringSliceArg(arguments, "items")
if len(items) == 0 {
return &ToolResult{
ToolName: "random",
Success: false,
Error: "缺少 items 参数或列表为空",
}, nil
}
// Make a copy and shuffle
shuffled := make([]string, len(items))
copy(shuffled, items)
shuffleStrings(shuffled)
var result strings.Builder
result.WriteString(fmt.Sprintf("随机打乱结果 (共 %d 项):\n", len(shuffled)))
for i, s := range shuffled {
result.WriteString(fmt.Sprintf(" %d. %s\n", i+1, s))
}
return &ToolResult{
ToolName: "random",
Success: true,
Data: result.String(),
}, nil
}
// --- Helper functions ---
// getFloatArg extracts a float64 argument with fallback.
func getFloatArg(arguments map[string]interface{}, key string, fallback float64) float64 {
if v, ok := arguments[key]; ok {
switch val := v.(type) {
case float64:
return val
case int:
return float64(val)
case int64:
return float64(val)
case json.Number:
f, err := val.Float64()
if err == nil {
return f
}
}
}
return fallback
}
// getIntArg extracts an int argument with fallback.
func getIntArg(arguments map[string]interface{}, key string, fallback int) int {
if v, ok := arguments[key]; ok {
switch val := v.(type) {
case float64:
return int(val)
case int:
return val
case int64:
return int(val)
}
}
return fallback
}
// getStringSliceArg extracts a string slice argument.
func getStringSliceArg(arguments map[string]interface{}, key string) []string {
if v, ok := arguments[key]; ok {
switch val := v.(type) {
case []interface{}:
result := make([]string, 0, len(val))
for _, item := range val {
if s, ok := item.(string); ok {
result = append(result, s)
} else {
result = append(result, fmt.Sprintf("%v", item))
}
}
return result
case []string:
return val
}
}
return nil
}
// secureIndex returns a cryptographically secure random index in [0, max).
func secureIndex(max int) int {
if max <= 1 {
return 0
}
n, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
if err != nil {
return mathrand.Intn(max)
}
return int(n.Int64())
}
// shuffleBytes shuffles a byte slice using Fisher-Yates with crypto/rand.
func shuffleBytes(data []byte) {
for i := len(data) - 1; i > 0; i-- {
j := secureIndex(i + 1)
data[i], data[j] = data[j], data[i]
}
}
// shuffleInts shuffles an int slice using Fisher-Yates with crypto/rand.
func shuffleInts(data []int) {
for i := len(data) - 1; i > 0; i-- {
j := secureIndex(i + 1)
data[i], data[j] = data[j], data[i]
}
}
// shuffleStrings shuffles a string slice using Fisher-Yates with crypto/rand.
func shuffleStrings(data []string) {
for i := len(data) - 1; i > 0; i-- {
j := secureIndex(i + 1)
data[i], data[j] = data[j], data[i]
}
}
+345
View File
@@ -0,0 +1,345 @@
package tools
import (
"context"
"fmt"
"regexp"
"strings"
"unicode"
)
// TextTool provides text processing operations for the LLM.
// Supports counting, summarizing, translation, and pattern extraction.
type TextTool struct{}
// NewTextTool creates a text processing tool.
func NewTextTool() *TextTool {
return &TextTool{}
}
// Definition returns the tool definition for LLM function calling.
func (t *TextTool) Definition() ToolDefinition {
return ToolDefinition{
Name: "text",
Description: "文本处理工具。统计文本、生成摘要、翻译文本、正则提取信息。用于处理用户提供的文本内容。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"enum": []string{"count", "summarize", "translate", "extract"},
"description": "操作类型。count: 统计字符/单词/行/段落数;summarize: 提取首段+关键句生成简单摘要;translate: 翻译文本(需指定target_lang);extract: 正则提取邮箱/电话/URL等",
},
"text": map[string]interface{}{
"type": "string",
"description": "输入文本,需要处理的文本内容",
},
"target_lang": map[string]interface{}{
"type": "string",
"enum": []string{"en", "zh", "ja", "ko", "fr", "de"},
"description": "翻译目标语言代码。en: 英语, zh: 中文, ja: 日语, ko: 韩语, fr: 法语, de: 德语",
},
"pattern": map[string]interface{}{
"type": "string",
"description": "正则表达式模式,用于 extract 操作。常用预设: email(邮箱), phone(电话), url(网址)",
},
},
"required": []string{"action", "text"},
},
}
}
// Execute performs text processing operations.
func (t *TextTool) Execute(ctx context.Context, arguments map[string]interface{}) (*ToolResult, error) {
action, ok := arguments["action"].(string)
if !ok || action == "" {
return &ToolResult{
ToolName: "text",
Success: false,
Error: "缺少 action 参数",
}, nil
}
text, ok := arguments["text"].(string)
if !ok || strings.TrimSpace(text) == "" {
return &ToolResult{
ToolName: "text",
Success: false,
Error: "缺少 text 参数或文本为空",
}, nil
}
switch action {
case "count":
return t.handleCount(text)
case "summarize":
return t.handleSummarize(text)
case "translate":
return t.handleTranslate(arguments)
case "extract":
return t.handleExtract(arguments)
default:
return &ToolResult{
ToolName: "text",
Success: false,
Error: fmt.Sprintf("未知操作: %s,支持: count, summarize, translate, extract", action),
}, nil
}
}
// handleCount counts characters, words, lines, and paragraphs in the text.
func (t *TextTool) handleCount(text string) (*ToolResult, error) {
charCount := len([]rune(text))
byteCount := len(text)
words := strings.Fields(text)
wordCount := len(words)
lines := strings.Split(text, "\n")
lineCount := len(lines)
// Count paragraphs (separated by double newlines)
paragraphs := regexp.MustCompile(`\n\s*\n`).Split(text, -1)
paraCount := 0
for _, p := range paragraphs {
if strings.TrimSpace(p) != "" {
paraCount++
}
}
// Count Chinese characters
chineseCount := 0
for _, r := range text {
if unicode.Is(unicode.Han, r) {
chineseCount++
}
}
return &ToolResult{
ToolName: "text",
Success: true,
Data: fmt.Sprintf("文本统计结果:\n- 字符数 (含空格): %d\n- 字符数 (不含空格): %d\n- 字节数: %d\n- 单词数: %d\n- 行数: %d\n- 段落数: %d\n- 中文字符数: %d",
charCount, len([]rune(strings.ReplaceAll(text, " ", ""))),
byteCount, wordCount, lineCount, paraCount, chineseCount),
}, nil
}
// handleSummarize generates a simple summary by extracting the first paragraph and key sentences.
func (t *TextTool) handleSummarize(text string) (*ToolResult, error) {
var result strings.Builder
result.WriteString("文本摘要:\n\n")
// Extract first paragraph
paragraphs := regexp.MustCompile(`\n\s*\n`).Split(text, -1)
var firstPara string
for _, p := range paragraphs {
if trimmed := strings.TrimSpace(p); trimmed != "" {
firstPara = trimmed
break
}
}
if firstPara != "" {
result.WriteString("【首段】\n")
// Truncate if very long
runes := []rune(firstPara)
if len(runes) > 300 {
firstPara = string(runes[:300]) + "..."
}
result.WriteString(firstPara)
result.WriteString("\n\n")
}
// Extract key sentences (longer sentences with important keywords)
sentences := t.splitSentences(text)
keySentences := t.extractKeySentences(sentences, 5)
if len(keySentences) > 0 {
result.WriteString("【关键句】\n")
for i, s := range keySentences {
result.WriteString(fmt.Sprintf("%d. %s\n", i+1, s))
}
}
// Overall stats
lines := strings.Split(text, "\n")
words := strings.Fields(text)
result.WriteString(fmt.Sprintf("\n【概况】共 %d 段、%d 句、%d 词、%d 行",
len(paragraphs), len(sentences), len(words), len(lines)))
return &ToolResult{
ToolName: "text",
Success: true,
Data: result.String(),
}, nil
}
// splitSentences splits text into sentences based on punctuation.
func (t *TextTool) splitSentences(text string) []string {
re := regexp.MustCompile(`[^。!?.!?\n]+[。!?.!?\n]?`)
return re.FindAllString(text, -1)
}
// extractKeySentences selects the most informative sentences (longer ones with keyword hints).
func (t *TextTool) extractKeySentences(sentences []string, maxCount int) []string {
type scored struct {
text string
score int
}
var scoredList []scored
keywords := []string{"重要", "关键", "核心", "主要", "首先", "最后", "因此", "所以", "总结",
"important", "key", "critical", "significant", "therefore", "conclusion", "summary"}
for _, s := range sentences {
trimmed := strings.TrimSpace(s)
if len([]rune(trimmed)) < 10 {
continue
}
score := len([]rune(trimmed)) // longer sentences are more likely informative
lower := strings.ToLower(trimmed)
for _, kw := range keywords {
if strings.Contains(lower, kw) {
score += 50
}
}
scoredList = append(scoredList, scored{text: trimmed, score: score})
}
// Sort by score descending (simple bubble sort for small lists)
for i := 0; i < len(scoredList); i++ {
for j := i + 1; j < len(scoredList); j++ {
if scoredList[j].score > scoredList[i].score {
scoredList[i], scoredList[j] = scoredList[j], scoredList[i]
}
}
}
result := make([]string, 0, maxCount)
for i := 0; i < len(scoredList) && i < maxCount; i++ {
result = append(result, scoredList[i].text)
}
return result
}
// handleTranslate provides a translation placeholder (actual translation requires LLM).
func (t *TextTool) handleTranslate(arguments map[string]interface{}) (*ToolResult, error) {
text, _ := arguments["text"].(string)
targetLang, _ := arguments["target_lang"].(string)
if targetLang == "" {
targetLang = "zh"
}
langNames := map[string]string{
"en": "英语",
"zh": "中文",
"ja": "日语",
"ko": "韩语",
"fr": "法语",
"de": "德语",
}
langName, ok := langNames[targetLang]
if !ok {
langName = targetLang
}
return &ToolResult{
ToolName: "text",
Success: true,
Data: fmt.Sprintf("【翻译请求】\n目标语言: %s (%s)\n原文 (%d 字符):\n---\n%s\n---\n\n提示: 实际翻译由LLM完成,请基于以上原文和目标语言进行翻译。",
langName, targetLang, len([]rune(text)), text),
}, nil
}
// handleExtract extracts patterns like emails, phones, URLs from text using regex.
func (t *TextTool) handleExtract(arguments map[string]interface{}) (*ToolResult, error) {
text, _ := arguments["text"].(string)
pattern, _ := arguments["pattern"].(string)
// Predefined patterns
presets := map[string]string{
"email": `[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}`,
"phone": `(?:\+?86[\-\s]?)?1[3-9]\d{9}`,
"url": `https?://[^\s<>"{}|\\^` + "`" + `\[\]]+`,
}
if preset, ok := presets[strings.ToLower(pattern)]; ok {
pattern = preset
}
if pattern == "" {
// Extract all common patterns when no specific pattern given
var result strings.Builder
result.WriteString("文本提取结果:\n\n")
for name, p := range presets {
re, err := regexp.Compile(p)
if err != nil {
continue
}
matches := re.FindAllString(text, -1)
if len(matches) > 0 {
result.WriteString(fmt.Sprintf("【%s】(共 %d 个):\n", name, len(matches)))
seen := make(map[string]bool)
for _, m := range matches {
if !seen[m] {
result.WriteString(fmt.Sprintf(" - %s\n", m))
seen[m] = true
}
}
result.WriteString("\n")
}
}
if result.Len() == len("文本提取结果:\n\n") {
return &ToolResult{
ToolName: "text",
Success: true,
Data: "未提取到匹配的内容(邮箱、电话、URL)",
}, nil
}
return &ToolResult{
ToolName: "text",
Success: true,
Data: result.String(),
}, nil
}
// Use custom regex pattern
re, err := regexp.Compile(pattern)
if err != nil {
return &ToolResult{
ToolName: "text",
Success: false,
Error: fmt.Sprintf("正则表达式无效: %v", err),
}, nil
}
matches := re.FindAllString(text, -1)
if len(matches) == 0 {
return &ToolResult{
ToolName: "text",
Success: true,
Data: fmt.Sprintf("未找到匹配模式 '%s' 的内容", pattern),
}, nil
}
var result strings.Builder
result.WriteString(fmt.Sprintf("正则提取结果 (模式: %s, 共 %d 个匹配):\n", pattern, len(matches)))
seen := make(map[string]bool)
for _, m := range matches {
if !seen[m] {
result.WriteString(fmt.Sprintf(" - %s\n", m))
seen[m] = true
}
}
return &ToolResult{
ToolName: "text",
Success: true,
Data: result.String(),
}, nil
}
+412 -105
View File
@@ -399,6 +399,26 @@ input[type="range"] { accent-color: var(--accent); padding: 0; }
display: flex; align-items: center; gap: 10px; margin-bottom: 14px;
}
.iot-last-update { font-size: 11px; color: var(--text3); }
/* ========== 记忆卡片样式 ========== */
.mem-card {
transition: all 0.2s ease;
}
.mem-card:hover {
box-shadow: 0 4px 20px rgba(0,0,0,.35);
transform: translateY(-2px);
border-color: var(--accent) !important;
}
.mem-card.mem-card-high:hover {
border-color: #f59e0b !important;
box-shadow: 0 4px 24px rgba(245,158,11,.25);
}
.mem-cat-tab.active {
background: var(--accent) !important;
color: #fff !important;
border-color: var(--accent) !important;
font-weight: 600;
}
</style>
</head>
<body>
@@ -495,6 +515,15 @@ const STATE = {
dashboardRenderCount: 0,
// 资源使用 60s 滑动窗口历史 (Bug 6)
resourceHistory: {},
// 记忆面板状态
memoryCache: [],
memoryUserId: 'admin_admin',
memoryFilterCategory: 'all',
memorySortBy: 'importance',
memorySortDir: 'desc',
memoryFilterImportance: 0,
memorySearchText: '',
memoryPanelInitialized: false,
};
// ========== WebSocket ==========
@@ -1038,92 +1067,175 @@ function renderDashboardSvcCards(svcs) {
`).join('');
}
// ========== 记忆分类颜色映射 ==========
const MEMORY_CAT_COLORS = {
'user_preference': { bg: 'rgba(168,85,247,.15)', text: '#a855f7', name: '用户偏好', icon: '💜' },
'personal_info': { bg: 'rgba(59,130,246,.15)', text: '#3b82f6', name: '个人信息', icon: '👤' },
'conversation': { bg: 'rgba(34,197,94,.15)', text: '#22c55e', name: '对话摘要', icon: '💬' },
'knowledge': { bg: 'rgba(249,115,22,.15)', text: '#f97316', name: '知识信息', icon: '📚' },
'event': { bg: 'rgba(239,68,68,.15)', text: '#ef4444', name: '事件记录', icon: '📅' },
'task': { bg: 'rgba(234,179,8,.15)', text: '#eab308', name: '任务计划', icon: '✅' },
'relationship': { bg: 'rgba(236,72,153,.15)', text: '#ec4899', name: '关系情感', icon: '💕' },
// 向后兼容旧分类
'preference': { bg: 'rgba(168,85,247,.15)', text: '#a855f7', name: '偏好', icon: '💜' },
'fact': { bg: 'rgba(59,130,246,.15)', text: '#3b82f6', name: '事实', icon: '👤' },
'experience': { bg: 'rgba(249,115,22,.15)', text: '#f97316', name: '经验', icon: '📚' },
'other': { bg: 'rgba(139,148,158,.15)', text: '#8b949e', name: '其他', icon: '📌' },
'habit': { bg: 'rgba(168,85,247,.15)', text: '#a855f7', name: '习惯', icon: '💜' },
};
function getCatColor(cat) {
return MEMORY_CAT_COLORS[cat] || MEMORY_CAT_COLORS['other'];
}
// ========== 面板2: 记忆管理 ==========
function renderMemoryPanel() {
document.getElementById('panel-memory').innerHTML = `
<div class="cards-grid cards-2">
<!-- 搜索区域 -->
<div class="card">
<div class="card-header"><span class="card-title">🔍 搜索记忆</span></div>
<div class="form-row">
<div class="form-group"><label>用户ID</label><input type="text" id="mem-user-id" placeholder="admin_admin" value="admin_admin"></div>
<div class="form-group"><label>关键词</label><input type="text" id="mem-search-q" placeholder="输入搜索关键词..."></div>
const isFirst = !STATE.memoryPanelInitialized;
STATE.memoryPanelInitialized = true;
// 首次渲染完整 DOM 结构
if (isFirst) {
document.getElementById('panel-memory').innerHTML = `
<!-- 搜索 & 添加行 -->
<div class="cards-grid cards-2" style="margin-bottom:14px">
<div class="card" style="margin:0">
<div class="card-header"><span class="card-title">🔍 搜索与筛选</span></div>
<div class="form-row">
<div class="form-group" style="flex:1">
<label>用户ID</label>
<input type="text" id="mem-user-id" placeholder="admin_admin" value="${escHtml(STATE.memoryUserId)}">
</div>
<div class="form-group" style="flex:2">
<label>全文搜索</label>
<input type="text" id="mem-search-text" placeholder="输入关键词搜索记忆内容..." value="${escHtml(STATE.memorySearchText)}"
oninput="STATE.memorySearchText=this.value;filterAndRenderMemories()">
</div>
</div>
<div class="form-row" style="margin-top:4px">
<div class="form-group" style="flex:1">
<label>最低重要性 ≥ <span id="mem-imp-val">${STATE.memoryFilterImportance}</span></label>
<input type="range" id="mem-filter-importance" min="0" max="10" value="${STATE.memoryFilterImportance}"
oninput="document.getElementById('mem-imp-val').textContent=this.value;STATE.memoryFilterImportance=parseInt(this.value);filterAndRenderMemories()">
</div>
<div class="form-group" style="flex:1;display:flex;align-items:flex-end">
<button class="btn btn-accent btn-sm" onclick="loadMemories()" style="width:100%">🔍 查询记忆</button>
</div>
</div>
</div>
<div class="btn-group" style="margin-top:4px">
<button class="btn btn-accent btn-sm" onclick="searchMemory()">🔍 搜索</button>
<button class="btn btn-sm" onclick="listMemory()">📋 列表全部</button>
<div class="card" style="margin:0">
<div class="card-header"><span class="card-title"> 添加记忆</span></div>
<div class="form-group"><label>用户ID</label><input type="text" id="mem-add-user-id" placeholder="admin_admin" value="admin_admin"></div>
<div class="form-group"><label>内容</label><textarea id="mem-add-content" placeholder="输入记忆内容..." rows="2"></textarea></div>
<div class="form-row">
<div class="form-group">
<label>分类</label>
<select id="mem-add-category">
<option value="user_preference">用户偏好</option>
<option value="personal_info">个人信息</option>
<option value="conversation">对话摘要</option>
<option value="knowledge">知识信息</option>
<option value="event">事件记录</option>
<option value="task">任务计划</option>
<option value="relationship">关系情感</option>
</select>
</div>
<div class="form-group">
<label>重要程度: <span id="mem-priority-val">5</span></label>
<input type="range" id="mem-add-importance" min="1" max="10" value="5" oninput="document.getElementById('mem-priority-val').textContent=this.value">
</div>
</div>
<button class="btn btn-accent btn-sm" onclick="addMemory()"> 添加</button>
</div>
</div>
<!-- 添加记忆 -->
<div class="card">
<div class="card-header"><span class="card-title"> 添加记忆</span></div>
<div class="form-group"><label>用户ID</label><input type="text" id="mem-add-user-id" placeholder="admin_admin" value="admin_admin"></div>
<div class="form-group"><label>内容</label><textarea id="mem-add-content" placeholder="输入记忆内容..."></textarea></div>
<div class="form-row">
<div class="form-group">
<label>分类</label>
<select id="mem-add-category">
<option value="preference">偏好</option>
<option value="fact">事实</option>
<option value="experience">经验</option>
<option value="other">其他</option>
<!-- 统计面板 -->
<div class="card" id="mem-stats-card" style="margin-bottom:14px">
<div class="card-header"><span class="card-title">📊 记忆统计</span></div>
<div id="mem-stats-content">
<div class="empty-state"><div class="icon">📊</div>加载记忆后显示统计</div>
</div>
</div>
<!-- 分类筛选标签栏 -->
<div class="card" style="margin-bottom:10px;padding:12px 16px">
<div style="display:flex;align-items:center;justify-content:space-between;flex-wrap:wrap;gap:8px">
<div id="mem-cat-tabs" style="display:flex;gap:6px;flex-wrap:wrap">
<!-- 动态填充 -->
</div>
<div style="display:flex;align-items:center;gap:8px">
<select id="mem-sort-by" onchange="STATE.memorySortBy=this.value;filterAndRenderMemories()" style="width:auto;padding:4px 8px;font-size:11px;background:var(--bg3);color:var(--text);border:1px solid var(--border);border-radius:4px">
<option value="importance" ${STATE.memorySortBy==='importance'?'selected':''}>⭐ 重要性</option>
<option value="created_at" ${STATE.memorySortBy==='created_at'?'selected':''}>🕐 创建时间</option>
<option value="updated_at" ${STATE.memorySortBy==='updated_at'?'selected':''}>🕐 更新时间</option>
<option value="category" ${STATE.memorySortBy==='category'?'selected':''}>🏷️ 分类</option>
<option value="access_count" ${STATE.memorySortBy==='access_count'?'selected':''}>📊 访问次数</option>
</select>
</div>
<div class="form-group">
<label>优先级: <span id="mem-priority-val">3</span></label>
<input type="range" id="mem-add-priority" min="1" max="5" value="3" oninput="document.getElementById('mem-priority-val').textContent=this.value">
<select id="mem-sort-dir" onchange="STATE.memorySortDir=this.value;filterAndRenderMemories()" style="width:auto;padding:4px 8px;font-size:11px;background:var(--bg3);color:var(--text);border:1px solid var(--border);border-radius:4px">
<option value="desc" ${STATE.memorySortDir==='desc'?'selected':''}>↓ 降序</option>
<option value="asc" ${STATE.memorySortDir==='asc'?'selected':''}>↑ 升序</option>
</select>
<span id="mem-result-count" style="font-size:11px;color:var(--text2);white-space:nowrap"></span>
</div>
</div>
<button class="btn btn-accent btn-sm" onclick="addMemory()"> 添加</button>
</div>
</div>
<!-- 结果表格 -->
<div class="card">
<div class="card-header">
<span class="card-title">📋 记忆列表</span>
<span style="display:flex;align-items:center;gap:8px">
<select id="mem-sort-order" onchange="sortAndRenderMemories()" style="background:var(--bg3);color:var(--text);border:1px solid var(--border);border-radius:4px;padding:2px 6px;font-size:11px">
<option value="desc">🕐 最新优先</option>
<option value="asc">🕐 最早优先</option>
</select>
<span id="mem-result-count" style="font-size:11px;color:var(--text2)"></span>
</span>
<!-- 记忆卡片网格 -->
<div id="mem-cards-grid" style="display:grid;grid-template-columns:repeat(auto-fill,minmax(340px,1fr));gap:14px">
<div class="empty-state" style="grid-column:1/-1"><div class="icon">🧠</div>点击「查询记忆」加载记忆数据</div>
</div>
<div class="table-wrap">
<table>
<thead><tr><th>内容</th><th>分类</th><th>优先级</th><th>用户</th><th>话题 (会话)</th><th>创建时间</th><th style="width:50px">操作</th></tr></thead>
<tbody id="mem-table-body">
<tr><td colspan="7"><div class="empty-state"><div class="icon">🧠</div>使用搜索或列表按钮加载记忆</div></td></tr>
</tbody>
</table>
</div>
</div>
`;
`;
// 初始化分类标签
renderCategoryTabs();
}
// 如果有缓存数据,直接渲染
if (STATE.memoryCache.length > 0) {
filterAndRenderMemories();
}
}
async function searchMemory() {
const userId = document.getElementById('mem-user-id').value.trim();
const q = document.getElementById('mem-search-q').value.trim();
if (!userId) { showToast('请输入用户ID', 'error'); return; }
if (!q) { showToast('请输入搜索关键词', 'error'); return; }
function renderCategoryTabs() {
const container = document.getElementById('mem-cat-tabs');
if (!container) return;
const data = await api(`/api/memory/search?user_id=${encodeURIComponent(userId)}&q=${encodeURIComponent(q)}`);
renderMemoryResults(data);
const categories = [
{ key: 'all', name: '全部', icon: '📋' },
{ key: 'user_preference', name: '用户偏好', icon: '💜' },
{ key: 'personal_info', name: '个人信息', icon: '👤' },
{ key: 'conversation', name: '对话摘要', icon: '💬' },
{ key: 'knowledge', name: '知识信息', icon: '📚' },
{ key: 'event', name: '事件记录', icon: '📅' },
{ key: 'task', name: '任务计划', icon: '✅' },
{ key: 'relationship', name: '关系情感', icon: '💕' },
];
container.innerHTML = categories.map(c => {
const active = STATE.memoryFilterCategory === c.key;
const style = active
? 'background:var(--accent);color:#fff;border-color:var(--accent);font-weight:600'
: 'background:var(--bg3);color:var(--text2);border-color:var(--border)';
return `<button class="mem-cat-tab" data-cat="${c.key}"
style="padding:4px 12px;border-radius:16px;border:1px solid;cursor:pointer;font-size:11px;transition:all .15s;font-family:inherit;${style}"
onmouseenter="if(!this.classList.contains('active')){this.style.background='var(--bg4)';this.style.color='var(--text)'}"
onmouseleave="if(!this.classList.contains('active')){this.style.background='var(--bg3)';this.style.color='var(--text2)'}"
onclick="switchMemoryCategory('${c.key}')">${c.icon} ${c.name}</button>`;
}).join('');
}
async function listMemory() {
function switchMemoryCategory(cat) {
STATE.memoryFilterCategory = cat;
renderCategoryTabs();
filterAndRenderMemories();
}
async function loadMemories() {
const userId = document.getElementById('mem-user-id').value.trim();
if (!userId) { showToast('请输入用户ID', 'error'); return; }
STATE.memoryUserId = userId;
const data = await api(`/api/memory/list?user_id=${encodeURIComponent(userId)}`);
renderMemoryResults(data);
}
function renderMemoryResults(data) {
const tbody = document.getElementById('mem-table-body');
const countEl = document.getElementById('mem-result-count');
if (data.error) {
let hint = '';
@@ -1134,72 +1246,266 @@ function renderMemoryResults(data) {
} else if (data.status === 502) {
hint = '<br><span style="font-size:11px">💡 提示: 请确认 Gateway 和 AI-Core 服务已启动</span>';
}
tbody.innerHTML = `<tr><td colspan="7"><div class="empty-state"><div class="icon">⚠️</div>${escHtml(data.error)}${hint}</div></td></tr>`;
countEl.textContent = '';
const grid = document.getElementById('mem-cards-grid');
if (grid) grid.innerHTML = `<div class="empty-state" style="grid-column:1/-1"><div class="icon">⚠️</div>${escHtml(data.error)}${hint}</div>`;
document.getElementById('mem-result-count').textContent = '';
STATE.memoryCache = [];
renderStatsPanel();
return;
}
// 兼容不同返回格式
let memories = [];
if (Array.isArray(data)) memories = data;
else if (data.memories) memories = data.memories;
else if (data.results) memories = data.results;
// 缓存记忆数据用于排序
STATE.memoryCache = memories;
sortAndRenderMemories();
filterAndRenderMemories();
}
function sortAndRenderMemories() {
const tbody = document.getElementById('mem-table-body');
const countEl = document.getElementById('mem-result-count');
const sortOrder = document.getElementById('mem-sort-order')?.value || 'desc';
async function searchMemory() {
const userId = document.getElementById('mem-user-id').value.trim();
const q = STATE.memorySearchText || document.getElementById('mem-search-text')?.value?.trim() || '';
if (!userId) { showToast('请输入用户ID', 'error'); return; }
if (!q) { showToast('请输入搜索关键词', 'error'); return; }
let memories = [...(STATE.memoryCache || [])];
STATE.memoryUserId = userId;
// 按创建时间排序
memories.sort((a, b) => {
const ta = new Date(a.created_at || 0).getTime();
const tb = new Date(b.created_at || 0).getTime();
return sortOrder === 'asc' ? ta - tb : tb - ta;
});
const data = await api(`/api/memory/search?user_id=${encodeURIComponent(userId)}&q=${encodeURIComponent(q)}`);
countEl.textContent = `${memories.length}`;
if (memories.length === 0) {
tbody.innerHTML = '<tr><td colspan="7"><div class="empty-state"><div class="icon">📭</div>没有找到记忆</div></td></tr>';
if (data.error) {
let hint = '';
if (data.errorType === 'gateway_not_running' || data.errorType === 'gateway_auth_failed') {
hint = '<br><span style="font-size:11px">💡 提示: 请先在「服务管理」面板中启动 Gateway 服务</span>';
} else if (data.errorType === 'gateway_unreachable') {
hint = '<br><span style="font-size:11px">💡 提示: Gateway 服务无响应,请检查网络连接和服务状态</span>';
} else if (data.status === 502) {
hint = '<br><span style="font-size:11px">💡 提示: 请确认 Gateway 和 AI-Core 服务已启动</span>';
}
const grid = document.getElementById('mem-cards-grid');
if (grid) grid.innerHTML = `<div class="empty-state" style="grid-column:1/-1"><div class="icon">⚠️</div>${escHtml(data.error)}${hint}</div>`;
document.getElementById('mem-result-count').textContent = '';
STATE.memoryCache = [];
renderStatsPanel();
return;
}
tbody.innerHTML = memories.map(m => {
// 会话ID 简短显示
const sid = m.session_id || '—';
const sidShort = sid.length > 16 ? sid.substring(0, 14) + '…' : sid;
return `
<tr>
<td style="max-width:260px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap" title="${escHtml(m.content || '')}">${escHtml((m.content || '').substring(0, 80))}</td>
<td><span class="badge badge-idle">${escHtml(m.category || 'other')}</span></td>
<td>${m.priority ?? 1}</td>
<td style="color:var(--text2)">${escHtml(m.user_id || '—')}</td>
<td style="color:var(--text2);font-size:11px" title="${escHtml(sid)}">${escHtml(sidShort)}</td>
<td style="color:var(--text2);font-size:11px">${formatTime(m.created_at)}</td>
<td><button class="btn btn-xs btn-red" onclick="deleteMemory('${escHtml(m.id || m.ID || '')}')" title="删除">🗑</button></td>
</tr>
`}).join('');
let memories = [];
if (Array.isArray(data)) memories = data;
else if (data.memories) memories = data.memories;
else if (data.results) memories = data.results;
STATE.memoryCache = memories;
filterAndRenderMemories();
}
// 兼容旧的 listMemory 调用
async function listMemory() {
loadMemories();
}
function filterAndRenderMemories() {
const memories = STATE.memoryCache || [];
// 1. 分类筛选
let filtered = memories;
if (STATE.memoryFilterCategory !== 'all') {
filtered = filtered.filter(m => m.category === STATE.memoryFilterCategory);
}
// 2. 重要性筛选
if (STATE.memoryFilterImportance > 0) {
filtered = filtered.filter(m => (m.importance || 0) >= STATE.memoryFilterImportance);
}
// 3. 全文搜索 (客户端二次过滤)
if (STATE.memorySearchText) {
const q = STATE.memorySearchText.toLowerCase();
filtered = filtered.filter(m => {
const content = (m.content || '').toLowerCase();
const summary = (m.summary || '').toLowerCase();
const keywords = (m.keywords || []).join(' ').toLowerCase();
return content.includes(q) || summary.includes(q) || keywords.includes(q);
});
}
// 4. 排序
const sortBy = STATE.memorySortBy;
const sortDir = STATE.memorySortDir === 'asc' ? 1 : -1;
filtered.sort((a, b) => {
let va, vb;
switch (sortBy) {
case 'importance':
va = a.importance || 0; vb = b.importance || 0; break;
case 'created_at':
va = new Date(a.created_at || 0).getTime(); vb = new Date(b.created_at || 0).getTime(); break;
case 'updated_at':
va = new Date(a.updated_at || a.created_at || 0).getTime();
vb = new Date(b.updated_at || b.created_at || 0).getTime(); break;
case 'category':
va = a.category || ''; vb = b.category || '';
return sortDir * va.localeCompare(vb);
case 'access_count':
va = a.access_count || 0; vb = b.access_count || 0; break;
default:
va = a.importance || 0; vb = b.importance || 0;
}
return sortDir * (va - vb);
});
// 5. 渲染统计
renderStatsPanel();
// 6. 渲染卡片
renderMemoryCards(filtered);
// 更新计数
const countEl = document.getElementById('mem-result-count');
if (countEl) {
countEl.textContent = `显示 ${filtered.length} / ${memories.length}`;
}
}
function renderStatsPanel() {
const container = document.getElementById('mem-stats-content');
if (!container) return;
const memories = STATE.memoryCache || [];
if (memories.length === 0) {
container.innerHTML = '<div class="empty-state"><div class="icon">📊</div>暂无记忆数据</div>';
return;
}
// 计算各分类数量
const catCount = {};
let totalImportance = 0;
let totalAccess = 0;
memories.forEach(m => {
const cat = m.category || 'other';
catCount[cat] = (catCount[cat] || 0) + 1;
totalImportance += (m.importance || 0);
totalAccess += (m.access_count || 0);
});
const avgImportance = (totalImportance / memories.length).toFixed(1);
const maxCatCount = Math.max(1, ...Object.values(catCount));
// 分类分布条
const catOrder = ['user_preference', 'personal_info', 'conversation', 'knowledge', 'event', 'task', 'relationship'];
const barHtml = catOrder.map(cat => {
const count = catCount[cat] || 0;
const pct = Math.round((count / Math.max(1, memories.length)) * 100);
const cc = getCatColor(cat);
return count > 0 ? `<div style="display:flex;align-items:center;gap:4px;font-size:10px">
<span style="color:${cc.text};min-width:52px">${cc.icon} ${cc.name}</span>
<div style="flex:1;background:var(--bg);border-radius:3px;height:14px;overflow:hidden">
<div style="height:100%;width:${pct}%;background:${cc.text};border-radius:3px;transition:width .3s;min-width:2px"></div>
</div>
<span style="color:var(--text2);min-width:32px;text-align:right;font-family:'JetBrains Mono',monospace">${count}</span>
</div>` : '';
}).join('');
container.innerHTML = `
<div class="cards-grid cards-4" style="margin-bottom:10px">
<div class="stat-card accent"><div class="stat-value">${memories.length}</div><div class="stat-label">📦 总记忆数</div></div>
<div class="stat-card blue"><div class="stat-value">${avgImportance}</div><div class="stat-label">⭐ 平均重要性</div></div>
<div class="stat-card green"><div class="stat-value">${totalAccess}</div><div class="stat-label">📊 总访问次数</div></div>
<div class="stat-card orange"><div class="stat-value">${Object.values(catCount).filter(n=>n>0).length}</div><div class="stat-label">🏷️ 分类数</div></div>
</div>
<div style="display:flex;flex-direction:column;gap:3px">${barHtml}</div>
`;
}
function renderMemoryCards(memories) {
const grid = document.getElementById('mem-cards-grid');
if (!grid) return;
if (memories.length === 0) {
const catName = STATE.memoryFilterCategory !== 'all'
? (getCatColor(STATE.memoryFilterCategory).name || STATE.memoryFilterCategory)
: '';
const msg = catName ? `${catName}」分类下暂无记忆` : '没有匹配的记忆';
grid.innerHTML = `<div class="empty-state" style="grid-column:1/-1"><div class="icon">📭</div>${msg}</div>`;
return;
}
grid.innerHTML = memories.map(m => renderMemoryCard(m)).join('');
}
function renderMemoryCard(m) {
const cat = m.category || 'other';
const cc = getCatColor(cat);
const importance = m.importance || 1;
const isHighImportance = importance >= 8;
// 星级
const stars = importanceToStars(importance);
// 关键词标签
const keywords = m.keywords || [];
const kwTags = keywords.length > 0
? keywords.slice(0, 5).map(k => `<span style="display:inline-block;padding:1px 7px;background:var(--bg3);border-radius:10px;font-size:10px;color:var(--text2);margin:1px">${escHtml(k)}</span>`).join('')
: '';
// 来源
const sourceLabel = m.source === 'thinking' ? '🤔 后台思考' : m.source === 'conversation' ? '💬 对话' : '📝 ' + (m.source || '未知');
// 会话 ID 简短
const sid = m.session_id || '';
const sidShort = sid.length > 20 ? sid.substring(0, 18) + '…' : sid;
return `
<div class="mem-card ${isHighImportance ? 'mem-card-high' : ''}"
style="background:var(--bg2);border:1px solid ${isHighImportance ? '#f59e0b' : 'var(--border)'};border-radius:var(--radius);padding:16px;display:flex;flex-direction:column;gap:10px;transition:all .2s">
<!-- 头部: 分类标签 + 重要性 -->
<div style="display:flex;align-items:center;justify-content:space-between">
<span style="display:inline-flex;align-items:center;gap:4px;padding:2px 10px;border-radius:12px;font-size:11px;font-weight:600;background:${cc.bg};color:${cc.text}">
${cc.icon} ${cc.name}
</span>
<span style="font-size:13px;color:#f59e0b" title="重要程度: ${importance}/10">${stars}</span>
</div>
<!-- 记忆内容 -->
<div style="flex:1">
<div style="color:var(--text);font-size:13px;line-height:1.6;word-break:break-word">
${escHtml(m.content || '')}
</div>
${m.summary ? `<div style="color:var(--text2);font-size:11px;margin-top:4px;font-style:italic">📌 ${escHtml(m.summary)}</div>` : ''}
</div>
<!-- 关键词标签 -->
${kwTags ? `<div style="display:flex;flex-wrap:wrap;gap:3px">${kwTags}</div>` : ''}
<!-- 底部元信息 -->
<div style="display:flex;align-items:center;flex-wrap:wrap;gap:8px;font-size:10px;color:var(--text3);border-top:1px solid var(--border);padding-top:8px">
<span title="来源">${sourceLabel}</span>
${sid ? `<span title="会话: ${escHtml(sid)}" style="max-width:140px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap">💬 ${escHtml(sidShort)}</span>` : ''}
<span title="访问次数">📊 ${m.access_count || 0}</span>
<span title="创建时间">📅 ${formatTime(m.created_at)}</span>
<span title="更新时间" style="${(m.updated_at && m.updated_at !== m.created_at) ? '' : 'display:none'}">🔄 ${formatTime(m.updated_at)}</span>
<button class="btn btn-xs btn-red" onclick="deleteMemory('${escHtml(m.id || m.ID || '')}')" title="删除" style="margin-left:auto">🗑</button>
</div>
</div>
`;
}
function importanceToStars(imp) {
const full = Math.round(imp / 2); // 1-10 映射到 1-5 星
const empty = 5 - full;
return '★'.repeat(full) + '☆'.repeat(empty);
}
async function addMemory() {
const user_id = document.getElementById('mem-add-user-id').value.trim();
const content = document.getElementById('mem-add-content').value.trim();
const category = document.getElementById('mem-add-category').value;
const priority = parseInt(document.getElementById('mem-add-priority').value);
const importance = parseInt(document.getElementById('mem-add-importance').value);
if (!user_id || !content) { showToast('请填写用户ID和内容', 'error'); return; }
const data = await api('/api/memory/add', {
method: 'POST',
body: JSON.stringify({ user_id, content, category, priority }),
body: JSON.stringify({ user_id, content, category, importance }),
});
if (data.error) { showToast(`添加失败: ${data.error}`, 'error'); return; }
@@ -1207,7 +1513,7 @@ async function addMemory() {
showToast('记忆添加成功!', 'success');
document.getElementById('mem-add-content').value = '';
// 自动刷新列表
listMemory();
loadMemories();
}
async function deleteMemory(memoryId) {
@@ -1219,8 +1525,9 @@ async function deleteMemory(memoryId) {
if (data.error) { showToast(`删除失败: ${data.error}`, 'error'); return; }
showToast('记忆删除成功!', 'success');
// 自动刷新列表
listMemory();
// 从缓存中移除并重新渲染
STATE.memoryCache = (STATE.memoryCache || []).filter(m => (m.id || m.ID) !== memoryId);
filterAndRenderMemories();
}
// ========== 面板3: 会话监看 ==========