fix: 修复 AI 回复无法送达发送者 + 重复消息 + action角色泄露 + OS环境支持

广播逻辑重构:
- AI 回复 (stream_start/response/stream_segments/multi_message/stream_end) 改用 broadcastToUser 发送给所有客户端
- 用户消息回显保持 broadcastToUserExcept 排除发送者

消息去重与角色修复:
- CacheMessage(user) 移至回复生成后,避免本轮 LLM 调用出现重复用户消息
- action 角色消息在 DB 存储时映射为 assistant,DeepSeek 等模型不支持自定义角色
- stream_end defer 机制确保错误路径也会终止客户端思考指示器

OS 完整环境支持:
- host 包重构为 HostBackend 接口 + Direct/WSL/Docker 三种后端
- 新增 os_exec/os_file/os_system 工具供 AI 在完整 Linux 环境中自由操作

其他:
- 视觉模型注入 + 图片预处理后清空 Images 避免传给 Chat 模型
- 图片 URL 相对路径→绝对 URL 转换
- DevTools 链路追踪页面 + 重启修复
- 记忆搜索模糊匹配增强
- 后台思考定时调度支持
- 管理后台页面 (模型配置/用户管理等)
- docs/api 更新广播机制说明

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-29 12:46:17 +08:00
parent aac64ed8b7
commit 91c9ee4b2d
49 changed files with 5032 additions and 299 deletions
+403 -13
View File
@@ -582,16 +582,7 @@ func (t *Thinker) performThink(triggerReason string) {
return
}
// 2. 检索相关记忆
var memories []memory.MemoryEntry
if t.memRetriever != nil {
memories, err = t.memRetriever.Retrieve(ctx, t.adminUserID, "最近发生了什么 重要的事情 用户偏好 个人信息")
if err != nil {
log.Printf("[后台思考] 记忆检索失败: %v", err)
}
}
// 3. 获取当前活跃会话的对话历史(优先活跃会话,回退到管理员主会话)
// 2. 获取当前活跃会话的对话历史(优先活跃会话,回退到管理员主会话)
var convHistory []model.LLMMessage
if t.convStore != nil {
t.mu.Lock()
@@ -608,6 +599,37 @@ func (t *Thinker) performThink(triggerReason string) {
}
}
// 3. 检索相关记忆(精确检索 + 模糊搜索)
var memories []memory.MemoryEntry
if t.memRetriever != nil {
memories, err = t.memRetriever.Retrieve(ctx, t.adminUserID, "最近发生了什么 重要的事情 用户偏好 个人信息")
if err != nil {
log.Printf("[后台思考] 记忆检索失败: %v", err)
}
// 模糊搜索:从对话历史提取话题,LLM 扩展关键词后语义搜索
if t.memClient != nil && len(convHistory) > 0 {
fuzzyQuery := lastUserMessage(convHistory)
if fuzzyQuery == "" {
fuzzyQuery = "最近对话 重要事件 用户状态"
}
fuzzyResults := t.fuzzyMemorySearch(ctx, t.adminUserID, fuzzyQuery)
seen := make(map[string]bool)
for _, m := range memories {
seen[m.ID] = true
}
for _, m := range fuzzyResults {
if !seen[m.ID] {
seen[m.ID] = true
memories = append(memories, m)
}
}
if len(fuzzyResults) > 0 {
log.Printf("[后台思考] 模糊搜索补充 %d 条记忆", len(fuzzyResults))
}
}
}
// 4. 查询 IoT 设备状态(每次都查询,无间隔限制)
var deviceSummary string
if t.iotClient != nil {
@@ -750,10 +772,9 @@ func (t *Thinker) performThink(triggerReason string) {
log.Printf("[后台思考] 完成 (触发原因=%s, 内容长度=%d, 工具调用=%d次)", triggerReason, len(finalContent), totalToolCalls)
// 9. 周期性记忆维护(每 10 次思考触发一次)
// 注:不再从思考结果中提取记忆——思考内容基于已有记忆生成,
// 再次提取会造成"读取→思考→保存→再次读取"的重复循环。
// 9. 记忆维护:机械合并(每10次) + LLM整理(每次)
t.maybeMaintainMemories(currentCount)
t.performMemoryConsolidation(ctx)
}
// buildThinkingSystemPrompt 构建思考用的系统提示词
@@ -1196,6 +1217,375 @@ func (t *Thinker) maybeMaintainMemories(thinkCount int) {
}
}
// consolidationAction is a parsed memory consolidation instruction from the LLM.
type consolidationAction struct {
Action string `json:"action"`
IDs []string `json:"ids,omitempty"`
ID string `json:"id,omitempty"`
Content string `json:"content,omitempty"`
Category string `json:"category,omitempty"`
Importance int `json:"importance,omitempty"`
Priority int `json:"priority,omitempty"`
Keywords []string `json:"keywords,omitempty"`
Reason string `json:"reason,omitempty"`
}
// performMemoryConsolidation uses LLM to review and clean up the memory store.
// It identifies duplicates, contradictions, outdated info, and low-quality memories,
// then executes merge/delete/update actions.
func (t *Thinker) performMemoryConsolidation(ctx context.Context) {
if t.memClient == nil {
return
}
allMemories, err := t.memClient.Query(ctx, model.MemoryQuery{
UserID: t.adminUserID,
Limit: 200,
})
if err != nil {
log.Printf("[记忆整理] 获取记忆失败: %v", err)
return
}
if len(allMemories) < 5 {
return
}
log.Printf("[记忆整理] LLM 审查 %d 条记忆...", len(allMemories))
systemPrompt := t.buildConsolidationPrompt(allMemories)
messages := []model.LLMMessage{
{Role: model.RoleSystem, Content: systemPrompt},
{Role: model.RoleUser, Content: "请审查以上记忆库,找出重复、矛盾、过时和低质量的记忆,输出 JSON 整理方案。如果没有需要整理的,输出空数组 []。"},
}
resp, err := t.toolAdapter.Chat(ctx, messages)
if err != nil {
log.Printf("[记忆整理] LLM 调用失败: %v", err)
return
}
actions := parseConsolidationActions(resp.Content)
if len(actions) == 0 {
log.Printf("[记忆整理] 记忆库状态良好,无需整理")
return
}
log.Printf("[记忆整理] LLM 建议 %d 项操作", len(actions))
executed := t.executeConsolidationActions(ctx, actions, allMemories)
log.Printf("[记忆整理] 完成: 执行了 %d 项操作", executed)
}
// buildConsolidationPrompt formats all memories as a structured list for LLM review.
func (t *Thinker) buildConsolidationPrompt(memories []model.MemoryEntry) string {
var sb strings.Builder
sb.WriteString("你是记忆库管理助手。审查以下用户记忆,找出问题并输出 JSON 操作清单。\n\n")
sb.WriteString("## 需要识别的问题\n")
sb.WriteString("1. 重复记忆 — 多条记忆记录了相同信息 → merge 合并为一条\n")
sb.WriteString("2. 矛盾记忆 — 两条记忆互相矛盾(如\"喜欢猫\"vs\"讨厌猫\")→ delete 删除过时的、update 修正错误的\n")
sb.WriteString("3. 过时记忆 — 信息已被新记忆取代 → delete 或 update\n")
sb.WriteString("4. 低质量记忆 — 太模糊、不完整、无实际信息量 → delete\n\n")
sb.WriteString("## JSON 操作格式\n")
sb.WriteString("```json\n[\n")
sb.WriteString(" {\"action\":\"merge\", \"ids\":[\"id1\",\"id2\"], \"content\":\"合并后的内容\", \"category\":\"personal_info\", \"importance\":8, \"reason\":\"两条记录同一件事\"},\n")
sb.WriteString(" {\"action\":\"delete\", \"id\":\"id3\", \"reason\":\"完全被 id1 覆盖\"},\n")
sb.WriteString(" {\"action\":\"update\", \"id\":\"id4\", \"content\":\"修正后的内容\", \"importance\":7, \"reason\":\"纠正过时信息\"},\n")
sb.WriteString(" {\"action\":\"create\", \"content\":\"需要补充的记忆\", \"category\":\"knowledge\", \"importance\":6, \"reason\":\"从已有记忆推断\"}\n")
sb.WriteString("]\n```\n\n")
sb.WriteString("## 规则\n")
sb.WriteString("- 只输出 JSON 数组,可以用 ```json``` 包裹,不要输出其他解释文字\n")
sb.WriteString("- 确实有问题才建议操作,不要强行找问题\n")
sb.WriteString("- merge 时保留最重要的那条的 ID,合并内容应包含各条的关键信息\n")
sb.WriteString("- 不确定时宁可保守(不操作)\n")
sb.WriteString("- importance 范围 1-10,数字越大越重要\n")
sb.WriteString("- category 可选: personal_info, user_preference, conversation, knowledge, event, task, relationship\n\n")
sb.WriteString(fmt.Sprintf("## 当前记忆库 (%d 条)\n\n", len(memories)))
for i, m := range memories {
sb.WriteString(fmt.Sprintf("%d. [%s] **%s** | cat=%s imp=%d pri=%d | src=%s\n",
i+1, m.ID[:min(8, len(m.ID))], m.Content,
m.Category, m.Importance, m.Priority, m.Source))
}
return sb.String()
}
// parseConsolidationActions extracts JSON actions from LLM response text.
func parseConsolidationActions(text string) []consolidationAction {
// Try to extract from ```json fences first
jsonStr := text
if idx := strings.Index(text, "```json"); idx >= 0 {
start := idx + 7
if end := strings.Index(text[start:], "```"); end >= 0 {
jsonStr = text[start : start+end]
}
} else if idx := strings.Index(text, "```"); idx >= 0 {
start := idx + 3
if end := strings.Index(text[start:], "```"); end >= 0 {
jsonStr = text[start : start+end]
}
}
// Find the JSON array
arrStart := strings.Index(jsonStr, "[")
arrEnd := strings.LastIndex(jsonStr, "]")
if arrStart < 0 || arrEnd <= arrStart {
return nil
}
jsonStr = jsonStr[arrStart : arrEnd+1]
var actions []consolidationAction
if err := json.Unmarshal([]byte(jsonStr), &actions); err != nil {
log.Printf("[记忆整理] JSON 解析失败: %v", err)
return nil
}
return actions
}
// executeConsolidationActions runs the parsed consolidation actions against the memory store.
func (t *Thinker) executeConsolidationActions(ctx context.Context, actions []consolidationAction, memories []model.MemoryEntry) int {
// Index memories by their short ID prefix for lookup
memByShortID := make(map[string]*model.MemoryEntry)
for i := range memories {
short := memories[i].ID[:min(8, len(memories[i].ID))]
memByShortID[short] = &memories[i]
}
memByFullID := make(map[string]*model.MemoryEntry)
for i := range memories {
memByFullID[memories[i].ID] = &memories[i]
}
executed := 0
for _, a := range actions {
switch a.Action {
case "delete":
id := resolveID(a.ID, memByShortID, memByFullID)
if id == "" {
log.Printf("[记忆整理] 跳过 delete: 找不到记忆 %s", a.ID)
continue
}
if err := t.memClient.Delete(ctx, id); err != nil {
log.Printf("[记忆整理] 删除 %s 失败: %v", a.ID, err)
continue
}
log.Printf("[记忆整理] 已删除: %s (原因: %s)", a.ID, a.Reason)
executed++
case "merge":
if len(a.IDs) < 2 {
continue
}
// Resolve all IDs, use first as the keeper
var resolved []string
for _, mid := range a.IDs {
if rid := resolveID(mid, memByShortID, memByFullID); rid != "" {
resolved = append(resolved, rid)
}
}
if len(resolved) < 2 {
continue
}
keeper := resolved[0]
// Update the keeper with merged content
cat := model.MemoryCategory(a.Category)
if cat == "" {
if m, ok := memByFullID[keeper]; ok {
cat = m.Category
}
}
imp := a.Importance
if imp == 0 {
if m, ok := memByFullID[keeper]; ok {
imp = m.Importance + 1
}
}
if imp > 10 {
imp = 10
}
pri := a.Priority
if pri == 0 {
if m, ok := memByFullID[keeper]; ok {
pri = int(m.Priority)
}
}
if err := t.memClient.Update(ctx, &model.MemoryEntry{
ID: keeper,
Content: a.Content,
Category: cat,
Importance: imp,
Priority: model.MemoryPriority(pri),
Keywords: a.Keywords,
Source: "consolidated",
}); err != nil {
log.Printf("[记忆整理] 更新合并目标 %s 失败: %v", keeper, err)
continue
}
// Delete the discarded ones
for _, did := range resolved[1:] {
if err := t.memClient.Delete(ctx, did); err != nil {
log.Printf("[记忆整理] 删除被合并记忆 %s 失败: %v", did, err)
}
}
log.Printf("[记忆整理] 已合并: %v -> %s (原因: %s)", resolved, keeper, a.Reason)
executed++
case "update":
id := resolveID(a.ID, memByShortID, memByFullID)
if id == "" {
log.Printf("[记忆整理] 跳过 update: 找不到记忆 %s", a.ID)
continue
}
existing := memByFullID[id]
cat := model.MemoryCategory(a.Category)
if cat == "" && existing != nil {
cat = existing.Category
}
imp := a.Importance
if imp == 0 && existing != nil {
imp = existing.Importance
}
pri := a.Priority
if pri == 0 && existing != nil {
pri = int(existing.Priority)
}
if err := t.memClient.Update(ctx, &model.MemoryEntry{
ID: id,
Content: a.Content,
Category: cat,
Importance: imp,
Priority: model.MemoryPriority(pri),
Keywords: a.Keywords,
Source: "consolidated",
}); err != nil {
log.Printf("[记忆整理] 更新 %s 失败: %v", id, err)
continue
}
log.Printf("[记忆整理] 已更新: %s (原因: %s)", id, a.Reason)
executed++
case "create":
cat := model.MemoryCategory(a.Category)
if cat == "" {
cat = model.CategoryKnowledge
}
imp := a.Importance
if imp == 0 {
imp = 5
}
if err := t.memClient.Save(ctx, &model.MemoryEntry{
UserID: t.adminUserID,
Content: a.Content,
Category: cat,
Importance: imp,
Priority: model.MemoryNormal,
Keywords: a.Keywords,
Source: "consolidation",
}); err != nil {
log.Printf("[记忆整理] 创建记忆失败: %v", err)
continue
}
log.Printf("[记忆整理] 已创建: %s (原因: %s)", a.Content, a.Reason)
executed++
}
}
return executed
}
// resolveID tries to match a short or full ID to an existing memory.
func resolveID(id string, byShort, byFull map[string]*model.MemoryEntry) string {
if _, ok := byFull[id]; ok {
return id
}
if m, ok := byShort[id]; ok {
return m.ID
}
// Try prefix match
for fullID := range byFull {
if strings.HasPrefix(fullID, id) {
return fullID
}
}
return ""
}
// fuzzyMemorySearch expands the query via LLM keyword extraction and performs semantic search.
func (t *Thinker) fuzzyMemorySearch(ctx context.Context, userID, query string) []memory.MemoryEntry {
if t.memClient == nil {
return nil
}
keywords := t.expandMemoryKeywords(ctx, query)
if len(keywords) == 0 {
return nil
}
log.Printf("[后台思考] 模糊记忆关键词: %v", keywords)
var allResults []memory.MemoryEntry
seen := make(map[string]bool)
for _, kw := range keywords {
results, err := t.memClient.QueryByText(ctx, userID, kw, "", 0, 5)
if err != nil {
log.Printf("[后台思考] 模糊搜索 '%s' 失败: %v", kw, err)
continue
}
for _, m := range results {
if !seen[m.ID] {
seen[m.ID] = true
allResults = append(allResults, m)
}
}
}
return allResults
}
// expandMemoryKeywords uses LLM to generate fuzzy/related keywords for memory search.
func (t *Thinker) expandMemoryKeywords(ctx context.Context, message string) []string {
prompt := fmt.Sprintf(
"从以下对话消息中提取 3-5 个可用于模糊搜索记忆的关键词。这些关键词应该是:\n"+
"- 与话题相关的抽象概念\n- 同义词和相关词\n- 更宽泛或更具体的相关概念\n"+
"- 不要包含消息中已经出现的原词\n\n"+
"用户消息:「%s」\n\n"+
"只输出 JSON 字符串数组,例如:[\"关键词1\",\"关键词2\"]", message)
resp, err := t.llmAdapter.Chat(ctx, []model.LLMMessage{
{Role: model.RoleSystem, Content: "你是记忆搜索专家。输出 JSON 字符串数组。"},
{Role: model.RoleUser, Content: prompt},
})
if err != nil {
log.Printf("[后台思考] 关键词扩展失败: %v", err)
return nil
}
text := strings.TrimSpace(resp.Content)
if idx := strings.Index(text, "["); idx >= 0 {
if end := strings.LastIndex(text, "]"); end > idx {
text = text[idx : end+1]
}
}
var keywords []string
if err := json.Unmarshal([]byte(text), &keywords); err != nil {
log.Printf("[后台思考] 解析关键词 JSON 失败: %v (raw=%s)", err, resp.Content)
return nil
}
return keywords
}
// lastUserMessage extracts the last user message from conversation history.
func lastUserMessage(history []model.LLMMessage) string {
for i := len(history) - 1; i >= 0; i-- {
if history[i].Role == model.RoleUser {
runes := []rune(history[i].Content)
if len(runes) > 200 {
return string(runes[:200])
}
return history[i].Content
}
}
return ""
}
// formatDeviceContext 格式化设备状态为文本
func formatDeviceContext(devices []tools.IoTDevice) string {
if len(devices) == 0 {
@@ -0,0 +1,184 @@
package background
import (
"encoding/json"
"fmt"
"log"
"os"
"strconv"
"strings"
"sync"
"time"
)
// ScheduleRule defines a time-based interval rule.
type ScheduleRule struct {
Name string `json:"name"`
Days []string `json:"days"`
TimeRange string `json:"time_range"`
Except []string `json:"except"`
IntervalMinutes int `json:"interval_minutes"`
}
// ThinkingScheduleConfig is the full schedule configuration.
type ThinkingScheduleConfig struct {
Version string `json:"version"`
DefaultIntervalMinutes int `json:"default_interval_minutes"`
Rules []ScheduleRule `json:"rules"`
}
// ScheduleLoader loads the thinking schedule from a JSON file and calculates
// the current interval based on time of day and day of week.
type ScheduleLoader struct {
mu sync.RWMutex
path string
config *ThinkingScheduleConfig
}
// NewScheduleLoader creates a loader. Returns nil config if the file does not exist.
func NewScheduleLoader(path string) (*ScheduleLoader, error) {
l := &ScheduleLoader{path: path}
if err := l.load(); err != nil {
return l, err
}
return l, nil
}
func (l *ScheduleLoader) load() error {
data, err := os.ReadFile(l.path)
if err != nil {
if os.IsNotExist(err) {
l.config = nil
return nil
}
return fmt.Errorf("read thinking schedule: %w", err)
}
if len(data) == 0 {
l.config = nil
return nil
}
var cfg ThinkingScheduleConfig
if err := json.Unmarshal(data, &cfg); err != nil {
l.config = nil
return fmt.Errorf("parse thinking schedule: %w", err)
}
l.mu.Lock()
l.config = &cfg
l.mu.Unlock()
log.Printf("[思考调度] 已加载配置文件: version=%s, default=%dmin, rules=%d", cfg.Version, cfg.DefaultIntervalMinutes, len(cfg.Rules))
return nil
}
// HasConfig returns true if a schedule config was loaded from file.
func (l *ScheduleLoader) HasConfig() bool {
l.mu.RLock()
defer l.mu.RUnlock()
return l.config != nil
}
// GetInterval returns the thinking interval in minutes for the given time.
// Returns 0 if no schedule is loaded (caller should use default).
func (l *ScheduleLoader) GetInterval(now time.Time) int {
l.mu.RLock()
defer l.mu.RUnlock()
if l.config == nil {
return 0
}
weekday := strings.ToLower(now.Weekday().String()) // monday, tuesday, ...
currentMinutes := now.Hour()*60 + now.Minute()
for _, rule := range l.config.Rules {
if !matchDay(weekday, rule.Days) {
continue
}
if !matchTimeRange(currentMinutes, rule.TimeRange) {
continue
}
if matchExceptRange(currentMinutes, rule.Except) {
continue
}
return rule.IntervalMinutes
}
return l.config.DefaultIntervalMinutes
}
// matchDay checks if the current weekday is in the rule's days list.
func matchDay(currentDay string, days []string) bool {
for _, d := range days {
if strings.ToLower(d) == currentDay {
return true
}
}
return false
}
// matchTimeRange checks if currentMinutes (0-1439) falls within the time range.
// Supports overnight ranges (e.g., 23:00-07:00 where start > end).
func matchTimeRange(currentMinutes int, timeRange string) bool {
start, end, ok := parseTimeRange(timeRange)
if !ok {
return false
}
if start <= end {
return currentMinutes >= start && currentMinutes < end
}
// Overnight range
return currentMinutes >= start || currentMinutes < end
}
// matchExceptRange returns true if currentMinutes falls in any except range.
func matchExceptRange(currentMinutes int, exceptRanges []string) bool {
for _, er := range exceptRanges {
start, end, ok := parseTimeRange(er)
if !ok {
continue
}
if start <= end {
if currentMinutes >= start && currentMinutes < end {
return true
}
} else {
if currentMinutes >= start || currentMinutes < end {
return true
}
}
}
return false
}
// parseTimeRange parses "HH:MM-HH:MM" into start and end minutes from midnight.
func parseTimeRange(r string) (int, int, bool) {
parts := strings.SplitN(r, "-", 2)
if len(parts) != 2 {
return 0, 0, false
}
start, ok := parseHM(strings.TrimSpace(parts[0]))
if !ok {
return 0, 0, false
}
end, ok := parseHM(strings.TrimSpace(parts[1]))
if !ok {
return 0, 0, false
}
return start, end, true
}
// parseHM parses "HH:MM" into minutes from midnight.
func parseHM(s string) (int, bool) {
parts := strings.SplitN(s, ":", 2)
if len(parts) != 2 {
return 0, false
}
h, err := strconv.Atoi(parts[0])
if err != nil || h < 0 || h > 23 {
return 0, false
}
m, err := strconv.Atoi(parts[1])
if err != nil || m < 0 || m > 59 {
return 0, false
}
return h*60 + m, true
}
+25 -6
View File
@@ -24,9 +24,10 @@ type IoTDeviceSummary interface {
// ConversationStore 会话历史存储接口
type ConversationStore struct {
mu sync.RWMutex
messages map[string][]model.LLMMessage // key = sessionID
maxHistory int
mu sync.RWMutex
messages map[string][]model.LLMMessage // key = sessionID
maxHistory int
databaseURL string // lazy-load from DB on cache miss
}
// NewConversationStore 创建会话历史存储
@@ -37,6 +38,13 @@ func NewConversationStore(maxHistory int) *ConversationStore {
}
}
// SetDatabaseURL sets the database URL for lazy-loading history on cache miss.
func (cs *ConversationStore) SetDatabaseURL(url string) {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.databaseURL = url
}
// AddMessage 添加消息到会话历史
func (cs *ConversationStore) AddMessage(sessionID string, msg model.LLMMessage) {
cs.mu.Lock()
@@ -59,12 +67,23 @@ func (cs *ConversationStore) AddMessage(sessionID string, msg model.LLMMessage)
cs.messages[sessionID] = msgs
}
// GetHistory 获取会话历史
// GetHistory 获取会话历史
// 如果内存缓存为空且配置了 databaseURL,会尝试从 DB 懒加载历史。
func (cs *ConversationStore) GetHistory(sessionID string, limit int) []model.LLMMessage {
cs.mu.RLock()
defer cs.mu.RUnlock()
msgs := cs.messages[sessionID]
dbURL := cs.databaseURL
cs.mu.RUnlock()
if len(msgs) == 0 && dbURL != "" {
// 懒加载:从 DB 恢复该会话的历史
if err := cs.LoadFromDB(dbURL, sessionID, limit); err == nil {
cs.mu.RLock()
msgs = cs.messages[sessionID]
cs.mu.RUnlock()
}
}
if len(msgs) == 0 {
return nil
}
@@ -0,0 +1,204 @@
package host
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
)
// DirectBackend executes commands directly on the host via os/exec,
// with command allowlist and directory restrictions for safety.
type DirectBackend struct {
sandbox *Sandbox
allowedDirs []string
}
// NewDirectBackend creates a host execution backend that runs commands
// directly on the host machine with sandbox restrictions.
func NewDirectBackend(sandbox *Sandbox) *DirectBackend {
b := &DirectBackend{sandbox: sandbox}
if sandbox != nil {
b.allowedDirs = sandbox.cfg.AllowedDirs
}
return b
}
func (b *DirectBackend) Name() string { return "direct" }
// SetAllowedDirs updates the directories accessible for file operations.
func (b *DirectBackend) SetAllowedDirs(dirs []string) {
b.allowedDirs = dirs
if b.sandbox != nil {
b.sandbox.cfg.AllowedDirs = dirs
}
}
// Exec runs a command in the sandbox.
func (b *DirectBackend) Exec(ctx context.Context, command, workDir string, timeout time.Duration) (*ExecResult, error) {
return b.sandbox.Exec(ctx, command, workDir, timeout)
}
// ReadFile reads the contents of a file within allowed directories.
func (b *DirectBackend) ReadFile(path string, maxBytes int) (string, error) {
if maxBytes <= 0 {
maxBytes = 1024 * 1024
}
if err := b.validatePath(path); err != nil {
return "", err
}
info, err := os.Stat(path)
if err != nil {
return "", fmt.Errorf("cannot stat file: %w", err)
}
if info.IsDir() {
return "", fmt.Errorf("path is a directory: %s", path)
}
if info.Size() > int64(maxBytes) {
return "", fmt.Errorf("file too large: %d bytes (max %d)", info.Size(), maxBytes)
}
data, err := os.ReadFile(path)
if err != nil {
return "", fmt.Errorf("cannot read file: %w", err)
}
if len(data) > maxBytes {
data = data[:maxBytes]
}
return string(data), nil
}
// WriteFile writes data to a file within allowed directories.
func (b *DirectBackend) WriteFile(path, content string, maxBytes int) error {
if maxBytes <= 0 {
maxBytes = 1024 * 1024
}
if len(content) > maxBytes {
return fmt.Errorf("content too large: %d bytes (max %d)", len(content), maxBytes)
}
if err := b.validatePath(path); err != nil {
return err
}
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("cannot create directory: %w", err)
}
return os.WriteFile(path, []byte(content), 0644)
}
// ListDir lists directory contents within allowed directories.
func (b *DirectBackend) ListDir(path string) ([]DirEntry, error) {
if err := b.validatePath(path); err != nil {
return nil, err
}
entries, err := os.ReadDir(path)
if err != nil {
return nil, fmt.Errorf("cannot read directory: %w", err)
}
result := make([]DirEntry, 0, len(entries))
for _, e := range entries {
info, _ := e.Info()
size := int64(0)
modTime := time.Time{}
if info != nil {
size = info.Size()
modTime = info.ModTime()
}
result = append(result, DirEntry{
Name: e.Name(),
IsDir: e.IsDir(),
Size: size,
ModTime: modTime.Format(time.RFC3339),
})
}
return result, nil
}
// SystemInfo returns basic system information.
func (b *DirectBackend) SystemInfo() map[string]interface{} {
hostname, _ := os.Hostname()
wd, _ := os.Getwd()
info := map[string]interface{}{
"hostname": hostname,
"os": runtime.GOOS,
"arch": runtime.GOARCH,
"num_cpu": runtime.NumCPU(),
"go_version": runtime.Version(),
"work_dir": wd,
"backend": "direct",
}
if runtime.GOOS == "windows" {
cmd := exec.Command("systeminfo")
out, err := cmd.Output()
if err == nil {
lines := strings.Split(string(out), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.Contains(line, "Total Physical Memory") {
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
info["total_memory"] = strings.TrimSpace(parts[1])
}
}
if strings.Contains(line, "OS Name") {
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
info["os_name"] = strings.TrimSpace(parts[1])
}
}
}
}
} else {
if data, err := os.ReadFile("/proc/meminfo"); err == nil {
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "MemTotal:") {
info["total_memory"] = strings.TrimSpace(strings.TrimPrefix(line, "MemTotal:"))
break
}
}
}
}
return info
}
// DiskUsage returns disk usage for the given path.
func (b *DirectBackend) DiskUsage(path string) (map[string]interface{}, error) {
if err := b.validatePath(path); err != nil {
return nil, err
}
info, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("cannot stat path: %w", err)
}
return map[string]interface{}{
"path": path,
"is_dir": info.IsDir(),
"size": info.Size(),
"mod_time": info.ModTime().Format(time.RFC3339),
}, nil
}
func (b *DirectBackend) validatePath(path string) error {
absPath, err := filepath.Abs(path)
if err != nil {
return fmt.Errorf("cannot resolve path: %w", err)
}
if len(b.allowedDirs) == 0 {
return nil
}
for _, allowed := range b.allowedDirs {
absAllowed, err := filepath.Abs(allowed)
if err != nil {
continue
}
if strings.HasPrefix(absPath, absAllowed+string(os.PathSeparator)) || absPath == absAllowed {
return nil
}
}
return fmt.Errorf("path not in allowed directories: %s", path)
}
@@ -0,0 +1,274 @@
package host
import (
"bytes"
"context"
"fmt"
"os/exec"
"strings"
"time"
)
// DockerBackend executes commands inside a Docker container,
// providing a full Linux OS environment with container-level isolation.
type DockerBackend struct {
container string
image string
timeout time.Duration
}
// NewDockerBackend creates a Docker backend that runs commands in the
// specified container. If the container does not exist, it will be
// created from the given image.
func NewDockerBackend(container, image string, defaultTimeout time.Duration) *DockerBackend {
if defaultTimeout <= 0 {
defaultTimeout = 30 * time.Second
}
return &DockerBackend{
container: container,
image: image,
timeout: defaultTimeout,
}
}
func (b *DockerBackend) Name() string { return "docker" }
// ensureContainer checks that the container exists and is running.
// If it doesn't exist, it creates it from the configured image.
func (b *DockerBackend) ensureContainer() error {
// Check if container exists and is running
check := exec.Command("docker", "inspect", "-f", "{{.State.Running}}", b.container)
out, err := check.Output()
if err == nil && strings.TrimSpace(string(out)) == "true" {
return nil
}
// Check if container exists but is stopped
if err == nil && strings.TrimSpace(string(out)) == "false" {
start := exec.Command("docker", "start", b.container)
if out, err := start.CombinedOutput(); err != nil {
return fmt.Errorf("cannot start container %s: %s — %w", b.container, string(out), err)
}
return nil
}
// Create and start a new container
create := exec.Command("docker", "run", "-d", "--name", b.container,
"--restart", "unless-stopped",
b.image, "sleep", "infinity")
if out, err := create.CombinedOutput(); err != nil {
return fmt.Errorf("cannot create container %s from image %s: %s — %w",
b.container, b.image, string(out), err)
}
return nil
}
// Exec runs a command inside the Docker container.
func (b *DockerBackend) Exec(ctx context.Context, command, workDir string, timeout time.Duration) (*ExecResult, error) {
if command == "" {
return nil, fmt.Errorf("empty command")
}
if err := b.ensureContainer(); err != nil {
return nil, err
}
if timeout <= 0 {
timeout = b.timeout
}
execCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Build the shell command to run inside the container
script := command
if workDir != "" {
script = fmt.Sprintf("cd %s && %s", shellEscapeDocker(workDir), command)
}
cmd := exec.CommandContext(execCtx, "docker", "exec", b.container, "sh", "-c", script)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
start := time.Now()
err := cmd.Run()
elapsed := time.Since(start)
result := &ExecResult{
Duration: elapsed.Round(time.Millisecond).String(),
Stdout: stdout.String(),
Stderr: stderr.String(),
}
if execCtx.Err() == context.DeadlineExceeded {
result.TimedOut = true
result.ExitCode = -1
return result, fmt.Errorf("command timed out after %s", timeout)
}
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
result.ExitCode = exitErr.ExitCode()
} else {
result.ExitCode = -1
}
} else {
result.ExitCode = 0
}
return result, err
}
// ReadFile reads a file from inside the container using cat.
func (b *DockerBackend) ReadFile(path string, maxBytes int) (string, error) {
if maxBytes <= 0 {
maxBytes = 1024 * 1024
}
if err := b.ensureContainer(); err != nil {
return "", err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "docker", "exec", b.container, "cat", path)
out, err := cmd.Output()
if err != nil {
return "", fmt.Errorf("cannot read file %s: %w", path, err)
}
if len(out) > maxBytes {
out = out[:maxBytes]
}
return string(out), nil
}
// WriteFile writes content to a file inside the container.
func (b *DockerBackend) WriteFile(path, content string, maxBytes int) error {
if maxBytes <= 0 {
maxBytes = 1024 * 1024
}
if len(content) > maxBytes {
return fmt.Errorf("content too large: %d bytes (max %d)", len(content), maxBytes)
}
if err := b.ensureContainer(); err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Create parent directory and write file
cmd := exec.CommandContext(ctx, "docker", "exec", "-i", b.container, "sh", "-c",
fmt.Sprintf("mkdir -p $(dirname %s) && cat > %s", shellEscapeDocker(path), shellEscapeDocker(path)))
cmd.Stdin = strings.NewReader(content)
_, err := cmd.Output()
if err != nil {
return fmt.Errorf("cannot write file %s: %w", path, err)
}
return nil
}
// ListDir lists a directory inside the container.
func (b *DockerBackend) ListDir(path string) ([]DirEntry, error) {
if err := b.ensureContainer(); err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "docker", "exec", b.container, "sh", "-c",
fmt.Sprintf("ls -la %s 2>/dev/null | tail -n +2 || echo ''", shellEscapeDocker(path)))
out, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("cannot list dir %s: %w", path, err)
}
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
result := make([]DirEntry, 0, len(lines))
for _, line := range lines {
if line == "" || strings.HasPrefix(line, "total ") {
continue
}
// Parse ls -la output: drwxr-xr-x 2 root root 4096 Jan 1 12:00 name
fields := strings.Fields(line)
if len(fields) < 9 {
continue
}
isDir := strings.HasPrefix(fields[0], "d")
name := fields[len(fields)-1]
if name == "." || name == ".." {
continue
}
var size int64
fmt.Sscanf(fields[4], "%d", &size)
result = append(result, DirEntry{
Name: name,
IsDir: isDir,
Size: size,
})
}
return result, nil
}
// SystemInfo returns system information from inside the container.
func (b *DockerBackend) SystemInfo() map[string]interface{} {
info := map[string]interface{}{
"backend": "docker",
"container": b.container,
"image": b.image,
}
if err := b.ensureContainer(); err != nil {
info["error"] = err.Error()
return info
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if out, err := exec.CommandContext(ctx, "docker", "exec", b.container, "uname", "-a").Output(); err == nil {
info["uname"] = strings.TrimSpace(string(out))
}
if out, err := exec.CommandContext(ctx, "docker", "exec", b.container, "hostname").Output(); err == nil {
info["hostname"] = strings.TrimSpace(string(out))
}
if out, err := exec.CommandContext(ctx, "docker", "exec", b.container, "free", "-h").Output(); err == nil {
info["memory"] = strings.TrimSpace(string(out))
}
if out, err := exec.CommandContext(ctx, "docker", "exec", b.container, "df", "-h", "/").Output(); err == nil {
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
if len(lines) > 1 {
info["disk"] = strings.TrimSpace(lines[1])
}
}
return info
}
// DiskUsage returns disk usage for a path inside the container.
func (b *DockerBackend) DiskUsage(path string) (map[string]interface{}, error) {
if err := b.ensureContainer(); err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "docker", "exec", b.container, "stat", path)
out, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("cannot stat path %s: %w", path, err)
}
return map[string]interface{}{
"path": path,
"stat": strings.TrimSpace(string(out)),
}, nil
}
// shellEscapeDocker escapes a string for safe use in a shell command.
func shellEscapeDocker(s string) string {
escaped := strings.ReplaceAll(s, "'", "'\\''")
return "'" + escaped + "'"
}
@@ -0,0 +1,323 @@
package host
import (
"bytes"
"context"
"fmt"
"os/exec"
"path/filepath"
"strings"
"time"
)
// WSLBackend executes commands inside a WSL2 distribution,
// providing a full Linux OS environment isolated from the Windows host.
type WSLBackend struct {
distro string
username string
password string
timeout time.Duration
userEnsured bool
}
// NewWSLBackend creates a WSL backend that runs commands in the
// specified WSL distribution as the given user. On first use,
// the user is automatically created with sudo privileges.
func NewWSLBackend(distro, username, password string, defaultTimeout time.Duration) *WSLBackend {
if defaultTimeout <= 0 {
defaultTimeout = 30 * time.Second
}
if username == "" {
username = "cyrene"
}
return &WSLBackend{
distro: distro,
username: username,
password: password,
timeout: defaultTimeout,
}
}
func (b *WSLBackend) Name() string { return "wsl" }
// ensureUser creates the configured user inside the WSL distro on first call.
// The user gets sudo privileges and the configured password.
func (b *WSLBackend) ensureUser() error {
if b.userEnsured {
return nil
}
// Check if user already exists
checkCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
checkCmd := exec.CommandContext(checkCtx, "wsl.exe", "-d", b.distro, "--", "id", b.username)
if checkCmd.Run() == nil {
b.userEnsured = true
return nil
}
// Create user with home directory, set password, add to sudo group
// If password is empty, create user without password (sudo won't need it
// if NOPASSWD is configured, but we still set a random one for safety)
pwd := b.password
if pwd == "" {
pwd = "cyrene"
}
// Escape single quotes in password for the shell echo command
escapedPwd := strings.ReplaceAll(pwd, "'", "'\\''")
script := fmt.Sprintf(
"useradd -m -s /bin/bash %s && echo '%s:%s' | chpasswd && usermod -aG sudo %s",
b.username, b.username, escapedPwd, b.username,
)
createCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
createCmd := exec.CommandContext(createCtx, "wsl.exe", "-d", b.distro, "--", "bash", "-c", script)
if out, err := createCmd.CombinedOutput(); err != nil {
return fmt.Errorf("cannot create user %s: %s — %w", b.username, string(out), err)
}
b.userEnsured = true
return nil
}
// Exec runs a command inside the WSL distribution via bash.
func (b *WSLBackend) Exec(ctx context.Context, command, workDir string, timeout time.Duration) (*ExecResult, error) {
if command == "" {
return nil, fmt.Errorf("empty command")
}
if err := b.ensureUser(); err != nil {
return nil, err
}
if timeout <= 0 {
timeout = b.timeout
}
execCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Build the bash command to run inside WSL
script := command
if workDir != "" {
wslPath := windowsToWSLPath(workDir)
script = fmt.Sprintf("cd %s && %s", shellEscape(wslPath), command)
}
cmd := exec.CommandContext(execCtx, "wsl.exe", "-d", b.distro, "--", "bash", "-c", script)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
start := time.Now()
err := cmd.Run()
elapsed := time.Since(start)
result := &ExecResult{
Duration: elapsed.Round(time.Millisecond).String(),
Stdout: stdout.String(),
Stderr: stderr.String(),
}
if execCtx.Err() == context.DeadlineExceeded {
result.TimedOut = true
result.ExitCode = -1
return result, fmt.Errorf("command timed out after %s", timeout)
}
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
result.ExitCode = exitErr.ExitCode()
} else {
result.ExitCode = -1
}
} else {
result.ExitCode = 0
}
return result, err
}
// ReadFile reads a file from the WSL filesystem using cat.
func (b *WSLBackend) ReadFile(path string, maxBytes int) (string, error) {
if maxBytes <= 0 {
maxBytes = 1024 * 1024
}
if err := b.ensureUser(); err != nil {
return "", err
}
wslPath := windowsToWSLPath(path)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "cat", wslPath)
out, err := cmd.Output()
if err != nil {
return "", fmt.Errorf("cannot read file %s: %w", path, err)
}
if len(out) > maxBytes {
out = out[:maxBytes]
}
return string(out), nil
}
// WriteFile writes content to a file in the WSL filesystem.
func (b *WSLBackend) WriteFile(path, content string, maxBytes int) error {
if maxBytes <= 0 {
maxBytes = 1024 * 1024
}
if len(content) > maxBytes {
return fmt.Errorf("content too large: %d bytes (max %d)", len(content), maxBytes)
}
if err := b.ensureUser(); err != nil {
return err
}
wslPath := windowsToWSLPath(path)
// Create parent directory first
dir := filepath.Dir(wslPath)
_ = exec.Command("wsl.exe", "-d", b.distro, "--", "mkdir", "-p", dir).Run()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "bash", "-c",
fmt.Sprintf("cat > %s", shellEscape(wslPath)))
cmd.Stdin = strings.NewReader(content)
_, err := cmd.Output()
if err != nil {
return fmt.Errorf("cannot write file %s: %w", path, err)
}
return nil
}
// ListDir lists a directory in the WSL filesystem using ls.
func (b *WSLBackend) ListDir(path string) ([]DirEntry, error) {
if err := b.ensureUser(); err != nil {
return nil, err
}
wslPath := windowsToWSLPath(path)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "bash", "-c",
fmt.Sprintf("stat -c '%%n|%%F|%%s|%%Y' %s/* 2>/dev/null || echo ''", shellEscape(wslPath)))
out, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("cannot list dir %s: %w", path, err)
}
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
result := make([]DirEntry, 0, len(lines))
for _, line := range lines {
if line == "" {
continue
}
parts := strings.SplitN(line, "|", 4)
if len(parts) < 4 {
continue
}
var size int64
fmt.Sscanf(parts[2], "%d", &size)
var modTimeUnix int64
fmt.Sscanf(parts[3], "%d", &modTimeUnix)
modTime := time.Unix(modTimeUnix, 0).Format(time.RFC3339)
isDir := strings.Contains(parts[1], "directory")
result = append(result, DirEntry{
Name: filepath.Base(parts[0]),
IsDir: isDir,
Size: size,
ModTime: modTime,
})
}
return result, nil
}
// SystemInfo returns system information from inside the WSL distribution.
func (b *WSLBackend) SystemInfo() map[string]interface{} {
info := map[string]interface{}{
"backend": "wsl",
"distro": b.distro,
}
if err := b.ensureUser(); err != nil {
info["error"] = err.Error()
return info
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// uname
if out, err := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "uname", "-a").Output(); err == nil {
info["uname"] = strings.TrimSpace(string(out))
}
// hostname
if out, err := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "hostname").Output(); err == nil {
info["hostname"] = strings.TrimSpace(string(out))
}
// memory info
if out, err := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "free", "-h").Output(); err == nil {
info["memory"] = strings.TrimSpace(string(out))
}
// disk info
if out, err := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "df", "-h", "/").Output(); err == nil {
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
if len(lines) > 1 {
info["disk"] = strings.TrimSpace(lines[1])
}
}
return info
}
// DiskUsage returns disk usage for a path inside WSL.
func (b *WSLBackend) DiskUsage(path string) (map[string]interface{}, error) {
wslPath := windowsToWSLPath(path)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "wsl.exe", "-d", b.distro, "--", "stat", wslPath)
out, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("cannot stat path %s: %w", path, err)
}
// Parse stat output minimally
result := map[string]interface{}{
"path": path,
"wsl_path": wslPath,
"stat": strings.TrimSpace(string(out)),
}
return result, nil
}
// windowsToWSLPath converts a Windows path to its WSL equivalent.
// C:\Users\foo → /mnt/c/Users/foo
// If the path is already a WSL path (starts with /), return as-is.
func windowsToWSLPath(path string) string {
if strings.HasPrefix(path, "/") {
return path // Already a Unix path
}
// Handle Windows drive letter: C:\... → /mnt/c/...
if len(path) >= 2 && path[1] == ':' {
drive := strings.ToLower(string(path[0]))
rest := strings.TrimPrefix(path[2:], "\\")
rest = strings.ReplaceAll(rest, "\\", "/")
return fmt.Sprintf("/mnt/%s/%s", drive, rest)
}
return path
}
// shellEscape escapes a string for safe use in a shell command.
func shellEscape(s string) string {
// Use single quotes and escape any single quotes in the string
escaped := strings.ReplaceAll(s, "'", "'\\''")
return "'" + escaped + "'"
}
@@ -0,0 +1,143 @@
package host
import (
"context"
"os"
"strings"
"testing"
"time"
)
func TestWSLBackendIntegration(t *testing.T) {
distro := os.Getenv("WSL_DISTRO")
if distro == "" {
t.Skip("WSL_DISTRO not set, skipping WSL integration test (set WSL_DISTRO=cyrene-wsl to run)")
}
backend := NewWSLBackend(distro, "cyrene", "test123", 30*time.Second)
mgr := NewManager(backend)
ctx := context.Background()
// 1. Basic command
t.Run("echo", func(t *testing.T) {
r, err := mgr.Exec(ctx, "echo 'hello from WSL OS env'", "", 10*time.Second)
if err != nil {
t.Fatalf("exec failed: %v", err)
}
if r.ExitCode != 0 {
t.Fatalf("exit=%d, stderr=%s", r.ExitCode, r.Stderr)
}
if !strings.Contains(r.Stdout, "hello from WSL OS env") {
t.Fatalf("unexpected stdout: %s", r.Stdout)
}
t.Logf("echo OK: %s (duration=%s)", strings.TrimSpace(r.Stdout), r.Duration)
})
// 2. Complex commands - package manager
t.Run("apt", func(t *testing.T) {
r, err := mgr.Exec(ctx, "apt --version 2>&1", "", 10*time.Second)
if err != nil {
t.Fatalf("exec failed: %v", err)
}
t.Logf("apt OK: %s", strings.TrimSpace(r.Stdout))
})
// 3. Python (should be pre-installed on Ubuntu)
t.Run("python", func(t *testing.T) {
r, err := mgr.Exec(ctx, "python3 --version 2>&1", "", 10*time.Second)
if err != nil {
t.Fatalf("exec failed: %v", err)
}
t.Logf("python OK: %s", strings.TrimSpace(r.Stdout))
})
// 4. Pipeline & shell features
t.Run("pipeline", func(t *testing.T) {
r, err := mgr.Exec(ctx, "echo 'a\nb\nc\nd' | wc -l", "", 10*time.Second)
if err != nil {
t.Fatalf("exec failed: %v", err)
}
if r.ExitCode != 0 {
t.Fatalf("exit=%d", r.ExitCode)
}
t.Logf("pipeline OK: %s", strings.TrimSpace(r.Stdout))
})
// 5. File write & read
t.Run("file_rw", func(t *testing.T) {
err := mgr.WriteFile("/tmp/cyrene-wsl-test.txt", "Hello from Cyrene OS!", 1024*1024)
if err != nil {
t.Fatalf("write failed: %v", err)
}
content, err := mgr.ReadFile("/tmp/cyrene-wsl-test.txt", 1024*1024)
if err != nil {
t.Fatalf("read failed: %v", err)
}
if content != "Hello from Cyrene OS!" {
t.Fatalf("content mismatch: %q", content)
}
t.Logf("file r/w OK: %q", content)
})
// 6. Directory listing
t.Run("listdir", func(t *testing.T) {
entries, err := mgr.ListDir("/etc")
if err != nil {
t.Fatalf("listdir failed: %v", err)
}
if len(entries) == 0 {
t.Fatal("expected entries in /etc")
}
t.Logf("listdir OK: %d entries in /etc", len(entries))
for _, e := range entries {
if e.Name == "os-release" || e.Name == "hostname" {
t.Logf(" - %s (isDir=%v, size=%d)", e.Name, e.IsDir, e.Size)
}
}
})
// 7. System info
t.Run("sysinfo", func(t *testing.T) {
info := mgr.SystemInfo()
if info["backend"] != "wsl" {
t.Fatalf("unexpected backend: %v", info["backend"])
}
if info["distro"] != distro {
t.Fatalf("unexpected distro: %v", info["distro"])
}
t.Logf("sysinfo OK: backend=%v, distro=%v", info["backend"], info["distro"])
if uname, ok := info["uname"]; ok {
t.Logf(" uname: %v", uname)
}
if hostname, ok := info["hostname"]; ok {
t.Logf(" hostname: %v", hostname)
}
if mem, ok := info["memory"]; ok {
t.Logf(" memory: %v", mem)
}
})
// 8. workDir
t.Run("workdir", func(t *testing.T) {
r, err := mgr.Exec(ctx, "pwd", "/tmp", 10*time.Second)
if err != nil {
t.Fatalf("exec failed: %v", err)
}
if !strings.Contains(r.Stdout, "/tmp") {
t.Fatalf("expected /tmp, got: %s", r.Stdout)
}
t.Logf("workdir OK: pwd=%s", strings.TrimSpace(r.Stdout))
})
// 9. Timeout
t.Run("timeout", func(t *testing.T) {
r, err := mgr.Exec(ctx, "sleep 10", "", 1*time.Second)
if err == nil {
t.Fatal("expected timeout")
}
if !r.TimedOut {
t.Fatal("expected TimedOut=true")
}
t.Logf("timeout OK: timed_out=%v", r.TimedOut)
})
}
+39 -174
View File
@@ -2,207 +2,72 @@ package host
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
)
// Manager provides controlled access to the host machine.
// It wraps a Sandbox for command execution and adds file system
// operations with path allow-list enforcement.
// HostBackend defines the interface for command execution and file system
// operations. Implementations include DirectBackend (host OS), WSLBackend
// (Windows Subsystem for Linux), and DockerBackend (container).
type HostBackend interface {
Exec(ctx context.Context, command, workDir string, timeout time.Duration) (*ExecResult, error)
ReadFile(path string, maxBytes int) (string, error)
WriteFile(path, content string, maxBytes int) error
ListDir(path string) ([]DirEntry, error)
SystemInfo() map[string]interface{}
DiskUsage(path string) (map[string]interface{}, error)
Name() string
}
// Manager provides controlled access to the host machine. It delegates
// to a HostBackend implementation which may be direct, WSL, or Docker.
type Manager struct {
sandbox *Sandbox
allowedDirs []string
backend HostBackend
}
// NewManager creates a new host Manager.
func NewManager(sandbox *Sandbox) *Manager {
m := &Manager{sandbox: sandbox}
if sandbox != nil {
m.allowedDirs = sandbox.cfg.AllowedDirs
}
return m
// NewManager creates a new host Manager with the given backend.
func NewManager(backend HostBackend) *Manager {
return &Manager{backend: backend}
}
// SetAllowedDirs updates the list of directories accessible for file operations.
// SetAllowedDirs updates directory restrictions. Only effective for
// DirectBackend; WSL and Docker backends are no-ops.
func (m *Manager) SetAllowedDirs(dirs []string) {
m.allowedDirs = dirs
m.sandbox.cfg.AllowedDirs = dirs
if db, ok := m.backend.(*DirectBackend); ok {
db.SetAllowedDirs(dirs)
}
}
// Exec runs a command in the sandbox.
// Exec runs a command via the configured backend.
func (m *Manager) Exec(ctx context.Context, command, workDir string, timeout time.Duration) (*ExecResult, error) {
return m.sandbox.Exec(ctx, command, workDir, timeout)
return m.backend.Exec(ctx, command, workDir, timeout)
}
// ReadFile reads the contents of a file within allowed directories.
// ReadFile reads a file via the configured backend.
func (m *Manager) ReadFile(path string, maxBytes int) (string, error) {
if maxBytes <= 0 {
maxBytes = 1024 * 1024
}
if err := m.validatePath(path); err != nil {
return "", err
}
info, err := os.Stat(path)
if err != nil {
return "", fmt.Errorf("cannot stat file: %w", err)
}
if info.IsDir() {
return "", fmt.Errorf("path is a directory: %s", path)
}
if info.Size() > int64(maxBytes) {
return "", fmt.Errorf("file too large: %d bytes (max %d)", info.Size(), maxBytes)
}
data, err := os.ReadFile(path)
if err != nil {
return "", fmt.Errorf("cannot read file: %w", err)
}
if len(data) > maxBytes {
data = data[:maxBytes]
}
return string(data), nil
return m.backend.ReadFile(path, maxBytes)
}
// WriteFile writes data to a file within allowed directories.
// WriteFile writes a file via the configured backend.
func (m *Manager) WriteFile(path, content string, maxBytes int) error {
if maxBytes <= 0 {
maxBytes = 1024 * 1024
}
if len(content) > maxBytes {
return fmt.Errorf("content too large: %d bytes (max %d)", len(content), maxBytes)
}
if err := m.validatePath(path); err != nil {
return err
}
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("cannot create directory: %w", err)
}
return os.WriteFile(path, []byte(content), 0644)
return m.backend.WriteFile(path, content, maxBytes)
}
// ListDir lists directory contents within allowed directories.
// ListDir lists a directory via the configured backend.
func (m *Manager) ListDir(path string) ([]DirEntry, error) {
if err := m.validatePath(path); err != nil {
return nil, err
}
entries, err := os.ReadDir(path)
if err != nil {
return nil, fmt.Errorf("cannot read directory: %w", err)
}
result := make([]DirEntry, 0, len(entries))
for _, e := range entries {
info, _ := e.Info()
size := int64(0)
modTime := time.Time{}
if info != nil {
size = info.Size()
modTime = info.ModTime()
}
result = append(result, DirEntry{
Name: e.Name(),
IsDir: e.IsDir(),
Size: size,
ModTime: modTime.Format(time.RFC3339),
})
}
return result, nil
return m.backend.ListDir(path)
}
// DirEntry represents a filesystem directory entry.
type DirEntry struct {
Name string `json:"name"`
IsDir bool `json:"is_dir"`
Size int64 `json:"size"`
ModTime string `json:"mod_time"`
}
// SystemInfo returns basic system information.
// SystemInfo returns system information from the configured backend.
func (m *Manager) SystemInfo() map[string]interface{} {
hostname, _ := os.Hostname()
wd, _ := os.Getwd()
info := map[string]interface{}{
"hostname": hostname,
"os": runtime.GOOS,
"arch": runtime.GOARCH,
"num_cpu": runtime.NumCPU(),
"go_version": runtime.Version(),
"work_dir": wd,
}
if runtime.GOOS == "windows" {
cmd := exec.Command("systeminfo")
out, err := cmd.Output()
if err == nil {
lines := strings.Split(string(out), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.Contains(line, "Total Physical Memory") {
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
info["total_memory"] = strings.TrimSpace(parts[1])
}
}
if strings.Contains(line, "OS Name") {
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
info["os_name"] = strings.TrimSpace(parts[1])
}
}
}
}
} else {
if data, err := os.ReadFile("/proc/meminfo"); err == nil {
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "MemTotal:") {
info["total_memory"] = strings.TrimSpace(strings.TrimPrefix(line, "MemTotal:"))
break
}
}
}
}
return info
return m.backend.SystemInfo()
}
// DiskUsage returns disk usage for the given path.
// DiskUsage returns disk usage info from the configured backend.
func (m *Manager) DiskUsage(path string) (map[string]interface{}, error) {
if err := m.validatePath(path); err != nil {
return nil, err
}
info, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("cannot stat path: %w", err)
}
result := map[string]interface{}{
"path": path,
"is_dir": info.IsDir(),
"size": info.Size(),
"mod_time": info.ModTime().Format(time.RFC3339),
}
return result, nil
return m.backend.DiskUsage(path)
}
func (m *Manager) validatePath(path string) error {
absPath, err := filepath.Abs(path)
if err != nil {
return fmt.Errorf("cannot resolve path: %w", err)
}
if len(m.allowedDirs) == 0 {
return nil
}
for _, allowed := range m.allowedDirs {
absAllowed, err := filepath.Abs(allowed)
if err != nil {
continue
}
if strings.HasPrefix(absPath, absAllowed+string(os.PathSeparator)) || absPath == absAllowed {
return nil
}
}
return fmt.Errorf("path not in allowed directories: %s", path)
// BackendName returns the name of the active backend.
func (m *Manager) BackendName() string {
return m.backend.Name()
}
+8
View File
@@ -50,6 +50,14 @@ func NewSandbox(cfg SandboxConfig) *Sandbox {
return &Sandbox{cfg: cfg}
}
// DirEntry represents a filesystem directory entry.
type DirEntry struct {
Name string `json:"name"`
IsDir bool `json:"is_dir"`
Size int64 `json:"size"`
ModTime string `json:"mod_time,omitempty"`
}
// ExecResult holds the result of a sandboxed command execution.
type ExecResult struct {
Stdout string `json:"stdout"`
@@ -62,7 +62,7 @@ func TestManagerFileOps(t *testing.T) {
tmpDir := os.TempDir()
cfg.AllowedDirs = []string{tmpDir}
sandbox := NewSandbox(cfg)
mgr := NewManager(sandbox)
mgr := NewManager(NewDirectBackend(sandbox))
mgr.SetAllowedDirs([]string{tmpDir})
testPath := filepath.Join(tmpDir, "cyrene-test-file.txt")
@@ -102,7 +102,7 @@ func TestManagerFileOps(t *testing.T) {
func TestManagerSystemInfo(t *testing.T) {
cfg := DefaultSandboxConfig()
sandbox := NewSandbox(cfg)
mgr := NewManager(sandbox)
mgr := NewManager(NewDirectBackend(sandbox))
info := mgr.SystemInfo()
if info["hostname"] == nil || info["hostname"] == "" {
@@ -121,7 +121,7 @@ func TestPathValidation(t *testing.T) {
cfg := DefaultSandboxConfig()
cfg.AllowedDirs = []string{os.TempDir()}
sandbox := NewSandbox(cfg)
mgr := NewManager(sandbox)
mgr := NewManager(NewDirectBackend(sandbox))
mgr.SetAllowedDirs([]string{os.TempDir()})
// Should fail: access outside allowed dirs
+68 -4
View File
@@ -4,15 +4,16 @@ import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"github.com/yourname/cyrene-ai/pkg/logger"
"net/http"
"strings"
"time"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
"github.com/yourname/cyrene-ai/pkg/logger"
)
// OpenAIConfig OpenAI适配器配置
@@ -267,12 +268,13 @@ func (p *OpenAIProvider) doChat(ctx context.Context, messages []model.LLMMessage
LogCall(r)
}()
// 转换消息格式
// 转换消息格式(先解析图片 URL 为 data URL
oaiMessages := make([]openAIMessage, len(messages))
for i, msg := range messages {
resolvedImages := p.resolveImages(msg.Images)
oaiMsg := openAIMessage{
Role: string(msg.Role),
Content: buildContent(msg.Content, msg.Images),
Content: buildContent(msg.Content, resolvedImages),
Name: msg.Name,
ToolCallID: msg.ToolCallID,
ReasoningContent: msg.ReasoningContent,
@@ -377,9 +379,10 @@ func (p *OpenAIProvider) doChat(ctx context.Context, messages []model.LLMMessage
func (p *OpenAIProvider) doChatStream(ctx context.Context, messages []model.LLMMessage, modelName string, tools []OpenAITool) (*http.Response, error) {
oaiMessages := make([]openAIMessage, len(messages))
for i, msg := range messages {
resolvedImages := p.resolveImages(msg.Images)
oaiMsg := openAIMessage{
Role: string(msg.Role),
Content: buildContent(msg.Content, msg.Images),
Content: buildContent(msg.Content, resolvedImages),
Name: msg.Name,
ToolCallID: msg.ToolCallID,
ReasoningContent: msg.ReasoningContent,
@@ -455,6 +458,67 @@ func contentString(v interface{}) string {
return ""
}
// resolveImages converts non-data URLs to base64 data URLs so external LLM APIs can access them.
func (p *OpenAIProvider) resolveImages(images []string) []string {
if len(images) == 0 {
return images
}
resolved := make([]string, 0, len(images))
for _, img := range images {
if strings.HasPrefix(img, "data:") {
resolved = append(resolved, img)
continue
}
dataURL, err := p.downloadAsDataURL(img)
if err != nil {
logger.Printf("[openai] 图片下载失败, 保留原始 URL: %s, err=%v", img, err)
resolved = append(resolved, img) // 保留原始 URL 作为 fallback
continue
}
resolved = append(resolved, dataURL)
}
return resolved
}
// downloadAsDataURL downloads an image from a URL and returns it as a base64 data URL.
func (p *OpenAIProvider) downloadAsDataURL(url string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("下载失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
// 限制最大 20MB
const maxSize = 20 * 1024 * 1024
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize+1))
if err != nil {
return "", fmt.Errorf("读取失败: %w", err)
}
if len(body) > maxSize {
return "", fmt.Errorf("图片过大: %d bytes", len(body))
}
mimeType := resp.Header.Get("Content-Type")
if mimeType == "" {
mimeType = http.DetectContentType(body)
}
b64 := base64.StdEncoding.EncodeToString(body)
return fmt.Sprintf("data:%s;base64,%s", mimeType, b64), nil
}
// buildContent converts text + optional images to API content format.
// Returns a plain string if no images, or a multimodal array otherwise.
func buildContent(text string, images []string) interface{} {
+1
View File
@@ -19,6 +19,7 @@ const (
PurposeToolCalling ModelPurpose = "tool_calling"
PurposeMemoryExtraction ModelPurpose = "memory_extraction"
PurposeVision ModelPurpose = "vision"
PurposeOCR ModelPurpose = "ocr"
)
// ErrModelNotRequired is returned when an optional model is unavailable.
+25
View File
@@ -167,6 +167,31 @@ func (c *Client) GetByID(ctx context.Context, id string) (*model.MemoryEntry, er
return &result.Memory, nil
}
// Update 更新记忆
func (c *Client) Update(ctx context.Context, entry *model.MemoryEntry) error {
body, _ := json.Marshal(map[string]interface{}{
"content": entry.Content,
"summary": entry.Summary,
"category": string(entry.Category),
"priority": int(entry.Priority),
"importance": entry.Importance,
"keywords": entry.Keywords,
"source": entry.Source,
})
resp, err := c.doRequest(ctx, http.MethodPut, c.baseURL+"/api/v1/memories/"+entry.ID, body)
if err != nil {
return fmt.Errorf("更新记忆失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("更新记忆失败 (%d): %s", resp.StatusCode, string(respBody))
}
return nil
}
// Delete 删除记忆
func (c *Client) Delete(ctx context.Context, id string) error {
resp, err := c.doRequest(ctx, http.MethodDelete, c.baseURL+"/api/v1/memories/"+id, nil)
@@ -8,6 +8,17 @@ type EnrichmentData struct {
ThoughtOutline string
IoTSummary string
KnowledgeInfo string
// Pending tool results from async execution (keyed by tool call ID)
PendingToolResults []PendingToolResult
}
// PendingToolResult holds the result of a tool that completed asynchronously.
type PendingToolResult struct {
ToolCallID string `json:"tool_call_id"`
ToolName string `json:"tool_name"`
Result string `json:"result"`
Success bool `json:"success"`
}
// SessionEnrichmentStore is a thread-safe per-session cache for async
@@ -25,8 +36,15 @@ func NewEnrichmentStore() *SessionEnrichmentStore {
}
}
// Get returns stored enrichment for a session and clears it (one-shot consumption).
// Get returns stored enrichment for a session (does NOT clear; results may be reused).
func (s *SessionEnrichmentStore) Get(sessionID string) *EnrichmentData {
s.mu.RLock()
defer s.mu.RUnlock()
return s.data[sessionID]
}
// Pop returns stored enrichment for a session and clears it (one-shot consumption).
func (s *SessionEnrichmentStore) Pop(sessionID string) *EnrichmentData {
s.mu.Lock()
defer s.mu.Unlock()
d, ok := s.data[sessionID]
@@ -45,3 +63,32 @@ func (s *SessionEnrichmentStore) Store(sessionID string, d *EnrichmentData) {
s.data[sessionID] = d
s.mu.Unlock()
}
// AppendToolResult adds a completed tool result to the session's enrichment data.
func (s *SessionEnrichmentStore) AppendToolResult(sessionID string, r PendingToolResult) {
s.mu.Lock()
defer s.mu.Unlock()
d, ok := s.data[sessionID]
if !ok {
d = &EnrichmentData{}
s.data[sessionID] = d
}
d.PendingToolResults = append(d.PendingToolResults, r)
}
// ---- Global pending tool store (used by Synthesizer for async tool results) ----
var globalPendingToolStore *SessionEnrichmentStore
var pendingToolStoreOnce sync.Once
// InitGlobalPendingToolStore initializes the singleton.
func InitGlobalPendingToolStore() {
pendingToolStoreOnce.Do(func() {
globalPendingToolStore = NewEnrichmentStore()
})
}
// GetGlobalPendingToolStore returns the singleton, or nil if not initialized.
func GetGlobalPendingToolStore() *SessionEnrichmentStore {
return globalPendingToolStore
}
@@ -38,6 +38,7 @@ type Orchestrator struct {
msgScheduler *scheduler.MessageScheduler
emotionTracker *persona.EmotionTracker
toolRegistry *plgManager.ToolRegistry
visionProvider llm.LLMProvider // 视觉模型 (图片预处理/OCR)
}
// SetResponseCache sets the response cache (optional, for Phase 0.2).
@@ -71,6 +72,11 @@ func (o *Orchestrator) SetToolRegistry(tr *plgManager.ToolRegistry) {
o.synthesizer.toolRegistry = tr
}
// SetVisionProvider sets the vision model provider for image preprocessing.
func (o *Orchestrator) SetVisionProvider(vp llm.LLMProvider) {
o.visionProvider = vp
}
// getBus returns the bus or a nop fallback.
func (o *Orchestrator) getBus() bus.Bus {
if o.eventBus == nil {
@@ -149,7 +155,27 @@ func (o *Orchestrator) ProcessInput(
UserID: params.UserID,
})
// 1. 意图分析
// 0.5 图片预处理: 使用视觉模型分析图片,将描述注入消息
if len(params.Images) > 0 && o.visionProvider != nil {
startTime := time.Now()
augmented := o.preprocessImages(ctx, params.Message, params.Images)
if augmented != params.Message {
params.Message = augmented
logger.Printf("[orchestrator] 图片预处理耗时: %v, 原消息=%d字, 增强后=%d字",
time.Since(startTime), len([]rune(params.Message))-len([]rune(augmented))+len([]rune(params.Message)), len([]rune(augmented)))
}
// 预处理后清空原始图片,避免后续传给不支持多模态的 Chat 模型
params.Images = nil
} else if len(params.Images) > 0 {
// 未配置 Vision 模型时,告知用户该模型不支持图片,并清空图片避免报错
if params.Message == "" {
params.Message = "(用户发送了一张图片,但当前未配置视觉模型,无法识别图片内容)"
}
logger.Printf("[orchestrator] 视觉模型未配置,丢弃 %d 张图片", len(params.Images))
params.Images = nil
}
// 1. 意图分析
startTime := time.Now()
intent, err := o.intentAnalyzer.Analyze(ctx, params.Message)
if err != nil || intent == nil {
@@ -247,17 +273,39 @@ func (o *Orchestrator) ProcessInput(
resultCh = o.subManager.Dispatch(subCtx, intent, params.Message, createParams)
}
// 3.5 确保全局工具结果存储已初始化
InitGlobalPendingToolStore()
// 4. 加载上一轮异步完成的子会话富化结果
var prevEnrichment *EnrichmentData
if o.enrichmentStore != nil {
prevEnrichment = o.enrichmentStore.Get(params.SessionID)
if prevEnrichment != nil {
logger.Printf("[orchestrator] 加载上一轮富化结果: memory=%t thought=%t iot=%t knowledge=%t",
prevEnrichment.MemorySummary != "",
prevEnrichment.ThoughtOutline != "",
prevEnrichment.IoTSummary != "",
prevEnrichment.KnowledgeInfo != "")
prevEnrichment = o.enrichmentStore.Pop(params.SessionID)
// Also merge any pending tool results from the global store
if globalStore := GetGlobalPendingToolStore(); globalStore != nil {
if toolData := globalStore.Pop(params.SessionID); toolData != nil && len(toolData.PendingToolResults) > 0 {
if prevEnrichment == nil {
prevEnrichment = &EnrichmentData{}
}
prevEnrichment.PendingToolResults = append(prevEnrichment.PendingToolResults, toolData.PendingToolResults...)
logger.Printf("[orchestrator] 合并后台工具结果 %d 条", len(toolData.PendingToolResults))
}
}
} else {
// Still check global store even if enrichmentStore is not set
if globalStore := GetGlobalPendingToolStore(); globalStore != nil {
if toolData := globalStore.Pop(params.SessionID); toolData != nil && len(toolData.PendingToolResults) > 0 {
prevEnrichment = toolData
logger.Printf("[orchestrator] 加载后台工具结果 %d 条", len(toolData.PendingToolResults))
}
}
}
if prevEnrichment != nil {
logger.Printf("[orchestrator] 加载上一轮富化结果: memory=%t thought=%t iot=%t knowledge=%t tools=%d",
prevEnrichment.MemorySummary != "",
prevEnrichment.ThoughtOutline != "",
prevEnrichment.IoTSummary != "",
prevEnrichment.KnowledgeInfo != "",
len(prevEnrichment.PendingToolResults))
}
// 5. 先构建基础综合参数(不含子会话结果),开始合成
@@ -284,6 +332,7 @@ func (o *Orchestrator) ProcessInput(
synthParams.ThoughtOutline = prevEnrichment.ThoughtOutline
synthParams.IoTSummary = prevEnrichment.IoTSummary
synthParams.KnowledgeInfo = prevEnrichment.KnowledgeInfo
synthParams.PendingToolResults = prevEnrichment.PendingToolResults
}
// 异步收集子会话结果,存入 enrichmentStore 供下一轮使用
@@ -324,7 +373,7 @@ func (o *Orchestrator) ProcessInput(
}()
// 5. 调用 Synthesizer 流式生成最终回复
chunkCh, err := o.synthesizer.Synthesize(ctx, synthParams)
chunkCh, err := o.synthesizer.Synthesize(ctx, synthParams, eventCh)
if err != nil {
logger.Printf("[orchestrator] 综合器启动失败: %v", err)
eventCh <- model.StreamEvent{
@@ -601,6 +650,46 @@ func (o *Orchestrator) CacheMessage(sessionID string, role model.Role, content s
}
}
// preprocessImages uses the vision model to analyze images and augments the user message.
// For standalone images (no text): generates a comprehensive description as the message.
// For text+images: appends image descriptions as contextual annotations.
func (o *Orchestrator) preprocessImages(ctx context.Context, message string, images []string) string {
var prompt string
if message == "" {
prompt = "请详细描述这张图片的内容,包括场景、物体、人物、文字(如有)、颜色、氛围等所有视觉信息。"
} else {
prompt = fmt.Sprintf("用户的问题是:「%s」\n\n请根据用户的问题,分析这张图片中相关的视觉信息,帮助回答用户的问题。如果图片中有文字,请完整提取。", message)
}
var descriptions []string
for i, img := range images {
resp, err := o.visionProvider.Chat(ctx, []model.LLMMessage{
{Role: model.RoleUser, Content: prompt, Images: []string{img}},
})
if err != nil {
logger.Printf("[orchestrator] 图片 %d 预处理失败: %v", i, err)
continue
}
if resp.Content != "" {
descriptions = append(descriptions, resp.Content)
}
}
if len(descriptions) == 0 {
return message
}
if message == "" {
return strings.Join(descriptions, "\n\n")
}
augmented := message
for i, desc := range descriptions {
augmented += fmt.Sprintf("\n\n[图片%d的视觉分析]: %s", i+1, desc)
}
return augmented
}
// Ensure time, memory are used
var _ = time.Now
var _ = memory.NewRetriever
@@ -14,7 +14,7 @@ var codeBlockPattern = regexp.MustCompile("`{3}([^\n]*)\n([\\s\\S]*?)`{3}")
var markdownPatterns = []*regexp.Regexp{
regexp.MustCompile(`^#{1,6}\s`), // headings
regexp.MustCompile(`\*\*[^*]+\*\*`), // bold
regexp.MustCompile(`(?<!\*)\*[^*]+\*(?!\*)`), // italic (single *)
regexp.MustCompile(`(?:^|[^*])\*([^*]+)\*(?:[^*]|$)`), // italic (*text*)
regexp.MustCompile(`\[([^\]]+)\]\(([^\)]+)\)`), // links [text](url)
regexp.MustCompile(`^[\-\*]\s`), // unordered list
regexp.MustCompile(`^\d+\.\s`), // ordered list
@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/yourname/cyrene-ai/ai-core/internal/llm"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
@@ -30,23 +31,25 @@ func NewSynthesizer(llmAdapter *llm.Adapter, toolRegistry *plgManager.ToolRegist
// SynthesizeParams 综合参数
type SynthesizeParams struct {
UserID string
SessionID string
UserMessage string
Images []string // 图片 base64 data URL (多模态)
Nickname string
PersonaPrompt string // 完整人格提示词
DialogHistory []model.LLMMessage // 对话历史
MemorySummary string // 记忆检索摘要
ThoughtOutline string // 通用对话思考
IoTSummary string // IoT 操作摘要
DeviceContext string // 设备状态上下文
KnowledgeInfo string // 知识库检索摘要
Mode string // text / voice_assistant
UserID string
SessionID string
UserMessage string
Images []string // 图片 base64 data URL (多模态)
Nickname string
PersonaPrompt string // 完整人格提示词
DialogHistory []model.LLMMessage // 对话历史
MemorySummary string // 记忆检索摘要
ThoughtOutline string // 通用对话思考
IoTSummary string // IoT 操作摘要
DeviceContext string // 设备状态上下文
KnowledgeInfo string // 知识库检索摘要
PendingToolResults []PendingToolResult // 上一轮异步完成的工具结果
Mode string // text / voice_assistant
}
// Synthesize 综合所有子会话结果,流式生成最终回复
func (s *Synthesizer) Synthesize(ctx context.Context, params SynthesizeParams) (<-chan llm.StreamChunk, error) {
// Synthesize 综合所有子会话结果,流式生成最终回复
// eventCh receives tool progress events; pass nil to suppress.
func (s *Synthesizer) Synthesize(ctx context.Context, params SynthesizeParams, eventCh chan<- model.StreamEvent) (<-chan llm.StreamChunk, error) {
messages := s.buildSynthesizeMessages(params)
logger.Printf("[synthesizer] 开始综合 (上下文 %d 条消息)", len(messages))
@@ -62,7 +65,9 @@ func (s *Synthesizer) Synthesize(ctx context.Context, params SynthesizeParams) (
return nil, err
}
maxRounds := 5
const toolDeadline = 8 * time.Second
const maxRounds = 5
for round := 0; len(resp.ToolCalls) > 0 && round < maxRounds; round++ {
logger.Printf("[synthesizer] LLM 请求 %d 个工具调用 (round=%d)", len(resp.ToolCalls), round)
@@ -80,7 +85,12 @@ func (s *Synthesizer) Synthesize(ctx context.Context, params SynthesizeParams) (
args = make(map[string]interface{})
}
result, execErr := s.toolRegistry.Execute(ctx, tc.Name, args)
s.emitToolProgress(eventCh, tc.Name, "started", 0, "正在执行 "+tc.Name)
toolCtx, cancel := context.WithTimeout(ctx, toolDeadline)
result, execErr := s.toolRegistry.Execute(toolCtx, tc.Name, args)
cancel()
if execErr != nil {
logger.Printf("[synthesizer] 工具 %s 执行失败: %v", tc.Name, execErr)
}
@@ -88,6 +98,19 @@ func (s *Synthesizer) Synthesize(ctx context.Context, params SynthesizeParams) (
result = &plgSDK.ToolResult{ToolName: tc.Name, Success: false, Error: execErr.Error()}
}
// Async fallback: if tool timed out, store for next turn
if toolCtx.Err() == context.DeadlineExceeded {
s.emitToolProgress(eventCh, tc.Name, "running", 0.5, tc.Name+" 执行时间较长,转入后台继续...")
go s.executeAsyncAndStore(tc, args, params.SessionID, eventCh)
result = &plgSDK.ToolResult{
ToolName: tc.Name,
Success: true,
Output: fmt.Sprintf("[后台执行中] %s 正在后台运行,结果将在下一轮对话中返回。你可以继续聊天。", tc.Name),
}
} else {
s.emitToolProgress(eventCh, tc.Name, "completed", 1.0, "")
}
resultJSON, _ := json.Marshal(result)
messages = append(messages, model.LLMMessage{
Role: model.RoleTool,
@@ -120,6 +143,51 @@ func (s *Synthesizer) Synthesize(ctx context.Context, params SynthesizeParams) (
return ch, nil
}
// emitToolProgress sends a StreamToolProgress event if eventCh is available.
func (s *Synthesizer) emitToolProgress(eventCh chan<- model.StreamEvent, name, status string, progress float64, message string) {
if eventCh == nil {
return
}
select {
case eventCh <- model.StreamEvent{
Type: model.StreamToolProgress,
ToolProgress: &model.ToolProgressInfo{
ToolName: name,
Status: status,
Progress: progress,
Message: message,
},
}:
default:
}
}
// executeAsyncAndStore runs a tool in background and stores the result for the next turn.
func (s *Synthesizer) executeAsyncAndStore(tc model.ToolCall, args map[string]interface{}, sessionID string, eventCh chan<- model.StreamEvent) {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
result, err := s.toolRegistry.Execute(ctx, tc.Name, args)
if err != nil {
logger.Printf("[synthesizer] 后台工具 %s 执行失败: %v", tc.Name, err)
s.emitToolProgress(eventCh, tc.Name, "failed", 1.0, tc.Name+" 后台执行失败: "+err.Error())
return
}
s.emitToolProgress(eventCh, tc.Name, "completed", 1.0, tc.Name+" 后台执行完成")
resultJSON, _ := json.Marshal(result)
store := GetGlobalPendingToolStore()
if store != nil {
store.AppendToolResult(sessionID, PendingToolResult{
ToolCallID: tc.ID,
ToolName: tc.Name,
Result: string(resultJSON),
Success: result != nil && result.Success,
})
}
}
// buildSynthesizeMessages 构建综合用的 LLM 消息列表
func (s *Synthesizer) buildSynthesizeMessages(params SynthesizeParams) []model.LLMMessage {
var messages []model.LLMMessage
@@ -174,6 +242,23 @@ func (s *Synthesizer) buildSynthesizeMessages(params SynthesizeParams) []model.L
})
}
// 注入上一轮异步工具执行结果
if len(params.PendingToolResults) > 0 {
var sb strings.Builder
sb.WriteString("【上一轮后台工具执行结果】\n")
for _, ptr := range params.PendingToolResults {
status := "成功"
if !ptr.Success {
status = "失败"
}
sb.WriteString(fmt.Sprintf("- %s (%s): %s\n", ptr.ToolName, status, ptr.Result))
}
messages = append(messages, model.LLMMessage{
Role: model.RoleSystem,
Content: sb.String(),
})
}
// 注入对话历史
if len(params.DialogHistory) > 0 {
messages = append(messages, params.DialogHistory...)
@@ -2,12 +2,15 @@ package subsession
import (
"context"
"encoding/json"
"fmt"
"github.com/yourname/cyrene-ai/pkg/logger"
"strings"
"time"
"github.com/yourname/cyrene-ai/ai-core/internal/llm"
"github.com/yourname/cyrene-ai/ai-core/internal/memory"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
"github.com/yourname/cyrene-ai/pkg/logger"
)
// MemoryRetriever 记忆检索接口
@@ -16,9 +19,12 @@ type MemoryRetriever interface {
}
// MemoryProvider 记忆检索子会话提供者
// 职责:检索与当前对话相关的用户记忆,排序去重,返回结构化摘要
// 职责:检索与当前对话相关的用户记忆,排序去重,返回结构化摘要
// 支持 LLM 驱动的模糊关键词扩展搜索。
type MemoryProvider struct {
retriever MemoryRetriever
retriever MemoryRetriever
llmAdapter *llm.Adapter
memClient *memory.Client
}
// NewMemoryProvider 创建记忆检索子会话提供者
@@ -28,6 +34,12 @@ func NewMemoryProvider(retriever MemoryRetriever) *MemoryProvider {
}
}
// SetFuzzySearch enables LLM-driven fuzzy keyword expansion for broader memory retrieval.
func (p *MemoryProvider) SetFuzzySearch(llmAdapter *llm.Adapter, memClient *memory.Client) {
p.llmAdapter = llmAdapter
p.memClient = memClient
}
func (p *MemoryProvider) Type() model.SubSessionType {
return model.SubSessionMemory
}
@@ -93,6 +105,7 @@ func (p *MemoryProvider) Execute(ctx context.Context, subCtx []model.LLMMessage)
return result, nil
}
// Phase 1: exact/keyword retrieval
memories, err := p.retriever.Retrieve(ctx, userID, userMessage)
if err != nil {
logger.Printf("[memory-subsession] 记忆检索失败: %v", err)
@@ -101,6 +114,20 @@ func (p *MemoryProvider) Execute(ctx context.Context, subCtx []model.LLMMessage)
return result, nil
}
seen := make(map[string]bool)
for _, m := range memories {
seen[m.ID] = true
}
// Phase 2: LLM-driven fuzzy keyword expansion + semantic search
fuzzyMemories := p.fuzzySearch(ctx, userID, userMessage)
for _, m := range fuzzyMemories {
if !seen[m.ID] {
seen[m.ID] = true
memories = append(memories, m)
}
}
// 转换为 MemorySnippet
snippets := make([]model.MemorySnippet, 0, len(memories))
for _, m := range memories {
@@ -117,7 +144,7 @@ func (p *MemoryProvider) Execute(ctx context.Context, subCtx []model.LLMMessage)
if len(snippets) == 0 {
result.Summary = "(没有找到相关记忆)"
} else {
result.Summary = fmt.Sprintf("检索到 %d 条相关记忆", len(snippets))
result.Summary = fmt.Sprintf("检索到 %d 条相关记忆(含模糊匹配)", len(snippets))
// 按重要性列出前几条
topCount := len(snippets)
if topCount > 3 {
@@ -138,6 +165,74 @@ func (p *MemoryProvider) Execute(ctx context.Context, subCtx []model.LLMMessage)
}
result.Memories = snippets
logger.Printf("[memory-subsession] 完成: %s", result.Summary)
logger.Printf("[memory-subsession] 完成: %s (精确=%d, 模糊=%d)", result.Summary, len(memories)-len(fuzzyMemories), len(fuzzyMemories))
return result, nil
}
// fuzzySearch expands the user message into fuzzy keywords via LLM and performs semantic search.
func (p *MemoryProvider) fuzzySearch(ctx context.Context, userID, userMessage string) []memory.MemoryEntry {
if p.llmAdapter == nil || p.memClient == nil {
return nil
}
keywords := p.expandKeywords(ctx, userMessage)
if len(keywords) == 0 {
return nil
}
logger.Printf("[memory-subsession] 模糊关键词: %v", keywords)
var allResults []memory.MemoryEntry
seen := make(map[string]bool)
for _, kw := range keywords {
results, err := p.memClient.QueryByText(ctx, userID, kw, "", 0, 5)
if err != nil {
logger.Printf("[memory-subsession] 模糊搜索 '%s' 失败: %v", kw, err)
continue
}
for _, m := range results {
if !seen[m.ID] {
seen[m.ID] = true
allResults = append(allResults, m)
}
}
}
return allResults
}
// expandKeywords uses LLM to generate fuzzy/related search keywords from the user message.
func (p *MemoryProvider) expandKeywords(ctx context.Context, message string) []string {
prompt := fmt.Sprintf(
"从以下对话消息中提取 3-5 个可用于模糊搜索记忆的关键词。这些关键词应该是:\n"+
"- 与话题相关的抽象概念\n- 同义词和相关词\n- 更宽泛或更具体的相关概念\n"+
"- 不要包含消息中已经出现的原词\n\n"+
"用户消息:「%s」\n\n"+
"只输出 JSON 字符串数组,例如:[\"关键词1\",\"关键词2\"]", message)
resp, err := p.llmAdapter.Chat(ctx, []model.LLMMessage{
{Role: model.RoleSystem, Content: "你是记忆搜索专家。输出 JSON 字符串数组。"},
{Role: model.RoleUser, Content: prompt},
})
if err != nil {
logger.Printf("[memory-subsession] 关键词扩展失败: %v", err)
return nil
}
text := strings.TrimSpace(resp.Content)
// Extract JSON array
if idx := strings.Index(text, "["); idx >= 0 {
if end := strings.LastIndex(text, "]"); end > idx {
text = text[idx : end+1]
}
}
var keywords []string
if err := json.Unmarshal([]byte(text), &keywords); err != nil {
logger.Printf("[memory-subsession] 解析关键词 JSON 失败: %v (raw=%s)", err, resp.Content)
return nil
}
return keywords
}
@@ -0,0 +1,217 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/yourname/cyrene-ai/ai-core/internal/host"
)
// OSExecTool allows the AI to execute arbitrary commands in a full OS
// environment (WSL or Docker container). Unlike host_exec which runs in
// a restricted sandbox, this provides unrestricted OS access.
type OSExecTool struct {
manager *host.Manager
}
// NewOSExecTool creates a new OS exec tool for full OS command execution.
func NewOSExecTool(manager *host.Manager) *OSExecTool {
return &OSExecTool{manager: manager}
}
func (t *OSExecTool) Definition() ToolDefinition {
return ToolDefinition{
Name: "os_exec",
Description: "在完整的操作系统环境(WSL/Docker容器)中执行任意命令。适用于复杂操作:安装软件包、编译大型项目、运行脚本、管理服务等。拥有完整的Linux系统权限,无命令限制。日常简单操作请使用 host_exec。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"command": map[string]interface{}{
"type": "string",
"description": "要执行的命令,例如 'pip install pandas && python analyze.py' 或 'apt-get update && apt-get install -y ffmpeg'",
},
"work_dir": map[string]interface{}{
"type": "string",
"description": "工作目录。不指定则使用默认目录。",
},
"timeout_sec": map[string]interface{}{
"type": "integer",
"description": "超时时间(秒),默认30秒,最大300秒。复杂任务请设置更长的超时。",
},
},
"required": []string{"command"},
},
}
}
func (t *OSExecTool) Execute(ctx context.Context, args map[string]interface{}) (*ToolResult, error) {
cmd, _ := args["command"].(string)
if cmd == "" {
return &ToolResult{
ToolName: "os_exec",
Success: false,
Error: "command 参数不能为空",
}, nil
}
workDir, _ := args["work_dir"].(string)
timeoutSec := 60 // Default longer timeout for complex operations
if v, ok := args["timeout_sec"].(float64); ok {
timeoutSec = int(v)
}
timeout := time.Duration(timeoutSec) * time.Second
result, err := t.manager.Exec(ctx, cmd, workDir, timeout)
if err != nil && result == nil {
return &ToolResult{
ToolName: "os_exec",
Success: false,
Error: err.Error(),
}, nil
}
data, _ := json.Marshal(map[string]interface{}{
"command": cmd,
"backend": t.manager.BackendName(),
"exit_code": result.ExitCode,
"duration": result.Duration,
"timed_out": result.TimedOut,
"stdout": result.Stdout,
"stderr": result.Stderr,
})
success := result.ExitCode == 0 && !result.TimedOut
return &ToolResult{
ToolName: "os_exec",
Success: success,
Data: string(data),
}, nil
}
// OSFileTool provides unrestricted file system access within the OS environment.
type OSFileTool struct {
manager *host.Manager
}
// NewOSFileTool creates a new OS file tool for full OS file operations.
func NewOSFileTool(manager *host.Manager) *OSFileTool {
return &OSFileTool{manager: manager}
}
func (t *OSFileTool) Definition() ToolDefinition {
return ToolDefinition{
Name: "os_file",
Description: "在完整OS环境中读写文件。支持在整个文件系统中自由操作:读取/写入/列出文件,无目录限制。适用于批量文件处理、日志分析、配置文件管理等复杂文件操作。日常简单文件操作请使用 host_file。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"description": "操作类型: read, write, list",
"enum": []string{"read", "write", "list"},
},
"path": map[string]interface{}{
"type": "string",
"description": "文件或目录路径",
},
"content": map[string]interface{}{
"type": "string",
"description": "写入内容 (仅 write 操作需要)",
},
},
"required": []string{"action", "path"},
},
}
}
func (t *OSFileTool) Execute(ctx context.Context, args map[string]interface{}) (*ToolResult, error) {
action, _ := args["action"].(string)
path, _ := args["path"].(string)
if action == "" || path == "" {
return &ToolResult{
ToolName: "os_file",
Success: false,
Error: "action 和 path 参数不能为空",
}, nil
}
switch action {
case "read":
content, err := t.manager.ReadFile(path, 1024*1024)
if err != nil {
return &ToolResult{ToolName: "os_file", Success: false, Error: err.Error()}, nil
}
data, _ := json.Marshal(map[string]interface{}{
"path": path,
"content": content,
"size": len(content),
})
return &ToolResult{ToolName: "os_file", Success: true, Data: string(data)}, nil
case "write":
content, _ := args["content"].(string)
if err := t.manager.WriteFile(path, content, 1024*1024); err != nil {
return &ToolResult{ToolName: "os_file", Success: false, Error: err.Error()}, nil
}
data, _ := json.Marshal(map[string]interface{}{
"path": path,
"written": len(content),
"status": "ok",
})
return &ToolResult{ToolName: "os_file", Success: true, Data: string(data)}, nil
case "list":
entries, err := t.manager.ListDir(path)
if err != nil {
return &ToolResult{ToolName: "os_file", Success: false, Error: err.Error()}, nil
}
data, _ := json.Marshal(map[string]interface{}{
"path": path,
"entries": entries,
"count": len(entries),
})
return &ToolResult{ToolName: "os_file", Success: true, Data: string(data)}, nil
default:
return &ToolResult{ToolName: "os_file", Success: false, Error: fmt.Sprintf("不支持的操作: %s", action)}, nil
}
}
// OSSystemTool provides OS-level system information.
type OSSystemTool struct {
manager *host.Manager
}
// NewOSSystemTool creates a new OS system info tool.
func NewOSSystemTool(manager *host.Manager) *OSSystemTool {
return &OSSystemTool{manager: manager}
}
func (t *OSSystemTool) Definition() ToolDefinition {
return ToolDefinition{
Name: "os_system",
Description: "获取完整OS环境的系统信息,包括操作系统详情、CPU架构、内存使用、磁盘空间等。与 host_system 不同,此工具返回的是WSL/容器内的完整Linux系统信息。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "查询类型: info(完整信息), memory(内存), cpu(CPU), disk(磁盘)",
"enum": []string{"info", "memory", "cpu", "disk"},
},
},
},
}
}
func (t *OSSystemTool) Execute(ctx context.Context, args map[string]interface{}) (*ToolResult, error) {
info := t.manager.SystemInfo()
data, _ := json.Marshal(info)
return &ToolResult{
ToolName: "os_system",
Success: true,
Data: string(data),
}, nil
}
@@ -0,0 +1,102 @@
package tools
import (
"context"
"os"
"strings"
"testing"
"github.com/yourname/cyrene-ai/ai-core/internal/host"
)
func TestOSExecToolWSL(t *testing.T) {
distro := os.Getenv("WSL_DISTRO")
if distro == "" {
t.Skip("WSL_DISTRO not set, skipping OS tool integration test")
}
backend := host.NewWSLBackend(distro, "cyrene", "test123", 30e9)
mgr := host.NewManager(backend)
// Test os_exec
t.Run("os_exec", func(t *testing.T) {
tool := NewOSExecTool(mgr)
def := tool.Definition()
if def.Name != "os_exec" {
t.Fatalf("unexpected name: %s", def.Name)
}
result, err := tool.Execute(context.Background(), map[string]interface{}{
"command": "echo 'os_exec works!' && uname -a",
})
if err != nil {
t.Fatalf("execute error: %v", err)
}
if !result.Success {
t.Fatalf("exec failed: %s", result.Error)
}
if !strings.Contains(result.Data, "os_exec works!") {
t.Fatalf("unexpected output: %s", result.Data)
}
t.Logf("os_exec OK: data len=%d", len(result.Data))
})
// Test os_file
t.Run("os_file", func(t *testing.T) {
tool := NewOSFileTool(mgr)
def := tool.Definition()
if def.Name != "os_file" {
t.Fatalf("unexpected name: %s", def.Name)
}
// Write
r, err := tool.Execute(context.Background(), map[string]interface{}{
"action": "write",
"path": "/tmp/cyrene-os-tool-test.txt",
"content": "OS tool integration test",
})
if err != nil || !r.Success {
t.Fatalf("os_file write failed: err=%v, errMsg=%s", err, r.Error)
}
// Read
r, err = tool.Execute(context.Background(), map[string]interface{}{
"action": "read",
"path": "/tmp/cyrene-os-tool-test.txt",
})
if err != nil || !r.Success {
t.Fatalf("os_file read failed: err=%v, errMsg=%s", err, r.Error)
}
if !strings.Contains(r.Data, "OS tool integration test") {
t.Fatalf("content mismatch: %s", r.Data)
}
// List
r, err = tool.Execute(context.Background(), map[string]interface{}{
"action": "list",
"path": "/tmp",
})
if err != nil || !r.Success {
t.Fatalf("os_file list failed: err=%v, errMsg=%s", err, r.Error)
}
t.Logf("os_file OK: write+read+list all pass")
})
// Test os_system
t.Run("os_system", func(t *testing.T) {
tool := NewOSSystemTool(mgr)
def := tool.Definition()
if def.Name != "os_system" {
t.Fatalf("unexpected name: %s", def.Name)
}
result, err := tool.Execute(context.Background(), map[string]interface{}{})
if err != nil {
t.Fatalf("execute error: %v", err)
}
if !result.Success {
t.Fatalf("os_system failed: %s", result.Error)
}
if !strings.Contains(result.Data, "wsl") {
t.Fatalf("expected wsl backend info: %s", result.Data)
}
t.Logf("os_system OK: data len=%d", len(result.Data))
})
}
@@ -13,7 +13,7 @@ func TestHostExecToolDefinition(t *testing.T) {
cfg := host.DefaultSandboxConfig()
cfg.AllowedDirs = []string{os.TempDir()}
sandbox := host.NewSandbox(cfg)
mgr := host.NewManager(sandbox)
mgr := host.NewManager(host.NewDirectBackend(sandbox))
tool := NewHostExecTool(mgr)
def := tool.Definition()
@@ -40,7 +40,7 @@ func TestHostFileToolDefinition(t *testing.T) {
tmpDir := os.TempDir()
cfg.AllowedDirs = []string{tmpDir}
sandbox := host.NewSandbox(cfg)
mgr := host.NewManager(sandbox)
mgr := host.NewManager(host.NewDirectBackend(sandbox))
mgr.SetAllowedDirs([]string{tmpDir})
tool := NewHostFileTool(mgr)
@@ -67,7 +67,7 @@ func TestHostFileToolDefinition(t *testing.T) {
func TestHostSystemToolDefinition(t *testing.T) {
cfg := host.DefaultSandboxConfig()
sandbox := host.NewSandbox(cfg)
mgr := host.NewManager(sandbox)
mgr := host.NewManager(host.NewDirectBackend(sandbox))
tool := NewHostSystemTool(mgr)
def := tool.Definition()
@@ -40,7 +40,7 @@ func TestEncodeImageToDataURL_InvalidPath(t *testing.T) {
}
func TestVisionToolDefinition(t *testing.T) {
tool := NewVisionTool()
tool := NewVisionTool(nil)
def := tool.Definition()
if def.Name != "vision_analyze" {
t.Fatalf("unexpected tool name: %s", def.Name)
@@ -68,7 +68,7 @@ func TestVisionToolExecute(t *testing.T) {
}
defer os.Remove(tmpPath)
tool := NewVisionTool()
tool := NewVisionTool(nil)
ctx := context.Background()
result, err := tool.Execute(ctx, map[string]interface{}{
"image_path": tmpPath,