fix: 修复 AI 回复无法送达发送者 + 重复消息 + action角色泄露 + OS环境支持
广播逻辑重构: - AI 回复 (stream_start/response/stream_segments/multi_message/stream_end) 改用 broadcastToUser 发送给所有客户端 - 用户消息回显保持 broadcastToUserExcept 排除发送者 消息去重与角色修复: - CacheMessage(user) 移至回复生成后,避免本轮 LLM 调用出现重复用户消息 - action 角色消息在 DB 存储时映射为 assistant,DeepSeek 等模型不支持自定义角色 - stream_end defer 机制确保错误路径也会终止客户端思考指示器 OS 完整环境支持: - host 包重构为 HostBackend 接口 + Direct/WSL/Docker 三种后端 - 新增 os_exec/os_file/os_system 工具供 AI 在完整 Linux 环境中自由操作 其他: - 视觉模型注入 + 图片预处理后清空 Images 避免传给 Chat 模型 - 图片 URL 相对路径→绝对 URL 转换 - DevTools 链路追踪页面 + 重启修复 - 记忆搜索模糊匹配增强 - 后台思考定时调度支持 - 管理后台页面 (模型配置/用户管理等) - docs/api 更新广播机制说明 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -76,3 +76,14 @@ ALLOWED_ORIGINS=http://localhost:5173,http://localhost:5199,http://localhost:300
|
||||
MEMORY_FILE_PATH=./data/memory
|
||||
VECTOR_DB_URL=http://localhost:6333
|
||||
VECTOR_DB_COLLECTION=cyrene_memories
|
||||
|
||||
# ========== 完整 OS 环境 (供 os_exec/os_file/os_system 工具) ==========
|
||||
# 后端选择: direct (默认,仅沙箱), wsl (WSL2 完整Linux), docker (Docker容器)
|
||||
HOST_EXEC_BACKEND=wsl
|
||||
WSL_DISTRO=Ubuntu-22.04
|
||||
# WSL 内自动创建的用户 (首次调用时自动创建,已存在则跳过)
|
||||
WSL_USER=cyrene
|
||||
WSL_USER_PASSWORD=cyrene
|
||||
SANDBOX_CONTAINER=cyrene-sandbox
|
||||
SANDBOX_IMAGE=ubuntu:22.04
|
||||
HOST_EXEC_MAX_TIMEOUT=300
|
||||
|
||||
+104
-7
@@ -133,6 +133,7 @@ func main() {
|
||||
adminUserID := "admin"
|
||||
adminSessionID := "admin-session-main"
|
||||
if cfg.DatabaseURL != "" {
|
||||
convStore.SetDatabaseURL(cfg.DatabaseURL)
|
||||
if err := convStore.LoadFromDB(cfg.DatabaseURL, adminSessionID, 50); err != nil {
|
||||
log.Printf("⚠ 从数据库恢复会话历史失败(不影响服务启动): %v", err)
|
||||
}
|
||||
@@ -151,13 +152,22 @@ func main() {
|
||||
log.Println("IoT 客户端未配置 (IOT_SERVICE_URL 和 IOT_DEBUG_SERVICE_URL 均为空)")
|
||||
}
|
||||
|
||||
// 初始化主机操控管理器 (Phase 6.2: 沙箱执行 + 文件系统隔离)
|
||||
// 初始化主机操控管理器 (沙箱执行 + 文件系统隔离)
|
||||
hostSandbox := host.NewSandbox(host.DefaultSandboxConfig())
|
||||
hostManager := host.NewManager(hostSandbox)
|
||||
directBackend := host.NewDirectBackend(hostSandbox)
|
||||
hostManager := host.NewManager(directBackend)
|
||||
dataDir := getEnv("DATA_DIR", "/tmp/cyrene_data")
|
||||
hostManager.SetAllowedDirs([]string{dataDir, os.TempDir(), "."})
|
||||
log.Printf("主机操控管理器已就绪: 沙箱执行 + 文件隔离 (数据目录=%s)", dataDir)
|
||||
|
||||
// 初始化完整OS环境管理器 (WSL/Docker,无沙箱限制,供 os_* 工具使用)
|
||||
osManager := createOSManager()
|
||||
if osManager != nil {
|
||||
log.Printf("完整OS环境管理器已就绪: backend=%s", osManager.BackendName())
|
||||
} else {
|
||||
log.Println("完整OS环境管理器未配置 (设置 HOST_EXEC_BACKEND=wsl 或 docker 以启用)")
|
||||
}
|
||||
|
||||
// 初始化 RAG 知识库 (Phase 6.6: 知识库 RAG 增强)
|
||||
knowledgeDir := getEnv("KNOWLEDGE_DIR", "./data/knowledge")
|
||||
ragEmbedder := rag.NewEmbedder(cfg.LLMBaseURL, cfg.LLMAPIKey, "text-embedding-3-small")
|
||||
@@ -167,6 +177,7 @@ func main() {
|
||||
|
||||
// 初始化工具注册中心 (使用共享插件模块)
|
||||
toolRegistry := plgManager.NewToolRegistry()
|
||||
var visionProvider llm.LLMProvider
|
||||
if getEnvBool("ENABLE_TOOLS", true) {
|
||||
// 11 个共享通用插件 — 注册其工具到统一注册中心
|
||||
registerPluginTools(toolRegistry, &pluginCalc.CalculatorPlugin{})
|
||||
@@ -198,7 +209,13 @@ func main() {
|
||||
toolRegistry.Register(wrapTool(tools.NewHostSystemTool(hostManager), "host_system", "Host System Info", "system"))
|
||||
}
|
||||
|
||||
var visionProvider llm.LLMProvider
|
||||
if osManager != nil {
|
||||
toolRegistry.Register(wrapTool(tools.NewOSExecTool(osManager), "os_exec", "OS Command Execution", "system"))
|
||||
toolRegistry.Register(wrapTool(tools.NewOSFileTool(osManager), "os_file", "OS File Operations", "system"))
|
||||
toolRegistry.Register(wrapTool(tools.NewOSSystemTool(osManager), "os_system", "OS System Info", "system"))
|
||||
}
|
||||
|
||||
visionProvider = nil
|
||||
if configLoader != nil && configLoader.HasConfig() {
|
||||
cfg := configLoader.GetConfig()
|
||||
if route, ok := cfg.Routing["vision"]; ok && len(route.FallbackChain) > 0 {
|
||||
@@ -300,7 +317,9 @@ func main() {
|
||||
// 注册子会话提供者
|
||||
subManager.Register(subsession.NewGeneralProvider(personaLoader))
|
||||
if memRetriever != nil {
|
||||
subManager.Register(subsession.NewMemoryProvider(memRetriever))
|
||||
memProvider := subsession.NewMemoryProvider(memRetriever)
|
||||
memProvider.SetFuzzySearch(memoryAdapter, memClient)
|
||||
subManager.Register(memProvider)
|
||||
}
|
||||
if iotClient != nil {
|
||||
subManager.Register(subsession.NewIoTProvider(iotClient, personaDir))
|
||||
@@ -322,6 +341,10 @@ func main() {
|
||||
memExtractor,
|
||||
)
|
||||
orch.SetToolRegistry(toolRegistry)
|
||||
if visionProvider != nil {
|
||||
orch.SetVisionProvider(visionProvider)
|
||||
log.Printf("对话编排器: 视觉模型已注入 (%s)", visionProvider.ModelName())
|
||||
}
|
||||
log.Println("对话编排器 v2.0 已就绪")
|
||||
_ = orch
|
||||
|
||||
@@ -428,6 +451,32 @@ func main() {
|
||||
json.NewEncoder(w).Encode(toolRegistry.GetCallStats())
|
||||
})
|
||||
|
||||
// OS 环境监控端点
|
||||
mux.HandleFunc("/api/v1/system/info", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
result := map[string]interface{}{
|
||||
"os_enabled": osManager != nil,
|
||||
}
|
||||
if osManager != nil {
|
||||
result["backend"] = osManager.BackendName()
|
||||
result["system"] = osManager.SystemInfo()
|
||||
if disk, err := osManager.DiskUsage("/"); err == nil {
|
||||
result["disk"] = disk
|
||||
}
|
||||
}
|
||||
if hostManager != nil {
|
||||
result["host"] = map[string]interface{}{
|
||||
"backend": hostManager.BackendName(),
|
||||
"system": hostManager.SystemInfo(),
|
||||
}
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(result)
|
||||
})
|
||||
|
||||
// 启动HTTP服务
|
||||
srv := &http.Server{
|
||||
Addr: ":" + cfg.Port,
|
||||
@@ -525,6 +574,42 @@ func getEnvBool(key string, fallback bool) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// createOSManager 根据 HOST_EXEC_BACKEND 环境变量创建完整OS环境管理器。
|
||||
// 支持 "wsl" 和 "docker" 两种后端。返回 nil 表示未配置或配置无效。
|
||||
func createOSManager() *host.Manager {
|
||||
backend := strings.ToLower(os.Getenv("HOST_EXEC_BACKEND"))
|
||||
switch backend {
|
||||
case "wsl":
|
||||
distro := getEnv("WSL_DISTRO", "Ubuntu-22.04")
|
||||
username := getEnv("WSL_USER", "cyrene")
|
||||
password := os.Getenv("WSL_USER_PASSWORD")
|
||||
maxTimeout := time.Duration(getEnvInt("HOST_EXEC_MAX_TIMEOUT", 300)) * time.Second
|
||||
wslBackend := host.NewWSLBackend(distro, username, password, maxTimeout)
|
||||
return host.NewManager(wslBackend)
|
||||
case "docker":
|
||||
container := getEnv("SANDBOX_CONTAINER", "cyrene-sandbox")
|
||||
image := getEnv("SANDBOX_IMAGE", "ubuntu:22.04")
|
||||
maxTimeout := time.Duration(getEnvInt("HOST_EXEC_MAX_TIMEOUT", 300)) * time.Second
|
||||
dockerBackend := host.NewDockerBackend(container, image, maxTimeout)
|
||||
return host.NewManager(dockerBackend)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// getEnvInt 获取整数类型的环境变量
|
||||
func getEnvInt(key string, fallback int) int {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return fallback
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil || n <= 0 {
|
||||
return fallback
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// registerPluginTools 从插件实例注册其所有工具到注册中心
|
||||
func registerPluginTools(registry *plgManager.ToolRegistry, plugin plgSDK.Plugin) {
|
||||
for _, t := range plugin.Tools() {
|
||||
@@ -638,9 +723,6 @@ func handleChat(
|
||||
userNickname = cfg.AdminNickname
|
||||
}
|
||||
|
||||
// 0.1 缓存用户消息到会话历史
|
||||
ctxBuilder.CacheMessage(req.SessionID, model.RoleUser, req.Message)
|
||||
|
||||
// 1. 设置 SSE 响应头
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
@@ -678,6 +760,19 @@ func handleChat(
|
||||
var fullContent string
|
||||
for event := range eventCh {
|
||||
switch event.Type {
|
||||
case model.StreamToolProgress:
|
||||
tp := event.ToolProgress
|
||||
progressData, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "tool_progress",
|
||||
"tool_name": tp.ToolName,
|
||||
"status": tp.Status,
|
||||
"progress": tp.Progress,
|
||||
"message": tp.Message,
|
||||
"message_id": messageID,
|
||||
})
|
||||
fmt.Fprintf(w, "data: %s\n\n", progressData)
|
||||
flusher.Flush()
|
||||
|
||||
case model.StreamError:
|
||||
log.Printf("[chat] 流式错误: %v", event.Error)
|
||||
errData, _ := json.Marshal(map[string]string{"delta": "", "error": event.Error.Error()})
|
||||
@@ -729,6 +824,8 @@ func handleChat(
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存用户消息到会话历史(在回复生成后,避免本轮 LLM 调用出现重复用户消息)
|
||||
ctxBuilder.CacheMessage(req.SessionID, model.RoleUser, req.Message)
|
||||
// 4. 对话完成后触发昔涟的自主思考(事件驱动,非定时)
|
||||
if thinker != nil {
|
||||
thinker.TriggerPostChatThink()
|
||||
|
||||
@@ -582,16 +582,7 @@ func (t *Thinker) performThink(triggerReason string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 检索相关记忆
|
||||
var memories []memory.MemoryEntry
|
||||
if t.memRetriever != nil {
|
||||
memories, err = t.memRetriever.Retrieve(ctx, t.adminUserID, "最近发生了什么 重要的事情 用户偏好 个人信息")
|
||||
if err != nil {
|
||||
log.Printf("[后台思考] 记忆检索失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 获取当前活跃会话的对话历史(优先活跃会话,回退到管理员主会话)
|
||||
// 2. 获取当前活跃会话的对话历史(优先活跃会话,回退到管理员主会话)
|
||||
var convHistory []model.LLMMessage
|
||||
if t.convStore != nil {
|
||||
t.mu.Lock()
|
||||
@@ -608,6 +599,37 @@ func (t *Thinker) performThink(triggerReason string) {
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 检索相关记忆(精确检索 + 模糊搜索)
|
||||
var memories []memory.MemoryEntry
|
||||
if t.memRetriever != nil {
|
||||
memories, err = t.memRetriever.Retrieve(ctx, t.adminUserID, "最近发生了什么 重要的事情 用户偏好 个人信息")
|
||||
if err != nil {
|
||||
log.Printf("[后台思考] 记忆检索失败: %v", err)
|
||||
}
|
||||
|
||||
// 模糊搜索:从对话历史提取话题,LLM 扩展关键词后语义搜索
|
||||
if t.memClient != nil && len(convHistory) > 0 {
|
||||
fuzzyQuery := lastUserMessage(convHistory)
|
||||
if fuzzyQuery == "" {
|
||||
fuzzyQuery = "最近对话 重要事件 用户状态"
|
||||
}
|
||||
fuzzyResults := t.fuzzyMemorySearch(ctx, t.adminUserID, fuzzyQuery)
|
||||
seen := make(map[string]bool)
|
||||
for _, m := range memories {
|
||||
seen[m.ID] = true
|
||||
}
|
||||
for _, m := range fuzzyResults {
|
||||
if !seen[m.ID] {
|
||||
seen[m.ID] = true
|
||||
memories = append(memories, m)
|
||||
}
|
||||
}
|
||||
if len(fuzzyResults) > 0 {
|
||||
log.Printf("[后台思考] 模糊搜索补充 %d 条记忆", len(fuzzyResults))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 查询 IoT 设备状态(每次都查询,无间隔限制)
|
||||
var deviceSummary string
|
||||
if t.iotClient != nil {
|
||||
@@ -750,10 +772,9 @@ func (t *Thinker) performThink(triggerReason string) {
|
||||
|
||||
log.Printf("[后台思考] 完成 (触发原因=%s, 内容长度=%d, 工具调用=%d次)", triggerReason, len(finalContent), totalToolCalls)
|
||||
|
||||
// 9. 周期性记忆维护(每 10 次思考触发一次)
|
||||
// 注:不再从思考结果中提取记忆——思考内容基于已有记忆生成,
|
||||
// 再次提取会造成"读取→思考→保存→再次读取"的重复循环。
|
||||
// 9. 记忆维护:机械合并(每10次) + LLM整理(每次)
|
||||
t.maybeMaintainMemories(currentCount)
|
||||
t.performMemoryConsolidation(ctx)
|
||||
}
|
||||
|
||||
// buildThinkingSystemPrompt 构建思考用的系统提示词
|
||||
@@ -1196,6 +1217,375 @@ func (t *Thinker) maybeMaintainMemories(thinkCount int) {
|
||||
}
|
||||
}
|
||||
|
||||
// consolidationAction is a parsed memory consolidation instruction from the LLM.
|
||||
type consolidationAction struct {
|
||||
Action string `json:"action"`
|
||||
IDs []string `json:"ids,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Category string `json:"category,omitempty"`
|
||||
Importance int `json:"importance,omitempty"`
|
||||
Priority int `json:"priority,omitempty"`
|
||||
Keywords []string `json:"keywords,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// performMemoryConsolidation uses LLM to review and clean up the memory store.
|
||||
// It identifies duplicates, contradictions, outdated info, and low-quality memories,
|
||||
// then executes merge/delete/update actions.
|
||||
func (t *Thinker) performMemoryConsolidation(ctx context.Context) {
|
||||
if t.memClient == nil {
|
||||
return
|
||||
}
|
||||
|
||||
allMemories, err := t.memClient.Query(ctx, model.MemoryQuery{
|
||||
UserID: t.adminUserID,
|
||||
Limit: 200,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("[记忆整理] 获取记忆失败: %v", err)
|
||||
return
|
||||
}
|
||||
if len(allMemories) < 5 {
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[记忆整理] LLM 审查 %d 条记忆...", len(allMemories))
|
||||
|
||||
systemPrompt := t.buildConsolidationPrompt(allMemories)
|
||||
messages := []model.LLMMessage{
|
||||
{Role: model.RoleSystem, Content: systemPrompt},
|
||||
{Role: model.RoleUser, Content: "请审查以上记忆库,找出重复、矛盾、过时和低质量的记忆,输出 JSON 整理方案。如果没有需要整理的,输出空数组 []。"},
|
||||
}
|
||||
|
||||
resp, err := t.toolAdapter.Chat(ctx, messages)
|
||||
if err != nil {
|
||||
log.Printf("[记忆整理] LLM 调用失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
actions := parseConsolidationActions(resp.Content)
|
||||
if len(actions) == 0 {
|
||||
log.Printf("[记忆整理] 记忆库状态良好,无需整理")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[记忆整理] LLM 建议 %d 项操作", len(actions))
|
||||
executed := t.executeConsolidationActions(ctx, actions, allMemories)
|
||||
log.Printf("[记忆整理] 完成: 执行了 %d 项操作", executed)
|
||||
}
|
||||
|
||||
// buildConsolidationPrompt formats all memories as a structured list for LLM review.
|
||||
func (t *Thinker) buildConsolidationPrompt(memories []model.MemoryEntry) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("你是记忆库管理助手。审查以下用户记忆,找出问题并输出 JSON 操作清单。\n\n")
|
||||
sb.WriteString("## 需要识别的问题\n")
|
||||
sb.WriteString("1. 重复记忆 — 多条记忆记录了相同信息 → merge 合并为一条\n")
|
||||
sb.WriteString("2. 矛盾记忆 — 两条记忆互相矛盾(如\"喜欢猫\"vs\"讨厌猫\")→ delete 删除过时的、update 修正错误的\n")
|
||||
sb.WriteString("3. 过时记忆 — 信息已被新记忆取代 → delete 或 update\n")
|
||||
sb.WriteString("4. 低质量记忆 — 太模糊、不完整、无实际信息量 → delete\n\n")
|
||||
sb.WriteString("## JSON 操作格式\n")
|
||||
sb.WriteString("```json\n[\n")
|
||||
sb.WriteString(" {\"action\":\"merge\", \"ids\":[\"id1\",\"id2\"], \"content\":\"合并后的内容\", \"category\":\"personal_info\", \"importance\":8, \"reason\":\"两条记录同一件事\"},\n")
|
||||
sb.WriteString(" {\"action\":\"delete\", \"id\":\"id3\", \"reason\":\"完全被 id1 覆盖\"},\n")
|
||||
sb.WriteString(" {\"action\":\"update\", \"id\":\"id4\", \"content\":\"修正后的内容\", \"importance\":7, \"reason\":\"纠正过时信息\"},\n")
|
||||
sb.WriteString(" {\"action\":\"create\", \"content\":\"需要补充的记忆\", \"category\":\"knowledge\", \"importance\":6, \"reason\":\"从已有记忆推断\"}\n")
|
||||
sb.WriteString("]\n```\n\n")
|
||||
sb.WriteString("## 规则\n")
|
||||
sb.WriteString("- 只输出 JSON 数组,可以用 ```json``` 包裹,不要输出其他解释文字\n")
|
||||
sb.WriteString("- 确实有问题才建议操作,不要强行找问题\n")
|
||||
sb.WriteString("- merge 时保留最重要的那条的 ID,合并内容应包含各条的关键信息\n")
|
||||
sb.WriteString("- 不确定时宁可保守(不操作)\n")
|
||||
sb.WriteString("- importance 范围 1-10,数字越大越重要\n")
|
||||
sb.WriteString("- category 可选: personal_info, user_preference, conversation, knowledge, event, task, relationship\n\n")
|
||||
sb.WriteString(fmt.Sprintf("## 当前记忆库 (%d 条)\n\n", len(memories)))
|
||||
|
||||
for i, m := range memories {
|
||||
sb.WriteString(fmt.Sprintf("%d. [%s] **%s** | cat=%s imp=%d pri=%d | src=%s\n",
|
||||
i+1, m.ID[:min(8, len(m.ID))], m.Content,
|
||||
m.Category, m.Importance, m.Priority, m.Source))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// parseConsolidationActions extracts JSON actions from LLM response text.
|
||||
func parseConsolidationActions(text string) []consolidationAction {
|
||||
// Try to extract from ```json fences first
|
||||
jsonStr := text
|
||||
if idx := strings.Index(text, "```json"); idx >= 0 {
|
||||
start := idx + 7
|
||||
if end := strings.Index(text[start:], "```"); end >= 0 {
|
||||
jsonStr = text[start : start+end]
|
||||
}
|
||||
} else if idx := strings.Index(text, "```"); idx >= 0 {
|
||||
start := idx + 3
|
||||
if end := strings.Index(text[start:], "```"); end >= 0 {
|
||||
jsonStr = text[start : start+end]
|
||||
}
|
||||
}
|
||||
// Find the JSON array
|
||||
arrStart := strings.Index(jsonStr, "[")
|
||||
arrEnd := strings.LastIndex(jsonStr, "]")
|
||||
if arrStart < 0 || arrEnd <= arrStart {
|
||||
return nil
|
||||
}
|
||||
jsonStr = jsonStr[arrStart : arrEnd+1]
|
||||
|
||||
var actions []consolidationAction
|
||||
if err := json.Unmarshal([]byte(jsonStr), &actions); err != nil {
|
||||
log.Printf("[记忆整理] JSON 解析失败: %v", err)
|
||||
return nil
|
||||
}
|
||||
return actions
|
||||
}
|
||||
|
||||
// executeConsolidationActions runs the parsed consolidation actions against the memory store.
|
||||
func (t *Thinker) executeConsolidationActions(ctx context.Context, actions []consolidationAction, memories []model.MemoryEntry) int {
|
||||
// Index memories by their short ID prefix for lookup
|
||||
memByShortID := make(map[string]*model.MemoryEntry)
|
||||
for i := range memories {
|
||||
short := memories[i].ID[:min(8, len(memories[i].ID))]
|
||||
memByShortID[short] = &memories[i]
|
||||
}
|
||||
memByFullID := make(map[string]*model.MemoryEntry)
|
||||
for i := range memories {
|
||||
memByFullID[memories[i].ID] = &memories[i]
|
||||
}
|
||||
|
||||
executed := 0
|
||||
for _, a := range actions {
|
||||
switch a.Action {
|
||||
case "delete":
|
||||
id := resolveID(a.ID, memByShortID, memByFullID)
|
||||
if id == "" {
|
||||
log.Printf("[记忆整理] 跳过 delete: 找不到记忆 %s", a.ID)
|
||||
continue
|
||||
}
|
||||
if err := t.memClient.Delete(ctx, id); err != nil {
|
||||
log.Printf("[记忆整理] 删除 %s 失败: %v", a.ID, err)
|
||||
continue
|
||||
}
|
||||
log.Printf("[记忆整理] 已删除: %s (原因: %s)", a.ID, a.Reason)
|
||||
executed++
|
||||
|
||||
case "merge":
|
||||
if len(a.IDs) < 2 {
|
||||
continue
|
||||
}
|
||||
// Resolve all IDs, use first as the keeper
|
||||
var resolved []string
|
||||
for _, mid := range a.IDs {
|
||||
if rid := resolveID(mid, memByShortID, memByFullID); rid != "" {
|
||||
resolved = append(resolved, rid)
|
||||
}
|
||||
}
|
||||
if len(resolved) < 2 {
|
||||
continue
|
||||
}
|
||||
keeper := resolved[0]
|
||||
// Update the keeper with merged content
|
||||
cat := model.MemoryCategory(a.Category)
|
||||
if cat == "" {
|
||||
if m, ok := memByFullID[keeper]; ok {
|
||||
cat = m.Category
|
||||
}
|
||||
}
|
||||
imp := a.Importance
|
||||
if imp == 0 {
|
||||
if m, ok := memByFullID[keeper]; ok {
|
||||
imp = m.Importance + 1
|
||||
}
|
||||
}
|
||||
if imp > 10 {
|
||||
imp = 10
|
||||
}
|
||||
pri := a.Priority
|
||||
if pri == 0 {
|
||||
if m, ok := memByFullID[keeper]; ok {
|
||||
pri = int(m.Priority)
|
||||
}
|
||||
}
|
||||
if err := t.memClient.Update(ctx, &model.MemoryEntry{
|
||||
ID: keeper,
|
||||
Content: a.Content,
|
||||
Category: cat,
|
||||
Importance: imp,
|
||||
Priority: model.MemoryPriority(pri),
|
||||
Keywords: a.Keywords,
|
||||
Source: "consolidated",
|
||||
}); err != nil {
|
||||
log.Printf("[记忆整理] 更新合并目标 %s 失败: %v", keeper, err)
|
||||
continue
|
||||
}
|
||||
// Delete the discarded ones
|
||||
for _, did := range resolved[1:] {
|
||||
if err := t.memClient.Delete(ctx, did); err != nil {
|
||||
log.Printf("[记忆整理] 删除被合并记忆 %s 失败: %v", did, err)
|
||||
}
|
||||
}
|
||||
log.Printf("[记忆整理] 已合并: %v -> %s (原因: %s)", resolved, keeper, a.Reason)
|
||||
executed++
|
||||
|
||||
case "update":
|
||||
id := resolveID(a.ID, memByShortID, memByFullID)
|
||||
if id == "" {
|
||||
log.Printf("[记忆整理] 跳过 update: 找不到记忆 %s", a.ID)
|
||||
continue
|
||||
}
|
||||
existing := memByFullID[id]
|
||||
cat := model.MemoryCategory(a.Category)
|
||||
if cat == "" && existing != nil {
|
||||
cat = existing.Category
|
||||
}
|
||||
imp := a.Importance
|
||||
if imp == 0 && existing != nil {
|
||||
imp = existing.Importance
|
||||
}
|
||||
pri := a.Priority
|
||||
if pri == 0 && existing != nil {
|
||||
pri = int(existing.Priority)
|
||||
}
|
||||
if err := t.memClient.Update(ctx, &model.MemoryEntry{
|
||||
ID: id,
|
||||
Content: a.Content,
|
||||
Category: cat,
|
||||
Importance: imp,
|
||||
Priority: model.MemoryPriority(pri),
|
||||
Keywords: a.Keywords,
|
||||
Source: "consolidated",
|
||||
}); err != nil {
|
||||
log.Printf("[记忆整理] 更新 %s 失败: %v", id, err)
|
||||
continue
|
||||
}
|
||||
log.Printf("[记忆整理] 已更新: %s (原因: %s)", id, a.Reason)
|
||||
executed++
|
||||
|
||||
case "create":
|
||||
cat := model.MemoryCategory(a.Category)
|
||||
if cat == "" {
|
||||
cat = model.CategoryKnowledge
|
||||
}
|
||||
imp := a.Importance
|
||||
if imp == 0 {
|
||||
imp = 5
|
||||
}
|
||||
if err := t.memClient.Save(ctx, &model.MemoryEntry{
|
||||
UserID: t.adminUserID,
|
||||
Content: a.Content,
|
||||
Category: cat,
|
||||
Importance: imp,
|
||||
Priority: model.MemoryNormal,
|
||||
Keywords: a.Keywords,
|
||||
Source: "consolidation",
|
||||
}); err != nil {
|
||||
log.Printf("[记忆整理] 创建记忆失败: %v", err)
|
||||
continue
|
||||
}
|
||||
log.Printf("[记忆整理] 已创建: %s (原因: %s)", a.Content, a.Reason)
|
||||
executed++
|
||||
}
|
||||
}
|
||||
return executed
|
||||
}
|
||||
|
||||
// resolveID tries to match a short or full ID to an existing memory.
|
||||
func resolveID(id string, byShort, byFull map[string]*model.MemoryEntry) string {
|
||||
if _, ok := byFull[id]; ok {
|
||||
return id
|
||||
}
|
||||
if m, ok := byShort[id]; ok {
|
||||
return m.ID
|
||||
}
|
||||
// Try prefix match
|
||||
for fullID := range byFull {
|
||||
if strings.HasPrefix(fullID, id) {
|
||||
return fullID
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// fuzzyMemorySearch expands the query via LLM keyword extraction and performs semantic search.
|
||||
func (t *Thinker) fuzzyMemorySearch(ctx context.Context, userID, query string) []memory.MemoryEntry {
|
||||
if t.memClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
keywords := t.expandMemoryKeywords(ctx, query)
|
||||
if len(keywords) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf("[后台思考] 模糊记忆关键词: %v", keywords)
|
||||
|
||||
var allResults []memory.MemoryEntry
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, kw := range keywords {
|
||||
results, err := t.memClient.QueryByText(ctx, userID, kw, "", 0, 5)
|
||||
if err != nil {
|
||||
log.Printf("[后台思考] 模糊搜索 '%s' 失败: %v", kw, err)
|
||||
continue
|
||||
}
|
||||
for _, m := range results {
|
||||
if !seen[m.ID] {
|
||||
seen[m.ID] = true
|
||||
allResults = append(allResults, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return allResults
|
||||
}
|
||||
|
||||
// expandMemoryKeywords uses LLM to generate fuzzy/related keywords for memory search.
|
||||
func (t *Thinker) expandMemoryKeywords(ctx context.Context, message string) []string {
|
||||
prompt := fmt.Sprintf(
|
||||
"从以下对话消息中提取 3-5 个可用于模糊搜索记忆的关键词。这些关键词应该是:\n"+
|
||||
"- 与话题相关的抽象概念\n- 同义词和相关词\n- 更宽泛或更具体的相关概念\n"+
|
||||
"- 不要包含消息中已经出现的原词\n\n"+
|
||||
"用户消息:「%s」\n\n"+
|
||||
"只输出 JSON 字符串数组,例如:[\"关键词1\",\"关键词2\"]", message)
|
||||
|
||||
resp, err := t.llmAdapter.Chat(ctx, []model.LLMMessage{
|
||||
{Role: model.RoleSystem, Content: "你是记忆搜索专家。输出 JSON 字符串数组。"},
|
||||
{Role: model.RoleUser, Content: prompt},
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("[后台思考] 关键词扩展失败: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(resp.Content)
|
||||
if idx := strings.Index(text, "["); idx >= 0 {
|
||||
if end := strings.LastIndex(text, "]"); end > idx {
|
||||
text = text[idx : end+1]
|
||||
}
|
||||
}
|
||||
|
||||
var keywords []string
|
||||
if err := json.Unmarshal([]byte(text), &keywords); err != nil {
|
||||
log.Printf("[后台思考] 解析关键词 JSON 失败: %v (raw=%s)", err, resp.Content)
|
||||
return nil
|
||||
}
|
||||
|
||||
return keywords
|
||||
}
|
||||
|
||||
// lastUserMessage extracts the last user message from conversation history.
|
||||
func lastUserMessage(history []model.LLMMessage) string {
|
||||
for i := len(history) - 1; i >= 0; i-- {
|
||||
if history[i].Role == model.RoleUser {
|
||||
runes := []rune(history[i].Content)
|
||||
if len(runes) > 200 {
|
||||
return string(runes[:200])
|
||||
}
|
||||
return history[i].Content
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// formatDeviceContext 格式化设备状态为文本
|
||||
func formatDeviceContext(devices []tools.IoTDevice) string {
|
||||
if len(devices) == 0 {
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
package background
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ScheduleRule defines a time-based interval rule.
|
||||
type ScheduleRule struct {
|
||||
Name string `json:"name"`
|
||||
Days []string `json:"days"`
|
||||
TimeRange string `json:"time_range"`
|
||||
Except []string `json:"except"`
|
||||
IntervalMinutes int `json:"interval_minutes"`
|
||||
}
|
||||
|
||||
// ThinkingScheduleConfig is the full schedule configuration.
|
||||
type ThinkingScheduleConfig struct {
|
||||
Version string `json:"version"`
|
||||
DefaultIntervalMinutes int `json:"default_interval_minutes"`
|
||||
Rules []ScheduleRule `json:"rules"`
|
||||
}
|
||||
|
||||
// ScheduleLoader loads the thinking schedule from a JSON file and calculates
|
||||
// the current interval based on time of day and day of week.
|
||||
type ScheduleLoader struct {
|
||||
mu sync.RWMutex
|
||||
path string
|
||||
config *ThinkingScheduleConfig
|
||||
}
|
||||
|
||||
// NewScheduleLoader creates a loader. Returns nil config if the file does not exist.
|
||||
func NewScheduleLoader(path string) (*ScheduleLoader, error) {
|
||||
l := &ScheduleLoader{path: path}
|
||||
if err := l.load(); err != nil {
|
||||
return l, err
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (l *ScheduleLoader) load() error {
|
||||
data, err := os.ReadFile(l.path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
l.config = nil
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read thinking schedule: %w", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
l.config = nil
|
||||
return nil
|
||||
}
|
||||
var cfg ThinkingScheduleConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
l.config = nil
|
||||
return fmt.Errorf("parse thinking schedule: %w", err)
|
||||
}
|
||||
l.mu.Lock()
|
||||
l.config = &cfg
|
||||
l.mu.Unlock()
|
||||
log.Printf("[思考调度] 已加载配置文件: version=%s, default=%dmin, rules=%d", cfg.Version, cfg.DefaultIntervalMinutes, len(cfg.Rules))
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasConfig returns true if a schedule config was loaded from file.
|
||||
func (l *ScheduleLoader) HasConfig() bool {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
return l.config != nil
|
||||
}
|
||||
|
||||
// GetInterval returns the thinking interval in minutes for the given time.
|
||||
// Returns 0 if no schedule is loaded (caller should use default).
|
||||
func (l *ScheduleLoader) GetInterval(now time.Time) int {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
|
||||
if l.config == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
weekday := strings.ToLower(now.Weekday().String()) // monday, tuesday, ...
|
||||
currentMinutes := now.Hour()*60 + now.Minute()
|
||||
|
||||
for _, rule := range l.config.Rules {
|
||||
if !matchDay(weekday, rule.Days) {
|
||||
continue
|
||||
}
|
||||
if !matchTimeRange(currentMinutes, rule.TimeRange) {
|
||||
continue
|
||||
}
|
||||
if matchExceptRange(currentMinutes, rule.Except) {
|
||||
continue
|
||||
}
|
||||
return rule.IntervalMinutes
|
||||
}
|
||||
|
||||
return l.config.DefaultIntervalMinutes
|
||||
}
|
||||
|
||||
// matchDay checks if the current weekday is in the rule's days list.
|
||||
func matchDay(currentDay string, days []string) bool {
|
||||
for _, d := range days {
|
||||
if strings.ToLower(d) == currentDay {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchTimeRange checks if currentMinutes (0-1439) falls within the time range.
|
||||
// Supports overnight ranges (e.g., 23:00-07:00 where start > end).
|
||||
func matchTimeRange(currentMinutes int, timeRange string) bool {
|
||||
start, end, ok := parseTimeRange(timeRange)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if start <= end {
|
||||
return currentMinutes >= start && currentMinutes < end
|
||||
}
|
||||
// Overnight range
|
||||
return currentMinutes >= start || currentMinutes < end
|
||||
}
|
||||
|
||||
// matchExceptRange returns true if currentMinutes falls in any except range.
|
||||
func matchExceptRange(currentMinutes int, exceptRanges []string) bool {
|
||||
for _, er := range exceptRanges {
|
||||
start, end, ok := parseTimeRange(er)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if start <= end {
|
||||
if currentMinutes >= start && currentMinutes < end {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
if currentMinutes >= start || currentMinutes < end {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseTimeRange parses "HH:MM-HH:MM" into start and end minutes from midnight.
|
||||
func parseTimeRange(r string) (int, int, bool) {
|
||||
parts := strings.SplitN(r, "-", 2)
|
||||
if len(parts) != 2 {
|
||||
return 0, 0, false
|
||||
}
|
||||
start, ok := parseHM(strings.TrimSpace(parts[0]))
|
||||
if !ok {
|
||||
return 0, 0, false
|
||||
}
|
||||
end, ok := parseHM(strings.TrimSpace(parts[1]))
|
||||
if !ok {
|
||||
return 0, 0, false
|
||||
}
|
||||
return start, end, true
|
||||
}
|
||||
|
||||
// parseHM parses "HH:MM" into minutes from midnight.
|
||||
func parseHM(s string) (int, bool) {
|
||||
parts := strings.SplitN(s, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return 0, false
|
||||
}
|
||||
h, err := strconv.Atoi(parts[0])
|
||||
if err != nil || h < 0 || h > 23 {
|
||||
return 0, false
|
||||
}
|
||||
m, err := strconv.Atoi(parts[1])
|
||||
if err != nil || m < 0 || m > 59 {
|
||||
return 0, false
|
||||
}
|
||||
return h*60 + m, true
|
||||
}
|
||||
@@ -24,9 +24,10 @@ type IoTDeviceSummary interface {
|
||||
|
||||
// ConversationStore 会话历史存储接口
|
||||
type ConversationStore struct {
|
||||
mu sync.RWMutex
|
||||
messages map[string][]model.LLMMessage // key = sessionID
|
||||
maxHistory int
|
||||
mu sync.RWMutex
|
||||
messages map[string][]model.LLMMessage // key = sessionID
|
||||
maxHistory int
|
||||
databaseURL string // lazy-load from DB on cache miss
|
||||
}
|
||||
|
||||
// NewConversationStore 创建会话历史存储
|
||||
@@ -37,6 +38,13 @@ func NewConversationStore(maxHistory int) *ConversationStore {
|
||||
}
|
||||
}
|
||||
|
||||
// SetDatabaseURL sets the database URL for lazy-loading history on cache miss.
|
||||
func (cs *ConversationStore) SetDatabaseURL(url string) {
|
||||
cs.mu.Lock()
|
||||
defer cs.mu.Unlock()
|
||||
cs.databaseURL = url
|
||||
}
|
||||
|
||||
// AddMessage 添加消息到会话历史
|
||||
func (cs *ConversationStore) AddMessage(sessionID string, msg model.LLMMessage) {
|
||||
cs.mu.Lock()
|
||||
@@ -59,12 +67,23 @@ func (cs *ConversationStore) AddMessage(sessionID string, msg model.LLMMessage)
|
||||
cs.messages[sessionID] = msgs
|
||||
}
|
||||
|
||||
// GetHistory 获取会话历史
|
||||
// GetHistory 获取会话历史。
|
||||
// 如果内存缓存为空且配置了 databaseURL,会尝试从 DB 懒加载历史。
|
||||
func (cs *ConversationStore) GetHistory(sessionID string, limit int) []model.LLMMessage {
|
||||
cs.mu.RLock()
|
||||
defer cs.mu.RUnlock()
|
||||
|
||||
msgs := cs.messages[sessionID]
|
||||
dbURL := cs.databaseURL
|
||||
cs.mu.RUnlock()
|
||||
|
||||
if len(msgs) == 0 && dbURL != "" {
|
||||
// 懒加载:从 DB 恢复该会话的历史
|
||||
if err := cs.LoadFromDB(dbURL, sessionID, limit); err == nil {
|
||||
cs.mu.RLock()
|
||||
msgs = cs.messages[sessionID]
|
||||
cs.mu.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
if len(msgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
package host
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DirectBackend executes commands directly on the host via os/exec,
|
||||
// with command allowlist and directory restrictions for safety.
|
||||
type DirectBackend struct {
|
||||
sandbox *Sandbox
|
||||
allowedDirs []string
|
||||
}
|
||||
|
||||
// NewDirectBackend creates a host execution backend that runs commands
|
||||
// directly on the host machine with sandbox restrictions.
|
||||
func NewDirectBackend(sandbox *Sandbox) *DirectBackend {
|
||||
b := &DirectBackend{sandbox: sandbox}
|
||||
if sandbox != nil {
|
||||
b.allowedDirs = sandbox.cfg.AllowedDirs
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *DirectBackend) Name() string { return "direct" }
|
||||
|
||||
// SetAllowedDirs updates the directories accessible for file operations.
|
||||
func (b *DirectBackend) SetAllowedDirs(dirs []string) {
|
||||
b.allowedDirs = dirs
|
||||
if b.sandbox != nil {
|
||||
b.sandbox.cfg.AllowedDirs = dirs
|
||||
}
|
||||
}
|
||||
|
||||
// Exec runs a command in the sandbox.
|
||||
func (b *DirectBackend) Exec(ctx context.Context, command, workDir string, timeout time.Duration) (*ExecResult, error) {
|
||||
return b.sandbox.Exec(ctx, command, workDir, timeout)
|
||||
}
|
||||
|
||||
// ReadFile reads the contents of a file within allowed directories.
|
||||
func (b *DirectBackend) ReadFile(path string, maxBytes int) (string, error) {
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 1024 * 1024
|
||||
}
|
||||
if err := b.validatePath(path); err != nil {
|
||||
return "", err
|
||||
}
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot stat file: %w", err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return "", fmt.Errorf("path is a directory: %s", path)
|
||||
}
|
||||
if info.Size() > int64(maxBytes) {
|
||||
return "", fmt.Errorf("file too large: %d bytes (max %d)", info.Size(), maxBytes)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot read file: %w", err)
|
||||
}
|
||||
if len(data) > maxBytes {
|
||||
data = data[:maxBytes]
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// WriteFile writes data to a file within allowed directories.
|
||||
func (b *DirectBackend) WriteFile(path, content string, maxBytes int) error {
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 1024 * 1024
|
||||
}
|
||||
if len(content) > maxBytes {
|
||||
return fmt.Errorf("content too large: %d bytes (max %d)", len(content), maxBytes)
|
||||
}
|
||||
if err := b.validatePath(path); err != nil {
|
||||
return err
|
||||
}
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("cannot create directory: %w", err)
|
||||
}
|
||||
return os.WriteFile(path, []byte(content), 0644)
|
||||
}
|
||||
|
||||
// ListDir lists directory contents within allowed directories.
|
||||
func (b *DirectBackend) ListDir(path string) ([]DirEntry, error) {
|
||||
if err := b.validatePath(path); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read directory: %w", err)
|
||||
}
|
||||
result := make([]DirEntry, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
info, _ := e.Info()
|
||||
size := int64(0)
|
||||
modTime := time.Time{}
|
||||
if info != nil {
|
||||
size = info.Size()
|
||||
modTime = info.ModTime()
|
||||
}
|
||||
result = append(result, DirEntry{
|
||||
Name: e.Name(),
|
||||
IsDir: e.IsDir(),
|
||||
Size: size,
|
||||
ModTime: modTime.Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SystemInfo returns basic system information.
|
||||
func (b *DirectBackend) SystemInfo() map[string]interface{} {
|
||||
hostname, _ := os.Hostname()
|
||||
wd, _ := os.Getwd()
|
||||
|
||||
info := map[string]interface{}{
|
||||
"hostname": hostname,
|
||||
"os": runtime.GOOS,
|
||||
"arch": runtime.GOARCH,
|
||||
"num_cpu": runtime.NumCPU(),
|
||||
"go_version": runtime.Version(),
|
||||
"work_dir": wd,
|
||||
"backend": "direct",
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
cmd := exec.Command("systeminfo")
|
||||
out, err := cmd.Output()
|
||||
if err == nil {
|
||||
lines := strings.Split(string(out), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.Contains(line, "Total Physical Memory") {
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
info["total_memory"] = strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
if strings.Contains(line, "OS Name") {
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
info["os_name"] = strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if data, err := os.ReadFile("/proc/meminfo"); err == nil {
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if strings.HasPrefix(line, "MemTotal:") {
|
||||
info["total_memory"] = strings.TrimSpace(strings.TrimPrefix(line, "MemTotal:"))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// DiskUsage returns disk usage for the given path.
|
||||
func (b *DirectBackend) DiskUsage(path string) (map[string]interface{}, error) {
|
||||
if err := b.validatePath(path); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot stat path: %w", err)
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"path": path,
|
||||
"is_dir": info.IsDir(),
|
||||
"size": info.Size(),
|
||||
"mod_time": info.ModTime().Format(time.RFC3339),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *DirectBackend) validatePath(path string) error {
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot resolve path: %w", err)
|
||||
}
|
||||
if len(b.allowedDirs) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, allowed := range b.allowedDirs {
|
||||
absAllowed, err := filepath.Abs(allowed)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(absPath, absAllowed+string(os.PathSeparator)) || absPath == absAllowed {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("path not in allowed directories: %s", path)
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
package host
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DockerBackend executes commands inside a Docker container,
|
||||
// providing a full Linux OS environment with container-level isolation.
|
||||
type DockerBackend struct {
|
||||
container string
|
||||
image string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// NewDockerBackend creates a Docker backend that runs commands in the
|
||||
// specified container. If the container does not exist, it will be
|
||||
// created from the given image.
|
||||
func NewDockerBackend(container, image string, defaultTimeout time.Duration) *DockerBackend {
|
||||
if defaultTimeout <= 0 {
|
||||
defaultTimeout = 30 * time.Second
|
||||
}
|
||||
return &DockerBackend{
|
||||
container: container,
|
||||
image: image,
|
||||
timeout: defaultTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *DockerBackend) Name() string { return "docker" }
|
||||
|
||||
// ensureContainer checks that the container exists and is running.
|
||||
// If it doesn't exist, it creates it from the configured image.
|
||||
func (b *DockerBackend) ensureContainer() error {
|
||||
// Check if container exists and is running
|
||||
check := exec.Command("docker", "inspect", "-f", "{{.State.Running}}", b.container)
|
||||
out, err := check.Output()
|
||||
if err == nil && strings.TrimSpace(string(out)) == "true" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if container exists but is stopped
|
||||
if err == nil && strings.TrimSpace(string(out)) == "false" {
|
||||
start := exec.Command("docker", "start", b.container)
|
||||
if out, err := start.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("cannot start container %s: %s — %w", b.container, string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create and start a new container
|
||||
create := exec.Command("docker", "run", "-d", "--name", b.container,
|
||||
"--restart", "unless-stopped",
|
||||
b.image, "sleep", "infinity")
|
||||
if out, err := create.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("cannot create container %s from image %s: %s — %w",
|
||||
b.container, b.image, string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exec runs a command inside the Docker container.
|
||||
func (b *DockerBackend) Exec(ctx context.Context, command, workDir string, timeout time.Duration) (*ExecResult, error) {
|
||||
if command == "" {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
if err := b.ensureContainer(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if timeout <= 0 {
|
||||
timeout = b.timeout
|
||||
}
|
||||
|
||||
execCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// Build the shell command to run inside the container
|
||||
script := command
|
||||
if workDir != "" {
|
||||
script = fmt.Sprintf("cd %s && %s", shellEscapeDocker(workDir), command)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(execCtx, "docker", "exec", b.container, "sh", "-c", script)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
start := time.Now()
|
||||
err := cmd.Run()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
result := &ExecResult{
|
||||
Duration: elapsed.Round(time.Millisecond).String(),
|
||||
Stdout: stdout.String(),
|
||||
Stderr: stderr.String(),
|
||||
}
|
||||
|
||||
if execCtx.Err() == context.DeadlineExceeded {
|
||||
result.TimedOut = true
|
||||
result.ExitCode = -1
|
||||
return result, fmt.Errorf("command timed out after %s", timeout)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
result.ExitCode = -1
|
||||
}
|
||||
} else {
|
||||
result.ExitCode = 0
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ReadFile reads a file from inside the container using cat.
|
||||
func (b *DockerBackend) ReadFile(path string, maxBytes int) (string, error) {
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 1024 * 1024
|
||||
}
|
||||
if err := b.ensureContainer(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "docker", "exec", b.container, "cat", path)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot read file %s: %w", path, err)
|
||||
}
|
||||
if len(out) > maxBytes {
|
||||
out = out[:maxBytes]
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
// WriteFile writes content to a file inside the container.
|
||||
func (b *DockerBackend) WriteFile(path, content string, maxBytes int) error {
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 1024 * 1024
|
||||
}
|
||||
if len(content) > maxBytes {
|
||||
return fmt.Errorf("content too large: %d bytes (max %d)", len(content), maxBytes)
|
||||
}
|
||||
if err := b.ensureContainer(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Create parent directory and write file
|
||||
cmd := exec.CommandContext(ctx, "docker", "exec", "-i", b.container, "sh", "-c",
|
||||
fmt.Sprintf("mkdir -p $(dirname %s) && cat > %s", shellEscapeDocker(path), shellEscapeDocker(path)))
|
||||
cmd.Stdin = strings.NewReader(content)
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot write file %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListDir lists a directory inside the container.
|
||||
func (b *DockerBackend) ListDir(path string) ([]DirEntry, error) {
|
||||
if err := b.ensureContainer(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "docker", "exec", b.container, "sh", "-c",
|
||||
fmt.Sprintf("ls -la %s 2>/dev/null | tail -n +2 || echo ''", shellEscapeDocker(path)))
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot list dir %s: %w", path, err)
|
||||
}
|
||||
|
||||
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||||
result := make([]DirEntry, 0, len(lines))
|
||||
for _, line := range lines {
|
||||
if line == "" || strings.HasPrefix(line, "total ") {
|
||||
continue
|
||||
}
|
||||
// Parse ls -la output: drwxr-xr-x 2 root root 4096 Jan 1 12:00 name
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 9 {
|
||||
continue
|
||||
}
|
||||
isDir := strings.HasPrefix(fields[0], "d")
|
||||
name := fields[len(fields)-1]
|
||||
if name == "." || name == ".." {
|
||||
continue
|
||||
}
|
||||
var size int64
|
||||
fmt.Sscanf(fields[4], "%d", &size)
|
||||
result = append(result, DirEntry{
|
||||
Name: name,
|
||||
IsDir: isDir,
|
||||
Size: size,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SystemInfo returns system information from inside the container.
|
||||
func (b *DockerBackend) SystemInfo() map[string]interface{} {
|
||||
info := map[string]interface{}{
|
||||
"backend": "docker",
|
||||
"container": b.container,
|
||||
"image": b.image,
|
||||
}
|
||||
|
||||
if err := b.ensureContainer(); err != nil {
|
||||
info["error"] = err.Error()
|
||||
return info
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if out, err := exec.CommandContext(ctx, "docker", "exec", b.container, "uname", "-a").Output(); err == nil {
|
||||
info["uname"] = strings.TrimSpace(string(out))
|
||||
}
|
||||
if out, err := exec.CommandContext(ctx, "docker", "exec", b.container, "hostname").Output(); err == nil {
|
||||
info["hostname"] = strings.TrimSpace(string(out))
|
||||
}
|
||||
if out, err := exec.CommandContext(ctx, "docker", "exec", b.container, "free", "-h").Output(); err == nil {
|
||||
info["memory"] = strings.TrimSpace(string(out))
|
||||
}
|
||||
if out, err := exec.CommandContext(ctx, "docker", "exec", b.container, "df", "-h", "/").Output(); err == nil {
|
||||
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||||
if len(lines) > 1 {
|
||||
info["disk"] = strings.TrimSpace(lines[1])
|
||||
}
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// DiskUsage returns disk usage for a path inside the container.
|
||||
func (b *DockerBackend) DiskUsage(path string) (map[string]interface{}, error) {
|
||||
if err := b.ensureContainer(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "docker", "exec", b.container, "stat", path)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot stat path %s: %w", path, err)
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"path": path,
|
||||
"stat": strings.TrimSpace(string(out)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// shellEscapeDocker escapes a string for safe use in a shell command.
|
||||
func shellEscapeDocker(s string) string {
|
||||
escaped := strings.ReplaceAll(s, "'", "'\\''")
|
||||
return "'" + escaped + "'"
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
package host
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// WSLBackend executes commands inside a WSL2 distribution,
|
||||
// providing a full Linux OS environment isolated from the Windows host.
|
||||
type WSLBackend struct {
|
||||
distro string
|
||||
username string
|
||||
password string
|
||||
timeout time.Duration
|
||||
|
||||
userEnsured bool
|
||||
}
|
||||
|
||||
// NewWSLBackend creates a WSL backend that runs commands in the
|
||||
// specified WSL distribution as the given user. On first use,
|
||||
// the user is automatically created with sudo privileges.
|
||||
func NewWSLBackend(distro, username, password string, defaultTimeout time.Duration) *WSLBackend {
|
||||
if defaultTimeout <= 0 {
|
||||
defaultTimeout = 30 * time.Second
|
||||
}
|
||||
if username == "" {
|
||||
username = "cyrene"
|
||||
}
|
||||
return &WSLBackend{
|
||||
distro: distro,
|
||||
username: username,
|
||||
password: password,
|
||||
timeout: defaultTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *WSLBackend) Name() string { return "wsl" }
|
||||
|
||||
// ensureUser creates the configured user inside the WSL distro on first call.
|
||||
// The user gets sudo privileges and the configured password.
|
||||
func (b *WSLBackend) ensureUser() error {
|
||||
if b.userEnsured {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if user already exists
|
||||
checkCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
checkCmd := exec.CommandContext(checkCtx, "wsl.exe", "-d", b.distro, "--", "id", b.username)
|
||||
if checkCmd.Run() == nil {
|
||||
b.userEnsured = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create user with home directory, set password, add to sudo group
|
||||
// If password is empty, create user without password (sudo won't need it
|
||||
// if NOPASSWD is configured, but we still set a random one for safety)
|
||||
pwd := b.password
|
||||
if pwd == "" {
|
||||
pwd = "cyrene"
|
||||
}
|
||||
|
||||
// Escape single quotes in password for the shell echo command
|
||||
escapedPwd := strings.ReplaceAll(pwd, "'", "'\\''")
|
||||
script := fmt.Sprintf(
|
||||
"useradd -m -s /bin/bash %s && echo '%s:%s' | chpasswd && usermod -aG sudo %s",
|
||||
b.username, b.username, escapedPwd, b.username,
|
||||
)
|
||||
|
||||
createCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
createCmd := exec.CommandContext(createCtx, "wsl.exe", "-d", b.distro, "--", "bash", "-c", script)
|
||||
if out, err := createCmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("cannot create user %s: %s — %w", b.username, string(out), err)
|
||||
}
|
||||
|
||||
b.userEnsured = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exec runs a command inside the WSL distribution via bash.
|
||||
func (b *WSLBackend) Exec(ctx context.Context, command, workDir string, timeout time.Duration) (*ExecResult, error) {
|
||||
if command == "" {
|
||||
return nil, fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
if err := b.ensureUser(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if timeout <= 0 {
|
||||
timeout = b.timeout
|
||||
}
|
||||
|
||||
execCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// Build the bash command to run inside WSL
|
||||
script := command
|
||||
if workDir != "" {
|
||||
wslPath := windowsToWSLPath(workDir)
|
||||
script = fmt.Sprintf("cd %s && %s", shellEscape(wslPath), command)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(execCtx, "wsl.exe", "-d", b.distro, "--", "bash", "-c", script)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
start := time.Now()
|
||||
err := cmd.Run()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
result := &ExecResult{
|
||||
Duration: elapsed.Round(time.Millisecond).String(),
|
||||
Stdout: stdout.String(),
|
||||
Stderr: stderr.String(),
|
||||
}
|
||||
|
||||
if execCtx.Err() == context.DeadlineExceeded {
|
||||
result.TimedOut = true
|
||||
result.ExitCode = -1
|
||||
return result, fmt.Errorf("command timed out after %s", timeout)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
result.ExitCode = -1
|
||||
}
|
||||
} else {
|
||||
result.ExitCode = 0
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ReadFile reads a file from the WSL filesystem using cat.
|
||||
func (b *WSLBackend) ReadFile(path string, maxBytes int) (string, error) {
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 1024 * 1024
|
||||
}
|
||||
if err := b.ensureUser(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
wslPath := windowsToWSLPath(path)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "cat", wslPath)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot read file %s: %w", path, err)
|
||||
}
|
||||
if len(out) > maxBytes {
|
||||
out = out[:maxBytes]
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
// WriteFile writes content to a file in the WSL filesystem.
|
||||
func (b *WSLBackend) WriteFile(path, content string, maxBytes int) error {
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 1024 * 1024
|
||||
}
|
||||
if len(content) > maxBytes {
|
||||
return fmt.Errorf("content too large: %d bytes (max %d)", len(content), maxBytes)
|
||||
}
|
||||
if err := b.ensureUser(); err != nil {
|
||||
return err
|
||||
}
|
||||
wslPath := windowsToWSLPath(path)
|
||||
// Create parent directory first
|
||||
dir := filepath.Dir(wslPath)
|
||||
_ = exec.Command("wsl.exe", "-d", b.distro, "--", "mkdir", "-p", dir).Run()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "bash", "-c",
|
||||
fmt.Sprintf("cat > %s", shellEscape(wslPath)))
|
||||
cmd.Stdin = strings.NewReader(content)
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot write file %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListDir lists a directory in the WSL filesystem using ls.
|
||||
func (b *WSLBackend) ListDir(path string) ([]DirEntry, error) {
|
||||
if err := b.ensureUser(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wslPath := windowsToWSLPath(path)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "bash", "-c",
|
||||
fmt.Sprintf("stat -c '%%n|%%F|%%s|%%Y' %s/* 2>/dev/null || echo ''", shellEscape(wslPath)))
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot list dir %s: %w", path, err)
|
||||
}
|
||||
|
||||
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||||
result := make([]DirEntry, 0, len(lines))
|
||||
for _, line := range lines {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, "|", 4)
|
||||
if len(parts) < 4 {
|
||||
continue
|
||||
}
|
||||
var size int64
|
||||
fmt.Sscanf(parts[2], "%d", &size)
|
||||
var modTimeUnix int64
|
||||
fmt.Sscanf(parts[3], "%d", &modTimeUnix)
|
||||
modTime := time.Unix(modTimeUnix, 0).Format(time.RFC3339)
|
||||
isDir := strings.Contains(parts[1], "directory")
|
||||
result = append(result, DirEntry{
|
||||
Name: filepath.Base(parts[0]),
|
||||
IsDir: isDir,
|
||||
Size: size,
|
||||
ModTime: modTime,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SystemInfo returns system information from inside the WSL distribution.
|
||||
func (b *WSLBackend) SystemInfo() map[string]interface{} {
|
||||
info := map[string]interface{}{
|
||||
"backend": "wsl",
|
||||
"distro": b.distro,
|
||||
}
|
||||
|
||||
if err := b.ensureUser(); err != nil {
|
||||
info["error"] = err.Error()
|
||||
return info
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// uname
|
||||
if out, err := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "uname", "-a").Output(); err == nil {
|
||||
info["uname"] = strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
// hostname
|
||||
if out, err := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "hostname").Output(); err == nil {
|
||||
info["hostname"] = strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
// memory info
|
||||
if out, err := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "free", "-h").Output(); err == nil {
|
||||
info["memory"] = strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
// disk info
|
||||
if out, err := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "df", "-h", "/").Output(); err == nil {
|
||||
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||||
if len(lines) > 1 {
|
||||
info["disk"] = strings.TrimSpace(lines[1])
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// DiskUsage returns disk usage for a path inside WSL.
|
||||
func (b *WSLBackend) DiskUsage(path string) (map[string]interface{}, error) {
|
||||
wslPath := windowsToWSLPath(path)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "stat", wslPath)
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot stat path %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Parse stat output minimally
|
||||
result := map[string]interface{}{
|
||||
"path": path,
|
||||
"wsl_path": wslPath,
|
||||
"stat": strings.TrimSpace(string(out)),
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// windowsToWSLPath converts a Windows path to its WSL equivalent.
|
||||
// C:\Users\foo → /mnt/c/Users/foo
|
||||
// If the path is already a WSL path (starts with /), return as-is.
|
||||
func windowsToWSLPath(path string) string {
|
||||
if strings.HasPrefix(path, "/") {
|
||||
return path // Already a Unix path
|
||||
}
|
||||
// Handle Windows drive letter: C:\... → /mnt/c/...
|
||||
if len(path) >= 2 && path[1] == ':' {
|
||||
drive := strings.ToLower(string(path[0]))
|
||||
rest := strings.TrimPrefix(path[2:], "\\")
|
||||
rest = strings.ReplaceAll(rest, "\\", "/")
|
||||
return fmt.Sprintf("/mnt/%s/%s", drive, rest)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
// shellEscape escapes a string for safe use in a shell command.
|
||||
func shellEscape(s string) string {
|
||||
// Use single quotes and escape any single quotes in the string
|
||||
escaped := strings.ReplaceAll(s, "'", "'\\''")
|
||||
return "'" + escaped + "'"
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
package host
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWSLBackendIntegration(t *testing.T) {
|
||||
distro := os.Getenv("WSL_DISTRO")
|
||||
if distro == "" {
|
||||
t.Skip("WSL_DISTRO not set, skipping WSL integration test (set WSL_DISTRO=cyrene-wsl to run)")
|
||||
}
|
||||
|
||||
backend := NewWSLBackend(distro, "cyrene", "test123", 30*time.Second)
|
||||
mgr := NewManager(backend)
|
||||
ctx := context.Background()
|
||||
|
||||
// 1. Basic command
|
||||
t.Run("echo", func(t *testing.T) {
|
||||
r, err := mgr.Exec(ctx, "echo 'hello from WSL OS env'", "", 10*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("exec failed: %v", err)
|
||||
}
|
||||
if r.ExitCode != 0 {
|
||||
t.Fatalf("exit=%d, stderr=%s", r.ExitCode, r.Stderr)
|
||||
}
|
||||
if !strings.Contains(r.Stdout, "hello from WSL OS env") {
|
||||
t.Fatalf("unexpected stdout: %s", r.Stdout)
|
||||
}
|
||||
t.Logf("echo OK: %s (duration=%s)", strings.TrimSpace(r.Stdout), r.Duration)
|
||||
})
|
||||
|
||||
// 2. Complex commands - package manager
|
||||
t.Run("apt", func(t *testing.T) {
|
||||
r, err := mgr.Exec(ctx, "apt --version 2>&1", "", 10*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("exec failed: %v", err)
|
||||
}
|
||||
t.Logf("apt OK: %s", strings.TrimSpace(r.Stdout))
|
||||
})
|
||||
|
||||
// 3. Python (should be pre-installed on Ubuntu)
|
||||
t.Run("python", func(t *testing.T) {
|
||||
r, err := mgr.Exec(ctx, "python3 --version 2>&1", "", 10*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("exec failed: %v", err)
|
||||
}
|
||||
t.Logf("python OK: %s", strings.TrimSpace(r.Stdout))
|
||||
})
|
||||
|
||||
// 4. Pipeline & shell features
|
||||
t.Run("pipeline", func(t *testing.T) {
|
||||
r, err := mgr.Exec(ctx, "echo 'a\nb\nc\nd' | wc -l", "", 10*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("exec failed: %v", err)
|
||||
}
|
||||
if r.ExitCode != 0 {
|
||||
t.Fatalf("exit=%d", r.ExitCode)
|
||||
}
|
||||
t.Logf("pipeline OK: %s", strings.TrimSpace(r.Stdout))
|
||||
})
|
||||
|
||||
// 5. File write & read
|
||||
t.Run("file_rw", func(t *testing.T) {
|
||||
err := mgr.WriteFile("/tmp/cyrene-wsl-test.txt", "Hello from Cyrene OS!", 1024*1024)
|
||||
if err != nil {
|
||||
t.Fatalf("write failed: %v", err)
|
||||
}
|
||||
content, err := mgr.ReadFile("/tmp/cyrene-wsl-test.txt", 1024*1024)
|
||||
if err != nil {
|
||||
t.Fatalf("read failed: %v", err)
|
||||
}
|
||||
if content != "Hello from Cyrene OS!" {
|
||||
t.Fatalf("content mismatch: %q", content)
|
||||
}
|
||||
t.Logf("file r/w OK: %q", content)
|
||||
})
|
||||
|
||||
// 6. Directory listing
|
||||
t.Run("listdir", func(t *testing.T) {
|
||||
entries, err := mgr.ListDir("/etc")
|
||||
if err != nil {
|
||||
t.Fatalf("listdir failed: %v", err)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
t.Fatal("expected entries in /etc")
|
||||
}
|
||||
t.Logf("listdir OK: %d entries in /etc", len(entries))
|
||||
for _, e := range entries {
|
||||
if e.Name == "os-release" || e.Name == "hostname" {
|
||||
t.Logf(" - %s (isDir=%v, size=%d)", e.Name, e.IsDir, e.Size)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 7. System info
|
||||
t.Run("sysinfo", func(t *testing.T) {
|
||||
info := mgr.SystemInfo()
|
||||
if info["backend"] != "wsl" {
|
||||
t.Fatalf("unexpected backend: %v", info["backend"])
|
||||
}
|
||||
if info["distro"] != distro {
|
||||
t.Fatalf("unexpected distro: %v", info["distro"])
|
||||
}
|
||||
t.Logf("sysinfo OK: backend=%v, distro=%v", info["backend"], info["distro"])
|
||||
if uname, ok := info["uname"]; ok {
|
||||
t.Logf(" uname: %v", uname)
|
||||
}
|
||||
if hostname, ok := info["hostname"]; ok {
|
||||
t.Logf(" hostname: %v", hostname)
|
||||
}
|
||||
if mem, ok := info["memory"]; ok {
|
||||
t.Logf(" memory: %v", mem)
|
||||
}
|
||||
})
|
||||
|
||||
// 8. workDir
|
||||
t.Run("workdir", func(t *testing.T) {
|
||||
r, err := mgr.Exec(ctx, "pwd", "/tmp", 10*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("exec failed: %v", err)
|
||||
}
|
||||
if !strings.Contains(r.Stdout, "/tmp") {
|
||||
t.Fatalf("expected /tmp, got: %s", r.Stdout)
|
||||
}
|
||||
t.Logf("workdir OK: pwd=%s", strings.TrimSpace(r.Stdout))
|
||||
})
|
||||
|
||||
// 9. Timeout
|
||||
t.Run("timeout", func(t *testing.T) {
|
||||
r, err := mgr.Exec(ctx, "sleep 10", "", 1*time.Second)
|
||||
if err == nil {
|
||||
t.Fatal("expected timeout")
|
||||
}
|
||||
if !r.TimedOut {
|
||||
t.Fatal("expected TimedOut=true")
|
||||
}
|
||||
t.Logf("timeout OK: timed_out=%v", r.TimedOut)
|
||||
})
|
||||
}
|
||||
@@ -2,207 +2,72 @@ package host
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Manager provides controlled access to the host machine.
|
||||
// It wraps a Sandbox for command execution and adds file system
|
||||
// operations with path allow-list enforcement.
|
||||
// HostBackend defines the interface for command execution and file system
|
||||
// operations. Implementations include DirectBackend (host OS), WSLBackend
|
||||
// (Windows Subsystem for Linux), and DockerBackend (container).
|
||||
type HostBackend interface {
|
||||
Exec(ctx context.Context, command, workDir string, timeout time.Duration) (*ExecResult, error)
|
||||
ReadFile(path string, maxBytes int) (string, error)
|
||||
WriteFile(path, content string, maxBytes int) error
|
||||
ListDir(path string) ([]DirEntry, error)
|
||||
SystemInfo() map[string]interface{}
|
||||
DiskUsage(path string) (map[string]interface{}, error)
|
||||
Name() string
|
||||
}
|
||||
|
||||
// Manager provides controlled access to the host machine. It delegates
|
||||
// to a HostBackend implementation which may be direct, WSL, or Docker.
|
||||
type Manager struct {
|
||||
sandbox *Sandbox
|
||||
allowedDirs []string
|
||||
backend HostBackend
|
||||
}
|
||||
|
||||
// NewManager creates a new host Manager.
|
||||
func NewManager(sandbox *Sandbox) *Manager {
|
||||
m := &Manager{sandbox: sandbox}
|
||||
if sandbox != nil {
|
||||
m.allowedDirs = sandbox.cfg.AllowedDirs
|
||||
}
|
||||
return m
|
||||
// NewManager creates a new host Manager with the given backend.
|
||||
func NewManager(backend HostBackend) *Manager {
|
||||
return &Manager{backend: backend}
|
||||
}
|
||||
|
||||
// SetAllowedDirs updates the list of directories accessible for file operations.
|
||||
// SetAllowedDirs updates directory restrictions. Only effective for
|
||||
// DirectBackend; WSL and Docker backends are no-ops.
|
||||
func (m *Manager) SetAllowedDirs(dirs []string) {
|
||||
m.allowedDirs = dirs
|
||||
m.sandbox.cfg.AllowedDirs = dirs
|
||||
if db, ok := m.backend.(*DirectBackend); ok {
|
||||
db.SetAllowedDirs(dirs)
|
||||
}
|
||||
}
|
||||
|
||||
// Exec runs a command in the sandbox.
|
||||
// Exec runs a command via the configured backend.
|
||||
func (m *Manager) Exec(ctx context.Context, command, workDir string, timeout time.Duration) (*ExecResult, error) {
|
||||
return m.sandbox.Exec(ctx, command, workDir, timeout)
|
||||
return m.backend.Exec(ctx, command, workDir, timeout)
|
||||
}
|
||||
|
||||
// ReadFile reads the contents of a file within allowed directories.
|
||||
// ReadFile reads a file via the configured backend.
|
||||
func (m *Manager) ReadFile(path string, maxBytes int) (string, error) {
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 1024 * 1024
|
||||
}
|
||||
if err := m.validatePath(path); err != nil {
|
||||
return "", err
|
||||
}
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot stat file: %w", err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return "", fmt.Errorf("path is a directory: %s", path)
|
||||
}
|
||||
if info.Size() > int64(maxBytes) {
|
||||
return "", fmt.Errorf("file too large: %d bytes (max %d)", info.Size(), maxBytes)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot read file: %w", err)
|
||||
}
|
||||
if len(data) > maxBytes {
|
||||
data = data[:maxBytes]
|
||||
}
|
||||
return string(data), nil
|
||||
return m.backend.ReadFile(path, maxBytes)
|
||||
}
|
||||
|
||||
// WriteFile writes data to a file within allowed directories.
|
||||
// WriteFile writes a file via the configured backend.
|
||||
func (m *Manager) WriteFile(path, content string, maxBytes int) error {
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 1024 * 1024
|
||||
}
|
||||
if len(content) > maxBytes {
|
||||
return fmt.Errorf("content too large: %d bytes (max %d)", len(content), maxBytes)
|
||||
}
|
||||
if err := m.validatePath(path); err != nil {
|
||||
return err
|
||||
}
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("cannot create directory: %w", err)
|
||||
}
|
||||
return os.WriteFile(path, []byte(content), 0644)
|
||||
return m.backend.WriteFile(path, content, maxBytes)
|
||||
}
|
||||
|
||||
// ListDir lists directory contents within allowed directories.
|
||||
// ListDir lists a directory via the configured backend.
|
||||
func (m *Manager) ListDir(path string) ([]DirEntry, error) {
|
||||
if err := m.validatePath(path); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read directory: %w", err)
|
||||
}
|
||||
result := make([]DirEntry, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
info, _ := e.Info()
|
||||
size := int64(0)
|
||||
modTime := time.Time{}
|
||||
if info != nil {
|
||||
size = info.Size()
|
||||
modTime = info.ModTime()
|
||||
}
|
||||
result = append(result, DirEntry{
|
||||
Name: e.Name(),
|
||||
IsDir: e.IsDir(),
|
||||
Size: size,
|
||||
ModTime: modTime.Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
return m.backend.ListDir(path)
|
||||
}
|
||||
|
||||
// DirEntry represents a filesystem directory entry.
|
||||
type DirEntry struct {
|
||||
Name string `json:"name"`
|
||||
IsDir bool `json:"is_dir"`
|
||||
Size int64 `json:"size"`
|
||||
ModTime string `json:"mod_time"`
|
||||
}
|
||||
|
||||
// SystemInfo returns basic system information.
|
||||
// SystemInfo returns system information from the configured backend.
|
||||
func (m *Manager) SystemInfo() map[string]interface{} {
|
||||
hostname, _ := os.Hostname()
|
||||
wd, _ := os.Getwd()
|
||||
|
||||
info := map[string]interface{}{
|
||||
"hostname": hostname,
|
||||
"os": runtime.GOOS,
|
||||
"arch": runtime.GOARCH,
|
||||
"num_cpu": runtime.NumCPU(),
|
||||
"go_version": runtime.Version(),
|
||||
"work_dir": wd,
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
cmd := exec.Command("systeminfo")
|
||||
out, err := cmd.Output()
|
||||
if err == nil {
|
||||
lines := strings.Split(string(out), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.Contains(line, "Total Physical Memory") {
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
info["total_memory"] = strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
if strings.Contains(line, "OS Name") {
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
info["os_name"] = strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if data, err := os.ReadFile("/proc/meminfo"); err == nil {
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if strings.HasPrefix(line, "MemTotal:") {
|
||||
info["total_memory"] = strings.TrimSpace(strings.TrimPrefix(line, "MemTotal:"))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return info
|
||||
return m.backend.SystemInfo()
|
||||
}
|
||||
|
||||
// DiskUsage returns disk usage for the given path.
|
||||
// DiskUsage returns disk usage info from the configured backend.
|
||||
func (m *Manager) DiskUsage(path string) (map[string]interface{}, error) {
|
||||
if err := m.validatePath(path); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot stat path: %w", err)
|
||||
}
|
||||
result := map[string]interface{}{
|
||||
"path": path,
|
||||
"is_dir": info.IsDir(),
|
||||
"size": info.Size(),
|
||||
"mod_time": info.ModTime().Format(time.RFC3339),
|
||||
}
|
||||
return result, nil
|
||||
return m.backend.DiskUsage(path)
|
||||
}
|
||||
|
||||
func (m *Manager) validatePath(path string) error {
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot resolve path: %w", err)
|
||||
}
|
||||
if len(m.allowedDirs) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, allowed := range m.allowedDirs {
|
||||
absAllowed, err := filepath.Abs(allowed)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(absPath, absAllowed+string(os.PathSeparator)) || absPath == absAllowed {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("path not in allowed directories: %s", path)
|
||||
// BackendName returns the name of the active backend.
|
||||
func (m *Manager) BackendName() string {
|
||||
return m.backend.Name()
|
||||
}
|
||||
|
||||
@@ -50,6 +50,14 @@ func NewSandbox(cfg SandboxConfig) *Sandbox {
|
||||
return &Sandbox{cfg: cfg}
|
||||
}
|
||||
|
||||
// DirEntry represents a filesystem directory entry.
|
||||
type DirEntry struct {
|
||||
Name string `json:"name"`
|
||||
IsDir bool `json:"is_dir"`
|
||||
Size int64 `json:"size"`
|
||||
ModTime string `json:"mod_time,omitempty"`
|
||||
}
|
||||
|
||||
// ExecResult holds the result of a sandboxed command execution.
|
||||
type ExecResult struct {
|
||||
Stdout string `json:"stdout"`
|
||||
|
||||
@@ -62,7 +62,7 @@ func TestManagerFileOps(t *testing.T) {
|
||||
tmpDir := os.TempDir()
|
||||
cfg.AllowedDirs = []string{tmpDir}
|
||||
sandbox := NewSandbox(cfg)
|
||||
mgr := NewManager(sandbox)
|
||||
mgr := NewManager(NewDirectBackend(sandbox))
|
||||
mgr.SetAllowedDirs([]string{tmpDir})
|
||||
|
||||
testPath := filepath.Join(tmpDir, "cyrene-test-file.txt")
|
||||
@@ -102,7 +102,7 @@ func TestManagerFileOps(t *testing.T) {
|
||||
func TestManagerSystemInfo(t *testing.T) {
|
||||
cfg := DefaultSandboxConfig()
|
||||
sandbox := NewSandbox(cfg)
|
||||
mgr := NewManager(sandbox)
|
||||
mgr := NewManager(NewDirectBackend(sandbox))
|
||||
|
||||
info := mgr.SystemInfo()
|
||||
if info["hostname"] == nil || info["hostname"] == "" {
|
||||
@@ -121,7 +121,7 @@ func TestPathValidation(t *testing.T) {
|
||||
cfg := DefaultSandboxConfig()
|
||||
cfg.AllowedDirs = []string{os.TempDir()}
|
||||
sandbox := NewSandbox(cfg)
|
||||
mgr := NewManager(sandbox)
|
||||
mgr := NewManager(NewDirectBackend(sandbox))
|
||||
mgr.SetAllowedDirs([]string{os.TempDir()})
|
||||
|
||||
// Should fail: access outside allowed dirs
|
||||
|
||||
@@ -4,15 +4,16 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"github.com/yourname/cyrene-ai/pkg/logger"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
"github.com/yourname/cyrene-ai/pkg/logger"
|
||||
)
|
||||
|
||||
// OpenAIConfig OpenAI适配器配置
|
||||
@@ -267,12 +268,13 @@ func (p *OpenAIProvider) doChat(ctx context.Context, messages []model.LLMMessage
|
||||
LogCall(r)
|
||||
}()
|
||||
|
||||
// 转换消息格式
|
||||
// 转换消息格式(先解析图片 URL 为 data URL)
|
||||
oaiMessages := make([]openAIMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
resolvedImages := p.resolveImages(msg.Images)
|
||||
oaiMsg := openAIMessage{
|
||||
Role: string(msg.Role),
|
||||
Content: buildContent(msg.Content, msg.Images),
|
||||
Content: buildContent(msg.Content, resolvedImages),
|
||||
Name: msg.Name,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
@@ -377,9 +379,10 @@ func (p *OpenAIProvider) doChat(ctx context.Context, messages []model.LLMMessage
|
||||
func (p *OpenAIProvider) doChatStream(ctx context.Context, messages []model.LLMMessage, modelName string, tools []OpenAITool) (*http.Response, error) {
|
||||
oaiMessages := make([]openAIMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
resolvedImages := p.resolveImages(msg.Images)
|
||||
oaiMsg := openAIMessage{
|
||||
Role: string(msg.Role),
|
||||
Content: buildContent(msg.Content, msg.Images),
|
||||
Content: buildContent(msg.Content, resolvedImages),
|
||||
Name: msg.Name,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
@@ -455,6 +458,67 @@ func contentString(v interface{}) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// resolveImages converts non-data URLs to base64 data URLs so external LLM APIs can access them.
|
||||
func (p *OpenAIProvider) resolveImages(images []string) []string {
|
||||
if len(images) == 0 {
|
||||
return images
|
||||
}
|
||||
resolved := make([]string, 0, len(images))
|
||||
for _, img := range images {
|
||||
if strings.HasPrefix(img, "data:") {
|
||||
resolved = append(resolved, img)
|
||||
continue
|
||||
}
|
||||
dataURL, err := p.downloadAsDataURL(img)
|
||||
if err != nil {
|
||||
logger.Printf("[openai] 图片下载失败, 保留原始 URL: %s, err=%v", img, err)
|
||||
resolved = append(resolved, img) // 保留原始 URL 作为 fallback
|
||||
continue
|
||||
}
|
||||
resolved = append(resolved, dataURL)
|
||||
}
|
||||
return resolved
|
||||
}
|
||||
|
||||
// downloadAsDataURL downloads an image from a URL and returns it as a base64 data URL.
|
||||
func (p *OpenAIProvider) downloadAsDataURL(url string) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("下载失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 限制最大 20MB
|
||||
const maxSize = 20 * 1024 * 1024
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize+1))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取失败: %w", err)
|
||||
}
|
||||
if len(body) > maxSize {
|
||||
return "", fmt.Errorf("图片过大: %d bytes", len(body))
|
||||
}
|
||||
|
||||
mimeType := resp.Header.Get("Content-Type")
|
||||
if mimeType == "" {
|
||||
mimeType = http.DetectContentType(body)
|
||||
}
|
||||
|
||||
b64 := base64.StdEncoding.EncodeToString(body)
|
||||
return fmt.Sprintf("data:%s;base64,%s", mimeType, b64), nil
|
||||
}
|
||||
|
||||
// buildContent converts text + optional images to API content format.
|
||||
// Returns a plain string if no images, or a multimodal array otherwise.
|
||||
func buildContent(text string, images []string) interface{} {
|
||||
|
||||
@@ -19,6 +19,7 @@ const (
|
||||
PurposeToolCalling ModelPurpose = "tool_calling"
|
||||
PurposeMemoryExtraction ModelPurpose = "memory_extraction"
|
||||
PurposeVision ModelPurpose = "vision"
|
||||
PurposeOCR ModelPurpose = "ocr"
|
||||
)
|
||||
|
||||
// ErrModelNotRequired is returned when an optional model is unavailable.
|
||||
|
||||
@@ -167,6 +167,31 @@ func (c *Client) GetByID(ctx context.Context, id string) (*model.MemoryEntry, er
|
||||
return &result.Memory, nil
|
||||
}
|
||||
|
||||
// Update 更新记忆
|
||||
func (c *Client) Update(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
body, _ := json.Marshal(map[string]interface{}{
|
||||
"content": entry.Content,
|
||||
"summary": entry.Summary,
|
||||
"category": string(entry.Category),
|
||||
"priority": int(entry.Priority),
|
||||
"importance": entry.Importance,
|
||||
"keywords": entry.Keywords,
|
||||
"source": entry.Source,
|
||||
})
|
||||
|
||||
resp, err := c.doRequest(ctx, http.MethodPut, c.baseURL+"/api/v1/memories/"+entry.ID, body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新记忆失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("更新记忆失败 (%d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除记忆
|
||||
func (c *Client) Delete(ctx context.Context, id string) error {
|
||||
resp, err := c.doRequest(ctx, http.MethodDelete, c.baseURL+"/api/v1/memories/"+id, nil)
|
||||
|
||||
@@ -8,6 +8,17 @@ type EnrichmentData struct {
|
||||
ThoughtOutline string
|
||||
IoTSummary string
|
||||
KnowledgeInfo string
|
||||
|
||||
// Pending tool results from async execution (keyed by tool call ID)
|
||||
PendingToolResults []PendingToolResult
|
||||
}
|
||||
|
||||
// PendingToolResult holds the result of a tool that completed asynchronously.
|
||||
type PendingToolResult struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Result string `json:"result"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
// SessionEnrichmentStore is a thread-safe per-session cache for async
|
||||
@@ -25,8 +36,15 @@ func NewEnrichmentStore() *SessionEnrichmentStore {
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns stored enrichment for a session and clears it (one-shot consumption).
|
||||
// Get returns stored enrichment for a session (does NOT clear; results may be reused).
|
||||
func (s *SessionEnrichmentStore) Get(sessionID string) *EnrichmentData {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.data[sessionID]
|
||||
}
|
||||
|
||||
// Pop returns stored enrichment for a session and clears it (one-shot consumption).
|
||||
func (s *SessionEnrichmentStore) Pop(sessionID string) *EnrichmentData {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
d, ok := s.data[sessionID]
|
||||
@@ -45,3 +63,32 @@ func (s *SessionEnrichmentStore) Store(sessionID string, d *EnrichmentData) {
|
||||
s.data[sessionID] = d
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// AppendToolResult adds a completed tool result to the session's enrichment data.
|
||||
func (s *SessionEnrichmentStore) AppendToolResult(sessionID string, r PendingToolResult) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
d, ok := s.data[sessionID]
|
||||
if !ok {
|
||||
d = &EnrichmentData{}
|
||||
s.data[sessionID] = d
|
||||
}
|
||||
d.PendingToolResults = append(d.PendingToolResults, r)
|
||||
}
|
||||
|
||||
// ---- Global pending tool store (used by Synthesizer for async tool results) ----
|
||||
|
||||
var globalPendingToolStore *SessionEnrichmentStore
|
||||
var pendingToolStoreOnce sync.Once
|
||||
|
||||
// InitGlobalPendingToolStore initializes the singleton.
|
||||
func InitGlobalPendingToolStore() {
|
||||
pendingToolStoreOnce.Do(func() {
|
||||
globalPendingToolStore = NewEnrichmentStore()
|
||||
})
|
||||
}
|
||||
|
||||
// GetGlobalPendingToolStore returns the singleton, or nil if not initialized.
|
||||
func GetGlobalPendingToolStore() *SessionEnrichmentStore {
|
||||
return globalPendingToolStore
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ type Orchestrator struct {
|
||||
msgScheduler *scheduler.MessageScheduler
|
||||
emotionTracker *persona.EmotionTracker
|
||||
toolRegistry *plgManager.ToolRegistry
|
||||
visionProvider llm.LLMProvider // 视觉模型 (图片预处理/OCR)
|
||||
}
|
||||
|
||||
// SetResponseCache sets the response cache (optional, for Phase 0.2).
|
||||
@@ -71,6 +72,11 @@ func (o *Orchestrator) SetToolRegistry(tr *plgManager.ToolRegistry) {
|
||||
o.synthesizer.toolRegistry = tr
|
||||
}
|
||||
|
||||
// SetVisionProvider sets the vision model provider for image preprocessing.
|
||||
func (o *Orchestrator) SetVisionProvider(vp llm.LLMProvider) {
|
||||
o.visionProvider = vp
|
||||
}
|
||||
|
||||
// getBus returns the bus or a nop fallback.
|
||||
func (o *Orchestrator) getBus() bus.Bus {
|
||||
if o.eventBus == nil {
|
||||
@@ -149,7 +155,27 @@ func (o *Orchestrator) ProcessInput(
|
||||
UserID: params.UserID,
|
||||
})
|
||||
|
||||
// 1. 意图分析
|
||||
// 0.5 图片预处理: 使用视觉模型分析图片,将描述注入消息
|
||||
if len(params.Images) > 0 && o.visionProvider != nil {
|
||||
startTime := time.Now()
|
||||
augmented := o.preprocessImages(ctx, params.Message, params.Images)
|
||||
if augmented != params.Message {
|
||||
params.Message = augmented
|
||||
logger.Printf("[orchestrator] 图片预处理耗时: %v, 原消息=%d字, 增强后=%d字",
|
||||
time.Since(startTime), len([]rune(params.Message))-len([]rune(augmented))+len([]rune(params.Message)), len([]rune(augmented)))
|
||||
}
|
||||
// 预处理后清空原始图片,避免后续传给不支持多模态的 Chat 模型
|
||||
params.Images = nil
|
||||
} else if len(params.Images) > 0 {
|
||||
// 未配置 Vision 模型时,告知用户该模型不支持图片,并清空图片避免报错
|
||||
if params.Message == "" {
|
||||
params.Message = "(用户发送了一张图片,但当前未配置视觉模型,无法识别图片内容)"
|
||||
}
|
||||
logger.Printf("[orchestrator] 视觉模型未配置,丢弃 %d 张图片", len(params.Images))
|
||||
params.Images = nil
|
||||
}
|
||||
|
||||
// 1. 意图分析
|
||||
startTime := time.Now()
|
||||
intent, err := o.intentAnalyzer.Analyze(ctx, params.Message)
|
||||
if err != nil || intent == nil {
|
||||
@@ -247,17 +273,39 @@ func (o *Orchestrator) ProcessInput(
|
||||
resultCh = o.subManager.Dispatch(subCtx, intent, params.Message, createParams)
|
||||
}
|
||||
|
||||
// 3.5 确保全局工具结果存储已初始化
|
||||
InitGlobalPendingToolStore()
|
||||
|
||||
// 4. 加载上一轮异步完成的子会话富化结果
|
||||
var prevEnrichment *EnrichmentData
|
||||
if o.enrichmentStore != nil {
|
||||
prevEnrichment = o.enrichmentStore.Get(params.SessionID)
|
||||
if prevEnrichment != nil {
|
||||
logger.Printf("[orchestrator] 加载上一轮富化结果: memory=%t thought=%t iot=%t knowledge=%t",
|
||||
prevEnrichment.MemorySummary != "",
|
||||
prevEnrichment.ThoughtOutline != "",
|
||||
prevEnrichment.IoTSummary != "",
|
||||
prevEnrichment.KnowledgeInfo != "")
|
||||
prevEnrichment = o.enrichmentStore.Pop(params.SessionID)
|
||||
// Also merge any pending tool results from the global store
|
||||
if globalStore := GetGlobalPendingToolStore(); globalStore != nil {
|
||||
if toolData := globalStore.Pop(params.SessionID); toolData != nil && len(toolData.PendingToolResults) > 0 {
|
||||
if prevEnrichment == nil {
|
||||
prevEnrichment = &EnrichmentData{}
|
||||
}
|
||||
prevEnrichment.PendingToolResults = append(prevEnrichment.PendingToolResults, toolData.PendingToolResults...)
|
||||
logger.Printf("[orchestrator] 合并后台工具结果 %d 条", len(toolData.PendingToolResults))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Still check global store even if enrichmentStore is not set
|
||||
if globalStore := GetGlobalPendingToolStore(); globalStore != nil {
|
||||
if toolData := globalStore.Pop(params.SessionID); toolData != nil && len(toolData.PendingToolResults) > 0 {
|
||||
prevEnrichment = toolData
|
||||
logger.Printf("[orchestrator] 加载后台工具结果 %d 条", len(toolData.PendingToolResults))
|
||||
}
|
||||
}
|
||||
}
|
||||
if prevEnrichment != nil {
|
||||
logger.Printf("[orchestrator] 加载上一轮富化结果: memory=%t thought=%t iot=%t knowledge=%t tools=%d",
|
||||
prevEnrichment.MemorySummary != "",
|
||||
prevEnrichment.ThoughtOutline != "",
|
||||
prevEnrichment.IoTSummary != "",
|
||||
prevEnrichment.KnowledgeInfo != "",
|
||||
len(prevEnrichment.PendingToolResults))
|
||||
}
|
||||
|
||||
// 5. 先构建基础综合参数(不含子会话结果),开始合成
|
||||
@@ -284,6 +332,7 @@ func (o *Orchestrator) ProcessInput(
|
||||
synthParams.ThoughtOutline = prevEnrichment.ThoughtOutline
|
||||
synthParams.IoTSummary = prevEnrichment.IoTSummary
|
||||
synthParams.KnowledgeInfo = prevEnrichment.KnowledgeInfo
|
||||
synthParams.PendingToolResults = prevEnrichment.PendingToolResults
|
||||
}
|
||||
|
||||
// 异步收集子会话结果,存入 enrichmentStore 供下一轮使用
|
||||
@@ -324,7 +373,7 @@ func (o *Orchestrator) ProcessInput(
|
||||
}()
|
||||
|
||||
// 5. 调用 Synthesizer 流式生成最终回复
|
||||
chunkCh, err := o.synthesizer.Synthesize(ctx, synthParams)
|
||||
chunkCh, err := o.synthesizer.Synthesize(ctx, synthParams, eventCh)
|
||||
if err != nil {
|
||||
logger.Printf("[orchestrator] 综合器启动失败: %v", err)
|
||||
eventCh <- model.StreamEvent{
|
||||
@@ -601,6 +650,46 @@ func (o *Orchestrator) CacheMessage(sessionID string, role model.Role, content s
|
||||
}
|
||||
}
|
||||
|
||||
// preprocessImages uses the vision model to analyze images and augments the user message.
|
||||
// For standalone images (no text): generates a comprehensive description as the message.
|
||||
// For text+images: appends image descriptions as contextual annotations.
|
||||
func (o *Orchestrator) preprocessImages(ctx context.Context, message string, images []string) string {
|
||||
var prompt string
|
||||
if message == "" {
|
||||
prompt = "请详细描述这张图片的内容,包括场景、物体、人物、文字(如有)、颜色、氛围等所有视觉信息。"
|
||||
} else {
|
||||
prompt = fmt.Sprintf("用户的问题是:「%s」\n\n请根据用户的问题,分析这张图片中相关的视觉信息,帮助回答用户的问题。如果图片中有文字,请完整提取。", message)
|
||||
}
|
||||
|
||||
var descriptions []string
|
||||
for i, img := range images {
|
||||
resp, err := o.visionProvider.Chat(ctx, []model.LLMMessage{
|
||||
{Role: model.RoleUser, Content: prompt, Images: []string{img}},
|
||||
})
|
||||
if err != nil {
|
||||
logger.Printf("[orchestrator] 图片 %d 预处理失败: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if resp.Content != "" {
|
||||
descriptions = append(descriptions, resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(descriptions) == 0 {
|
||||
return message
|
||||
}
|
||||
|
||||
if message == "" {
|
||||
return strings.Join(descriptions, "\n\n")
|
||||
}
|
||||
|
||||
augmented := message
|
||||
for i, desc := range descriptions {
|
||||
augmented += fmt.Sprintf("\n\n[图片%d的视觉分析]: %s", i+1, desc)
|
||||
}
|
||||
return augmented
|
||||
}
|
||||
|
||||
// Ensure time, memory are used
|
||||
var _ = time.Now
|
||||
var _ = memory.NewRetriever
|
||||
|
||||
@@ -14,7 +14,7 @@ var codeBlockPattern = regexp.MustCompile("`{3}([^\n]*)\n([\\s\\S]*?)`{3}")
|
||||
var markdownPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`^#{1,6}\s`), // headings
|
||||
regexp.MustCompile(`\*\*[^*]+\*\*`), // bold
|
||||
regexp.MustCompile(`(?<!\*)\*[^*]+\*(?!\*)`), // italic (single *)
|
||||
regexp.MustCompile(`(?:^|[^*])\*([^*]+)\*(?:[^*]|$)`), // italic (*text*)
|
||||
regexp.MustCompile(`\[([^\]]+)\]\(([^\)]+)\)`), // links [text](url)
|
||||
regexp.MustCompile(`^[\-\*]\s`), // unordered list
|
||||
regexp.MustCompile(`^\d+\.\s`), // ordered list
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/llm"
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
@@ -30,23 +31,25 @@ func NewSynthesizer(llmAdapter *llm.Adapter, toolRegistry *plgManager.ToolRegist
|
||||
|
||||
// SynthesizeParams 综合参数
|
||||
type SynthesizeParams struct {
|
||||
UserID string
|
||||
SessionID string
|
||||
UserMessage string
|
||||
Images []string // 图片 base64 data URL (多模态)
|
||||
Nickname string
|
||||
PersonaPrompt string // 完整人格提示词
|
||||
DialogHistory []model.LLMMessage // 对话历史
|
||||
MemorySummary string // 记忆检索摘要
|
||||
ThoughtOutline string // 通用对话思考
|
||||
IoTSummary string // IoT 操作摘要
|
||||
DeviceContext string // 设备状态上下文
|
||||
KnowledgeInfo string // 知识库检索摘要
|
||||
Mode string // text / voice_assistant
|
||||
UserID string
|
||||
SessionID string
|
||||
UserMessage string
|
||||
Images []string // 图片 base64 data URL (多模态)
|
||||
Nickname string
|
||||
PersonaPrompt string // 完整人格提示词
|
||||
DialogHistory []model.LLMMessage // 对话历史
|
||||
MemorySummary string // 记忆检索摘要
|
||||
ThoughtOutline string // 通用对话思考
|
||||
IoTSummary string // IoT 操作摘要
|
||||
DeviceContext string // 设备状态上下文
|
||||
KnowledgeInfo string // 知识库检索摘要
|
||||
PendingToolResults []PendingToolResult // 上一轮异步完成的工具结果
|
||||
Mode string // text / voice_assistant
|
||||
}
|
||||
|
||||
// Synthesize 综合所有子会话结果,流式生成最终回复
|
||||
func (s *Synthesizer) Synthesize(ctx context.Context, params SynthesizeParams) (<-chan llm.StreamChunk, error) {
|
||||
// Synthesize 综合所有子会话结果,流式生成最终回复。
|
||||
// eventCh receives tool progress events; pass nil to suppress.
|
||||
func (s *Synthesizer) Synthesize(ctx context.Context, params SynthesizeParams, eventCh chan<- model.StreamEvent) (<-chan llm.StreamChunk, error) {
|
||||
messages := s.buildSynthesizeMessages(params)
|
||||
|
||||
logger.Printf("[synthesizer] 开始综合 (上下文 %d 条消息)", len(messages))
|
||||
@@ -62,7 +65,9 @@ func (s *Synthesizer) Synthesize(ctx context.Context, params SynthesizeParams) (
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxRounds := 5
|
||||
const toolDeadline = 8 * time.Second
|
||||
const maxRounds = 5
|
||||
|
||||
for round := 0; len(resp.ToolCalls) > 0 && round < maxRounds; round++ {
|
||||
logger.Printf("[synthesizer] LLM 请求 %d 个工具调用 (round=%d)", len(resp.ToolCalls), round)
|
||||
|
||||
@@ -80,7 +85,12 @@ func (s *Synthesizer) Synthesize(ctx context.Context, params SynthesizeParams) (
|
||||
args = make(map[string]interface{})
|
||||
}
|
||||
|
||||
result, execErr := s.toolRegistry.Execute(ctx, tc.Name, args)
|
||||
s.emitToolProgress(eventCh, tc.Name, "started", 0, "正在执行 "+tc.Name)
|
||||
|
||||
toolCtx, cancel := context.WithTimeout(ctx, toolDeadline)
|
||||
result, execErr := s.toolRegistry.Execute(toolCtx, tc.Name, args)
|
||||
cancel()
|
||||
|
||||
if execErr != nil {
|
||||
logger.Printf("[synthesizer] 工具 %s 执行失败: %v", tc.Name, execErr)
|
||||
}
|
||||
@@ -88,6 +98,19 @@ func (s *Synthesizer) Synthesize(ctx context.Context, params SynthesizeParams) (
|
||||
result = &plgSDK.ToolResult{ToolName: tc.Name, Success: false, Error: execErr.Error()}
|
||||
}
|
||||
|
||||
// Async fallback: if tool timed out, store for next turn
|
||||
if toolCtx.Err() == context.DeadlineExceeded {
|
||||
s.emitToolProgress(eventCh, tc.Name, "running", 0.5, tc.Name+" 执行时间较长,转入后台继续...")
|
||||
go s.executeAsyncAndStore(tc, args, params.SessionID, eventCh)
|
||||
result = &plgSDK.ToolResult{
|
||||
ToolName: tc.Name,
|
||||
Success: true,
|
||||
Output: fmt.Sprintf("[后台执行中] %s 正在后台运行,结果将在下一轮对话中返回。你可以继续聊天。", tc.Name),
|
||||
}
|
||||
} else {
|
||||
s.emitToolProgress(eventCh, tc.Name, "completed", 1.0, "")
|
||||
}
|
||||
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
messages = append(messages, model.LLMMessage{
|
||||
Role: model.RoleTool,
|
||||
@@ -120,6 +143,51 @@ func (s *Synthesizer) Synthesize(ctx context.Context, params SynthesizeParams) (
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// emitToolProgress sends a StreamToolProgress event if eventCh is available.
|
||||
func (s *Synthesizer) emitToolProgress(eventCh chan<- model.StreamEvent, name, status string, progress float64, message string) {
|
||||
if eventCh == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case eventCh <- model.StreamEvent{
|
||||
Type: model.StreamToolProgress,
|
||||
ToolProgress: &model.ToolProgressInfo{
|
||||
ToolName: name,
|
||||
Status: status,
|
||||
Progress: progress,
|
||||
Message: message,
|
||||
},
|
||||
}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// executeAsyncAndStore runs a tool in background and stores the result for the next turn.
|
||||
func (s *Synthesizer) executeAsyncAndStore(tc model.ToolCall, args map[string]interface{}, sessionID string, eventCh chan<- model.StreamEvent) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := s.toolRegistry.Execute(ctx, tc.Name, args)
|
||||
if err != nil {
|
||||
logger.Printf("[synthesizer] 后台工具 %s 执行失败: %v", tc.Name, err)
|
||||
s.emitToolProgress(eventCh, tc.Name, "failed", 1.0, tc.Name+" 后台执行失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.emitToolProgress(eventCh, tc.Name, "completed", 1.0, tc.Name+" 后台执行完成")
|
||||
|
||||
resultJSON, _ := json.Marshal(result)
|
||||
store := GetGlobalPendingToolStore()
|
||||
if store != nil {
|
||||
store.AppendToolResult(sessionID, PendingToolResult{
|
||||
ToolCallID: tc.ID,
|
||||
ToolName: tc.Name,
|
||||
Result: string(resultJSON),
|
||||
Success: result != nil && result.Success,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// buildSynthesizeMessages 构建综合用的 LLM 消息列表
|
||||
func (s *Synthesizer) buildSynthesizeMessages(params SynthesizeParams) []model.LLMMessage {
|
||||
var messages []model.LLMMessage
|
||||
@@ -174,6 +242,23 @@ func (s *Synthesizer) buildSynthesizeMessages(params SynthesizeParams) []model.L
|
||||
})
|
||||
}
|
||||
|
||||
// 注入上一轮异步工具执行结果
|
||||
if len(params.PendingToolResults) > 0 {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("【上一轮后台工具执行结果】\n")
|
||||
for _, ptr := range params.PendingToolResults {
|
||||
status := "成功"
|
||||
if !ptr.Success {
|
||||
status = "失败"
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("- %s (%s): %s\n", ptr.ToolName, status, ptr.Result))
|
||||
}
|
||||
messages = append(messages, model.LLMMessage{
|
||||
Role: model.RoleSystem,
|
||||
Content: sb.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// 注入对话历史
|
||||
if len(params.DialogHistory) > 0 {
|
||||
messages = append(messages, params.DialogHistory...)
|
||||
|
||||
@@ -2,12 +2,15 @@ package subsession
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/yourname/cyrene-ai/pkg/logger"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/llm"
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/memory"
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
"github.com/yourname/cyrene-ai/pkg/logger"
|
||||
)
|
||||
|
||||
// MemoryRetriever 记忆检索接口
|
||||
@@ -16,9 +19,12 @@ type MemoryRetriever interface {
|
||||
}
|
||||
|
||||
// MemoryProvider 记忆检索子会话提供者
|
||||
// 职责:检索与当前对话相关的用户记忆,排序去重,返回结构化摘要
|
||||
// 职责:检索与当前对话相关的用户记忆,排序去重,返回结构化摘要。
|
||||
// 支持 LLM 驱动的模糊关键词扩展搜索。
|
||||
type MemoryProvider struct {
|
||||
retriever MemoryRetriever
|
||||
retriever MemoryRetriever
|
||||
llmAdapter *llm.Adapter
|
||||
memClient *memory.Client
|
||||
}
|
||||
|
||||
// NewMemoryProvider 创建记忆检索子会话提供者
|
||||
@@ -28,6 +34,12 @@ func NewMemoryProvider(retriever MemoryRetriever) *MemoryProvider {
|
||||
}
|
||||
}
|
||||
|
||||
// SetFuzzySearch enables LLM-driven fuzzy keyword expansion for broader memory retrieval.
|
||||
func (p *MemoryProvider) SetFuzzySearch(llmAdapter *llm.Adapter, memClient *memory.Client) {
|
||||
p.llmAdapter = llmAdapter
|
||||
p.memClient = memClient
|
||||
}
|
||||
|
||||
func (p *MemoryProvider) Type() model.SubSessionType {
|
||||
return model.SubSessionMemory
|
||||
}
|
||||
@@ -93,6 +105,7 @@ func (p *MemoryProvider) Execute(ctx context.Context, subCtx []model.LLMMessage)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Phase 1: exact/keyword retrieval
|
||||
memories, err := p.retriever.Retrieve(ctx, userID, userMessage)
|
||||
if err != nil {
|
||||
logger.Printf("[memory-subsession] 记忆检索失败: %v", err)
|
||||
@@ -101,6 +114,20 @@ func (p *MemoryProvider) Execute(ctx context.Context, subCtx []model.LLMMessage)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
seen := make(map[string]bool)
|
||||
for _, m := range memories {
|
||||
seen[m.ID] = true
|
||||
}
|
||||
|
||||
// Phase 2: LLM-driven fuzzy keyword expansion + semantic search
|
||||
fuzzyMemories := p.fuzzySearch(ctx, userID, userMessage)
|
||||
for _, m := range fuzzyMemories {
|
||||
if !seen[m.ID] {
|
||||
seen[m.ID] = true
|
||||
memories = append(memories, m)
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为 MemorySnippet
|
||||
snippets := make([]model.MemorySnippet, 0, len(memories))
|
||||
for _, m := range memories {
|
||||
@@ -117,7 +144,7 @@ func (p *MemoryProvider) Execute(ctx context.Context, subCtx []model.LLMMessage)
|
||||
if len(snippets) == 0 {
|
||||
result.Summary = "(没有找到相关记忆)"
|
||||
} else {
|
||||
result.Summary = fmt.Sprintf("检索到 %d 条相关记忆", len(snippets))
|
||||
result.Summary = fmt.Sprintf("检索到 %d 条相关记忆(含模糊匹配)", len(snippets))
|
||||
// 按重要性列出前几条
|
||||
topCount := len(snippets)
|
||||
if topCount > 3 {
|
||||
@@ -138,6 +165,74 @@ func (p *MemoryProvider) Execute(ctx context.Context, subCtx []model.LLMMessage)
|
||||
}
|
||||
|
||||
result.Memories = snippets
|
||||
logger.Printf("[memory-subsession] 完成: %s", result.Summary)
|
||||
logger.Printf("[memory-subsession] 完成: %s (精确=%d, 模糊=%d)", result.Summary, len(memories)-len(fuzzyMemories), len(fuzzyMemories))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// fuzzySearch expands the user message into fuzzy keywords via LLM and performs semantic search.
|
||||
func (p *MemoryProvider) fuzzySearch(ctx context.Context, userID, userMessage string) []memory.MemoryEntry {
|
||||
if p.llmAdapter == nil || p.memClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
keywords := p.expandKeywords(ctx, userMessage)
|
||||
if len(keywords) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Printf("[memory-subsession] 模糊关键词: %v", keywords)
|
||||
|
||||
var allResults []memory.MemoryEntry
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, kw := range keywords {
|
||||
results, err := p.memClient.QueryByText(ctx, userID, kw, "", 0, 5)
|
||||
if err != nil {
|
||||
logger.Printf("[memory-subsession] 模糊搜索 '%s' 失败: %v", kw, err)
|
||||
continue
|
||||
}
|
||||
for _, m := range results {
|
||||
if !seen[m.ID] {
|
||||
seen[m.ID] = true
|
||||
allResults = append(allResults, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return allResults
|
||||
}
|
||||
|
||||
// expandKeywords uses LLM to generate fuzzy/related search keywords from the user message.
|
||||
func (p *MemoryProvider) expandKeywords(ctx context.Context, message string) []string {
|
||||
prompt := fmt.Sprintf(
|
||||
"从以下对话消息中提取 3-5 个可用于模糊搜索记忆的关键词。这些关键词应该是:\n"+
|
||||
"- 与话题相关的抽象概念\n- 同义词和相关词\n- 更宽泛或更具体的相关概念\n"+
|
||||
"- 不要包含消息中已经出现的原词\n\n"+
|
||||
"用户消息:「%s」\n\n"+
|
||||
"只输出 JSON 字符串数组,例如:[\"关键词1\",\"关键词2\"]", message)
|
||||
|
||||
resp, err := p.llmAdapter.Chat(ctx, []model.LLMMessage{
|
||||
{Role: model.RoleSystem, Content: "你是记忆搜索专家。输出 JSON 字符串数组。"},
|
||||
{Role: model.RoleUser, Content: prompt},
|
||||
})
|
||||
if err != nil {
|
||||
logger.Printf("[memory-subsession] 关键词扩展失败: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(resp.Content)
|
||||
// Extract JSON array
|
||||
if idx := strings.Index(text, "["); idx >= 0 {
|
||||
if end := strings.LastIndex(text, "]"); end > idx {
|
||||
text = text[idx : end+1]
|
||||
}
|
||||
}
|
||||
|
||||
var keywords []string
|
||||
if err := json.Unmarshal([]byte(text), &keywords); err != nil {
|
||||
logger.Printf("[memory-subsession] 解析关键词 JSON 失败: %v (raw=%s)", err, resp.Content)
|
||||
return nil
|
||||
}
|
||||
|
||||
return keywords
|
||||
}
|
||||
|
||||
@@ -0,0 +1,217 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/host"
|
||||
)
|
||||
|
||||
// OSExecTool allows the AI to execute arbitrary commands in a full OS
|
||||
// environment (WSL or Docker container). Unlike host_exec which runs in
|
||||
// a restricted sandbox, this provides unrestricted OS access.
|
||||
type OSExecTool struct {
|
||||
manager *host.Manager
|
||||
}
|
||||
|
||||
// NewOSExecTool creates a new OS exec tool for full OS command execution.
|
||||
func NewOSExecTool(manager *host.Manager) *OSExecTool {
|
||||
return &OSExecTool{manager: manager}
|
||||
}
|
||||
|
||||
func (t *OSExecTool) Definition() ToolDefinition {
|
||||
return ToolDefinition{
|
||||
Name: "os_exec",
|
||||
Description: "在完整的操作系统环境(WSL/Docker容器)中执行任意命令。适用于复杂操作:安装软件包、编译大型项目、运行脚本、管理服务等。拥有完整的Linux系统权限,无命令限制。日常简单操作请使用 host_exec。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"command": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "要执行的命令,例如 'pip install pandas && python analyze.py' 或 'apt-get update && apt-get install -y ffmpeg'",
|
||||
},
|
||||
"work_dir": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "工作目录。不指定则使用默认目录。",
|
||||
},
|
||||
"timeout_sec": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "超时时间(秒),默认30秒,最大300秒。复杂任务请设置更长的超时。",
|
||||
},
|
||||
},
|
||||
"required": []string{"command"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *OSExecTool) Execute(ctx context.Context, args map[string]interface{}) (*ToolResult, error) {
|
||||
cmd, _ := args["command"].(string)
|
||||
if cmd == "" {
|
||||
return &ToolResult{
|
||||
ToolName: "os_exec",
|
||||
Success: false,
|
||||
Error: "command 参数不能为空",
|
||||
}, nil
|
||||
}
|
||||
|
||||
workDir, _ := args["work_dir"].(string)
|
||||
timeoutSec := 60 // Default longer timeout for complex operations
|
||||
if v, ok := args["timeout_sec"].(float64); ok {
|
||||
timeoutSec = int(v)
|
||||
}
|
||||
timeout := time.Duration(timeoutSec) * time.Second
|
||||
|
||||
result, err := t.manager.Exec(ctx, cmd, workDir, timeout)
|
||||
if err != nil && result == nil {
|
||||
return &ToolResult{
|
||||
ToolName: "os_exec",
|
||||
Success: false,
|
||||
Error: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"command": cmd,
|
||||
"backend": t.manager.BackendName(),
|
||||
"exit_code": result.ExitCode,
|
||||
"duration": result.Duration,
|
||||
"timed_out": result.TimedOut,
|
||||
"stdout": result.Stdout,
|
||||
"stderr": result.Stderr,
|
||||
})
|
||||
|
||||
success := result.ExitCode == 0 && !result.TimedOut
|
||||
return &ToolResult{
|
||||
ToolName: "os_exec",
|
||||
Success: success,
|
||||
Data: string(data),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// OSFileTool provides unrestricted file system access within the OS environment.
|
||||
type OSFileTool struct {
|
||||
manager *host.Manager
|
||||
}
|
||||
|
||||
// NewOSFileTool creates a new OS file tool for full OS file operations.
|
||||
func NewOSFileTool(manager *host.Manager) *OSFileTool {
|
||||
return &OSFileTool{manager: manager}
|
||||
}
|
||||
|
||||
func (t *OSFileTool) Definition() ToolDefinition {
|
||||
return ToolDefinition{
|
||||
Name: "os_file",
|
||||
Description: "在完整OS环境中读写文件。支持在整个文件系统中自由操作:读取/写入/列出文件,无目录限制。适用于批量文件处理、日志分析、配置文件管理等复杂文件操作。日常简单文件操作请使用 host_file。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "操作类型: read, write, list",
|
||||
"enum": []string{"read", "write", "list"},
|
||||
},
|
||||
"path": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "文件或目录路径",
|
||||
},
|
||||
"content": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "写入内容 (仅 write 操作需要)",
|
||||
},
|
||||
},
|
||||
"required": []string{"action", "path"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *OSFileTool) Execute(ctx context.Context, args map[string]interface{}) (*ToolResult, error) {
|
||||
action, _ := args["action"].(string)
|
||||
path, _ := args["path"].(string)
|
||||
if action == "" || path == "" {
|
||||
return &ToolResult{
|
||||
ToolName: "os_file",
|
||||
Success: false,
|
||||
Error: "action 和 path 参数不能为空",
|
||||
}, nil
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "read":
|
||||
content, err := t.manager.ReadFile(path, 1024*1024)
|
||||
if err != nil {
|
||||
return &ToolResult{ToolName: "os_file", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"path": path,
|
||||
"content": content,
|
||||
"size": len(content),
|
||||
})
|
||||
return &ToolResult{ToolName: "os_file", Success: true, Data: string(data)}, nil
|
||||
|
||||
case "write":
|
||||
content, _ := args["content"].(string)
|
||||
if err := t.manager.WriteFile(path, content, 1024*1024); err != nil {
|
||||
return &ToolResult{ToolName: "os_file", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"path": path,
|
||||
"written": len(content),
|
||||
"status": "ok",
|
||||
})
|
||||
return &ToolResult{ToolName: "os_file", Success: true, Data: string(data)}, nil
|
||||
|
||||
case "list":
|
||||
entries, err := t.manager.ListDir(path)
|
||||
if err != nil {
|
||||
return &ToolResult{ToolName: "os_file", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"path": path,
|
||||
"entries": entries,
|
||||
"count": len(entries),
|
||||
})
|
||||
return &ToolResult{ToolName: "os_file", Success: true, Data: string(data)}, nil
|
||||
|
||||
default:
|
||||
return &ToolResult{ToolName: "os_file", Success: false, Error: fmt.Sprintf("不支持的操作: %s", action)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// OSSystemTool provides OS-level system information.
|
||||
type OSSystemTool struct {
|
||||
manager *host.Manager
|
||||
}
|
||||
|
||||
// NewOSSystemTool creates a new OS system info tool.
|
||||
func NewOSSystemTool(manager *host.Manager) *OSSystemTool {
|
||||
return &OSSystemTool{manager: manager}
|
||||
}
|
||||
|
||||
func (t *OSSystemTool) Definition() ToolDefinition {
|
||||
return ToolDefinition{
|
||||
Name: "os_system",
|
||||
Description: "获取完整OS环境的系统信息,包括操作系统详情、CPU架构、内存使用、磁盘空间等。与 host_system 不同,此工具返回的是WSL/容器内的完整Linux系统信息。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "查询类型: info(完整信息), memory(内存), cpu(CPU), disk(磁盘)",
|
||||
"enum": []string{"info", "memory", "cpu", "disk"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *OSSystemTool) Execute(ctx context.Context, args map[string]interface{}) (*ToolResult, error) {
|
||||
info := t.manager.SystemInfo()
|
||||
data, _ := json.Marshal(info)
|
||||
return &ToolResult{
|
||||
ToolName: "os_system",
|
||||
Success: true,
|
||||
Data: string(data),
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/host"
|
||||
)
|
||||
|
||||
func TestOSExecToolWSL(t *testing.T) {
|
||||
distro := os.Getenv("WSL_DISTRO")
|
||||
if distro == "" {
|
||||
t.Skip("WSL_DISTRO not set, skipping OS tool integration test")
|
||||
}
|
||||
backend := host.NewWSLBackend(distro, "cyrene", "test123", 30e9)
|
||||
mgr := host.NewManager(backend)
|
||||
|
||||
// Test os_exec
|
||||
t.Run("os_exec", func(t *testing.T) {
|
||||
tool := NewOSExecTool(mgr)
|
||||
def := tool.Definition()
|
||||
if def.Name != "os_exec" {
|
||||
t.Fatalf("unexpected name: %s", def.Name)
|
||||
}
|
||||
result, err := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"command": "echo 'os_exec works!' && uname -a",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("execute error: %v", err)
|
||||
}
|
||||
if !result.Success {
|
||||
t.Fatalf("exec failed: %s", result.Error)
|
||||
}
|
||||
if !strings.Contains(result.Data, "os_exec works!") {
|
||||
t.Fatalf("unexpected output: %s", result.Data)
|
||||
}
|
||||
t.Logf("os_exec OK: data len=%d", len(result.Data))
|
||||
})
|
||||
|
||||
// Test os_file
|
||||
t.Run("os_file", func(t *testing.T) {
|
||||
tool := NewOSFileTool(mgr)
|
||||
def := tool.Definition()
|
||||
if def.Name != "os_file" {
|
||||
t.Fatalf("unexpected name: %s", def.Name)
|
||||
}
|
||||
|
||||
// Write
|
||||
r, err := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"action": "write",
|
||||
"path": "/tmp/cyrene-os-tool-test.txt",
|
||||
"content": "OS tool integration test",
|
||||
})
|
||||
if err != nil || !r.Success {
|
||||
t.Fatalf("os_file write failed: err=%v, errMsg=%s", err, r.Error)
|
||||
}
|
||||
|
||||
// Read
|
||||
r, err = tool.Execute(context.Background(), map[string]interface{}{
|
||||
"action": "read",
|
||||
"path": "/tmp/cyrene-os-tool-test.txt",
|
||||
})
|
||||
if err != nil || !r.Success {
|
||||
t.Fatalf("os_file read failed: err=%v, errMsg=%s", err, r.Error)
|
||||
}
|
||||
if !strings.Contains(r.Data, "OS tool integration test") {
|
||||
t.Fatalf("content mismatch: %s", r.Data)
|
||||
}
|
||||
|
||||
// List
|
||||
r, err = tool.Execute(context.Background(), map[string]interface{}{
|
||||
"action": "list",
|
||||
"path": "/tmp",
|
||||
})
|
||||
if err != nil || !r.Success {
|
||||
t.Fatalf("os_file list failed: err=%v, errMsg=%s", err, r.Error)
|
||||
}
|
||||
t.Logf("os_file OK: write+read+list all pass")
|
||||
})
|
||||
|
||||
// Test os_system
|
||||
t.Run("os_system", func(t *testing.T) {
|
||||
tool := NewOSSystemTool(mgr)
|
||||
def := tool.Definition()
|
||||
if def.Name != "os_system" {
|
||||
t.Fatalf("unexpected name: %s", def.Name)
|
||||
}
|
||||
result, err := tool.Execute(context.Background(), map[string]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatalf("execute error: %v", err)
|
||||
}
|
||||
if !result.Success {
|
||||
t.Fatalf("os_system failed: %s", result.Error)
|
||||
}
|
||||
if !strings.Contains(result.Data, "wsl") {
|
||||
t.Fatalf("expected wsl backend info: %s", result.Data)
|
||||
}
|
||||
t.Logf("os_system OK: data len=%d", len(result.Data))
|
||||
})
|
||||
}
|
||||
@@ -13,7 +13,7 @@ func TestHostExecToolDefinition(t *testing.T) {
|
||||
cfg := host.DefaultSandboxConfig()
|
||||
cfg.AllowedDirs = []string{os.TempDir()}
|
||||
sandbox := host.NewSandbox(cfg)
|
||||
mgr := host.NewManager(sandbox)
|
||||
mgr := host.NewManager(host.NewDirectBackend(sandbox))
|
||||
|
||||
tool := NewHostExecTool(mgr)
|
||||
def := tool.Definition()
|
||||
@@ -40,7 +40,7 @@ func TestHostFileToolDefinition(t *testing.T) {
|
||||
tmpDir := os.TempDir()
|
||||
cfg.AllowedDirs = []string{tmpDir}
|
||||
sandbox := host.NewSandbox(cfg)
|
||||
mgr := host.NewManager(sandbox)
|
||||
mgr := host.NewManager(host.NewDirectBackend(sandbox))
|
||||
mgr.SetAllowedDirs([]string{tmpDir})
|
||||
|
||||
tool := NewHostFileTool(mgr)
|
||||
@@ -67,7 +67,7 @@ func TestHostFileToolDefinition(t *testing.T) {
|
||||
func TestHostSystemToolDefinition(t *testing.T) {
|
||||
cfg := host.DefaultSandboxConfig()
|
||||
sandbox := host.NewSandbox(cfg)
|
||||
mgr := host.NewManager(sandbox)
|
||||
mgr := host.NewManager(host.NewDirectBackend(sandbox))
|
||||
|
||||
tool := NewHostSystemTool(mgr)
|
||||
def := tool.Definition()
|
||||
|
||||
@@ -40,7 +40,7 @@ func TestEncodeImageToDataURL_InvalidPath(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestVisionToolDefinition(t *testing.T) {
|
||||
tool := NewVisionTool()
|
||||
tool := NewVisionTool(nil)
|
||||
def := tool.Definition()
|
||||
if def.Name != "vision_analyze" {
|
||||
t.Fatalf("unexpected tool name: %s", def.Name)
|
||||
@@ -68,7 +68,7 @@ func TestVisionToolExecute(t *testing.T) {
|
||||
}
|
||||
defer os.Remove(tmpPath)
|
||||
|
||||
tool := NewVisionTool()
|
||||
tool := NewVisionTool(nil)
|
||||
ctx := context.Background()
|
||||
result, err := tool.Execute(ctx, map[string]interface{}{
|
||||
"image_path": tmpPath,
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ScheduleRule defines a time-based interval rule.
|
||||
type ScheduleRule struct {
|
||||
Name string `json:"name"`
|
||||
Days []string `json:"days"` // monday, tuesday, wednesday, thursday, friday, saturday, sunday
|
||||
TimeRange string `json:"time_range"` // "HH:MM-HH:MM"
|
||||
Except []string `json:"except"` // ["HH:MM-HH:MM", ...]
|
||||
IntervalMinutes int `json:"interval_minutes"`
|
||||
}
|
||||
|
||||
// ThinkingScheduleConfig is the full schedule configuration.
|
||||
type ThinkingScheduleConfig struct {
|
||||
Version string `json:"version"`
|
||||
DefaultIntervalMinutes int `json:"default_interval_minutes"`
|
||||
Rules []ScheduleRule `json:"rules"`
|
||||
}
|
||||
|
||||
// DefaultThinkingScheduleConfig returns the default schedule with two rules.
|
||||
func DefaultThinkingScheduleConfig() *ThinkingScheduleConfig {
|
||||
return &ThinkingScheduleConfig{
|
||||
Version: "1.0",
|
||||
DefaultIntervalMinutes: 5,
|
||||
Rules: []ScheduleRule{
|
||||
{
|
||||
Name: "night",
|
||||
Days: []string{"monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday"},
|
||||
TimeRange: "23:00-07:00",
|
||||
IntervalMinutes: 30,
|
||||
},
|
||||
{
|
||||
Name: "weekday_work",
|
||||
Days: []string{"monday", "tuesday", "wednesday", "thursday", "friday"},
|
||||
TimeRange: "09:00-17:00",
|
||||
Except: []string{"12:00-14:00", "15:00-15:30"},
|
||||
IntervalMinutes: 30,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ThinkingScheduleStore persists the schedule config to a JSON file.
|
||||
type ThinkingScheduleStore struct {
|
||||
mu sync.RWMutex
|
||||
path string
|
||||
config *ThinkingScheduleConfig
|
||||
}
|
||||
|
||||
// NewThinkingScheduleStore creates a store, creating the file with defaults if it does not exist.
|
||||
func NewThinkingScheduleStore(path string) (*ThinkingScheduleStore, error) {
|
||||
s := &ThinkingScheduleStore{
|
||||
path: path,
|
||||
config: nil,
|
||||
}
|
||||
if err := s.load(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *ThinkingScheduleStore) load() error {
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
s.config = DefaultThinkingScheduleConfig()
|
||||
return s.save()
|
||||
}
|
||||
return fmt.Errorf("read thinking schedule file: %w", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
s.config = DefaultThinkingScheduleConfig()
|
||||
return s.save()
|
||||
}
|
||||
var cfg ThinkingScheduleConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse thinking schedule: %w", err)
|
||||
}
|
||||
if cfg.Version == "" {
|
||||
cfg.Version = "1.0"
|
||||
}
|
||||
if cfg.DefaultIntervalMinutes <= 0 {
|
||||
cfg.DefaultIntervalMinutes = 5
|
||||
}
|
||||
if cfg.Rules == nil {
|
||||
cfg.Rules = []ScheduleRule{}
|
||||
}
|
||||
s.config = &cfg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ThinkingScheduleStore) save() error {
|
||||
data, err := json.MarshalIndent(s.config, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal thinking schedule: %w", err)
|
||||
}
|
||||
tmpPath := s.path + ".tmp"
|
||||
if err := os.WriteFile(tmpPath, data, 0640); err != nil {
|
||||
return fmt.Errorf("write thinking schedule: %w", err)
|
||||
}
|
||||
return os.Rename(tmpPath, s.path)
|
||||
}
|
||||
|
||||
// GetConfig returns the current config (read-only).
|
||||
func (s *ThinkingScheduleStore) GetConfig() *ThinkingScheduleConfig {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.config
|
||||
}
|
||||
|
||||
// SetConfig validates and persists a new config.
|
||||
func (s *ThinkingScheduleStore) SetConfig(cfg *ThinkingScheduleConfig) error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("配置不能为空")
|
||||
}
|
||||
if cfg.DefaultIntervalMinutes <= 0 {
|
||||
cfg.DefaultIntervalMinutes = 5
|
||||
}
|
||||
if cfg.Version == "" {
|
||||
cfg.Version = "1.0"
|
||||
}
|
||||
if cfg.Rules == nil {
|
||||
cfg.Rules = []ScheduleRule{}
|
||||
}
|
||||
for _, r := range cfg.Rules {
|
||||
if r.IntervalMinutes <= 0 {
|
||||
return fmt.Errorf("规则 %q 间隔分钟必须大于 0", r.Name)
|
||||
}
|
||||
if r.TimeRange == "" {
|
||||
return fmt.Errorf("规则 %q 缺少 time_range", r.Name)
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.config = cfg
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// HasConfig returns true if a config is loaded.
|
||||
func (s *ThinkingScheduleStore) HasConfig() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.config != nil
|
||||
}
|
||||
@@ -152,7 +152,20 @@ func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage)
|
||||
"mode": mode,
|
||||
}
|
||||
if len(msg.Attachments) > 0 {
|
||||
aiReq["attachments"] = msg.Attachments
|
||||
images := make([]string, 0, len(msg.Attachments))
|
||||
for _, att := range msg.Attachments {
|
||||
if att.Type == "image" && att.URL != "" {
|
||||
imgURL := att.URL
|
||||
// 将相对路径转换为绝对 URL,方便 AI-Core 访问
|
||||
if strings.HasPrefix(imgURL, "/") {
|
||||
imgURL = "http://127.0.0.1:" + h.cfg.Port + imgURL
|
||||
}
|
||||
images = append(images, imgURL)
|
||||
}
|
||||
}
|
||||
if len(images) > 0 {
|
||||
aiReq["images"] = images
|
||||
}
|
||||
}
|
||||
reqBody, err := json.Marshal(aiReq)
|
||||
if err != nil {
|
||||
@@ -187,8 +200,8 @@ func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage)
|
||||
}
|
||||
h.hub.CacheMessage(client.UserID, client.SessionID, userMsg)
|
||||
|
||||
// 广播用户消息给同用户所有设备(跨端同步)
|
||||
h.broadcastToUser(client.UserID, ws.ServerMessage{
|
||||
// 广播用户消息给同用户其他设备(跨端同步,排除发送者自身)
|
||||
h.broadcastToUserExcept(client.UserID, client.ClientID, ws.ServerMessage{
|
||||
Type: "response",
|
||||
MessageID: userMsgID,
|
||||
Content: msg.Content,
|
||||
@@ -208,6 +221,21 @@ func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage)
|
||||
|
||||
// streamResponse 调用 AI-Core SSE 流式接口并逐 delta 转发给客户端
|
||||
func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []byte, userMsg string) {
|
||||
normalExit := false
|
||||
defer func() {
|
||||
if !normalExit {
|
||||
h.broadcastToUser(client.UserID, ws.ServerMessage{
|
||||
Type: "stream_end",
|
||||
MessageID: "msg_" + generateID(),
|
||||
SessionID: client.SessionID,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
if h.hub != nil {
|
||||
h.hub.UpdateSessionState(client.SessionID, "idle")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
aiCoreURL := h.cfg.AICoreURL + "/api/v1/chat"
|
||||
httpReq, err := http.NewRequest("POST", aiCoreURL, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
@@ -309,7 +337,7 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
if chunk.Error != "" {
|
||||
logger.Printf("[chat] AI-Core 流式错误: %s", chunk.Error)
|
||||
h.hub.UpdateSessionState(client.SessionID, "error")
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
h.broadcastToUser(client.UserID, ws.ServerMessage{
|
||||
Type: "error",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Error: chunk.Error,
|
||||
@@ -338,9 +366,13 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
msgType = "action"
|
||||
}
|
||||
reviewMsgID := fmt.Sprintf("%s_r%d", msgID, i)
|
||||
// 持久化每条审查消息
|
||||
// 持久化每条审查消息 (action 角色映射为 assistant,LLM 模型不支持自定义角色)
|
||||
if h.sessionStore != nil && h.sessionStore.IsAvailable() {
|
||||
if err := h.sessionStore.AddMessage(client.SessionID, role, msgType, rm.Content, client.ClientID); err != nil {
|
||||
dbRole := role
|
||||
if dbRole == "action" {
|
||||
dbRole = "assistant"
|
||||
}
|
||||
if err := h.sessionStore.AddMessage(client.SessionID, dbRole, msgType, rm.Content, client.ClientID); err != nil {
|
||||
logger.Printf("[chat] 持久化审查消息失败: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -402,7 +434,7 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.Printf("[chat] SSE 读取错误: %v", err)
|
||||
h.hub.UpdateSessionState(client.SessionID, "error")
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
h.broadcastToUser(client.UserID, ws.ServerMessage{
|
||||
Type: "error",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Error: fmt.Sprintf("流读取错误: %v", err),
|
||||
@@ -477,6 +509,7 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
h.hub.RecordMessage(client.SessionID, "assistant", recordText)
|
||||
|
||||
// 设置会话状态为 idle
|
||||
normalExit = true
|
||||
h.hub.UpdateSessionState(client.SessionID, "idle")
|
||||
}
|
||||
|
||||
@@ -766,7 +799,11 @@ func (h *ChatHandler) HandleProactiveMessage(c *gin.Context) {
|
||||
|
||||
// Persist to database so proactive messages survive restarts.
|
||||
if h.sessionStore != nil && h.sessionStore.IsAvailable() {
|
||||
if err := h.sessionStore.AddMessage(sessionID, role, msgType, seg.content, ""); err != nil {
|
||||
dbRole := role
|
||||
if dbRole == "action" {
|
||||
dbRole = "assistant"
|
||||
}
|
||||
if err := h.sessionStore.AddMessage(sessionID, dbRole, msgType, seg.content, ""); err != nil {
|
||||
logger.Printf("[proactive] 持久化消息失败: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -951,6 +988,17 @@ func (h *ChatHandler) broadcastToUser(userID string, msg ws.ServerMessage) {
|
||||
h.hub.SendToUser(userID, data)
|
||||
}
|
||||
|
||||
// broadcastToUserExcept sends a server message to ALL connected clients for a user,
|
||||
// excluding the specified clientID (the sender).
|
||||
func (h *ChatHandler) broadcastToUserExcept(userID, excludeClientID string, msg ws.ServerMessage) {
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
logger.Printf("[chat] 序列化广播消息失败: %v", err)
|
||||
return
|
||||
}
|
||||
h.hub.SendToUserExcept(userID, excludeClientID, data)
|
||||
}
|
||||
|
||||
// parseMultiMessage 检测并解析多消息格式
|
||||
// 如果文本包含空行分隔的多条消息,拆分为多条;否则返回单条
|
||||
func parseMultiMessage(text string) []proactiveSegment {
|
||||
|
||||
@@ -93,7 +93,16 @@ func (h *MemoryHandler) List(c *gin.Context) {
|
||||
userID = authUserID
|
||||
}
|
||||
|
||||
limit := c.Query("limit")
|
||||
offset := c.Query("offset")
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/memories?user_id=%s", h.memoryServiceURL, userID)
|
||||
if limit != "" {
|
||||
url += "&limit=" + limit
|
||||
}
|
||||
if offset != "" {
|
||||
url += "&offset=" + offset
|
||||
}
|
||||
|
||||
resp, err := h.client.Get(url)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
)
|
||||
|
||||
// ThinkingScheduleHandler handles CRUD for the thinking schedule config.
|
||||
type ThinkingScheduleHandler struct {
|
||||
store *config.ThinkingScheduleStore
|
||||
}
|
||||
|
||||
// NewThinkingScheduleHandler creates a new handler.
|
||||
func NewThinkingScheduleHandler(store *config.ThinkingScheduleStore) *ThinkingScheduleHandler {
|
||||
return &ThinkingScheduleHandler{store: store}
|
||||
}
|
||||
|
||||
// GetSchedule returns the current schedule config.
|
||||
// GET /api/v1/admin/thinking-schedule
|
||||
func (h *ThinkingScheduleHandler) GetSchedule(c *gin.Context) {
|
||||
cfg := h.store.GetConfig()
|
||||
if cfg == nil {
|
||||
cfg = config.DefaultThinkingScheduleConfig()
|
||||
}
|
||||
c.JSON(http.StatusOK, cfg)
|
||||
}
|
||||
|
||||
// SetSchedule replaces the entire schedule config.
|
||||
// PUT /api/v1/admin/thinking-schedule
|
||||
func (h *ThinkingScheduleHandler) SetSchedule(c *gin.Context) {
|
||||
var cfg config.ThinkingScheduleConfig
|
||||
if err := c.ShouldBindJSON(&cfg); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.SetConfig(&cfg); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "saved"})
|
||||
}
|
||||
@@ -472,6 +472,25 @@ func (h *Hub) SendToUser(userID string, message []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
// SendToUserExcept 向指定用户的所有连接发送消息,排除指定 clientID
|
||||
func (h *Hub) SendToUserExcept(userID, excludeClientID string, message []byte) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
if clients, ok := h.userClients[userID]; ok {
|
||||
for client := range clients {
|
||||
if client.ClientID == excludeClientID {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case client.Send <- message:
|
||||
default:
|
||||
// 跳过阻塞的客户端
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SendToSession 向指定会话的连接发送消息
|
||||
func (h *Hub) SendToSession(userID, sessionID string, message []byte) {
|
||||
h.mu.RLock()
|
||||
|
||||
@@ -64,9 +64,10 @@ func (h *MemoryHandler) listMemories(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
category := r.URL.Query().Get("category")
|
||||
limit := queryInt(r, "limit", 50)
|
||||
offset := queryInt(r, "offset", 0)
|
||||
minImportance := queryInt(r, "min_importance", 0)
|
||||
|
||||
memories, err := h.svc.ListMemories(r.Context(), userID, category, minImportance, limit)
|
||||
memories, err := h.svc.ListMemories(r.Context(), userID, category, minImportance, limit, offset)
|
||||
if err != nil {
|
||||
logger.Printf("[memory-handler] 列出记忆失败: %v", err)
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
|
||||
@@ -65,7 +65,7 @@ func (svc *MemoryService) GetMemory(ctx context.Context, id string) (*model.Memo
|
||||
}
|
||||
|
||||
// ListMemories 列出用户所有记忆
|
||||
func (svc *MemoryService) ListMemories(ctx context.Context, userID string, category string, minImportance int, limit int) ([]model.MemoryEntry, error) {
|
||||
func (svc *MemoryService) ListMemories(ctx context.Context, userID string, category string, minImportance int, limit int, offset int) ([]model.MemoryEntry, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
@@ -74,6 +74,7 @@ func (svc *MemoryService) ListMemories(ctx context.Context, userID string, categ
|
||||
UserID: userID,
|
||||
MinImportance: minImportance,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
if category != "" {
|
||||
q.Category = model.MemoryCategory(category)
|
||||
|
||||
Reference in New Issue
Block a user