package subsession import ( "context" "crypto/rand" "fmt" "log" "sync" "github.com/yourname/cyrene-ai/ai-core/internal/llm" "github.com/yourname/cyrene-ai/ai-core/internal/model" ) // Manager 子会话管理器 // 负责注册 Provider、分派任务、并行执行、超时控制、结果收集 type Manager struct { mu sync.RWMutex providers map[model.SubSessionType]Provider llmClient LLMClient } // NewManager 创建子会话管理器 func NewManager(llmClient LLMClient) *Manager { return &Manager{ providers: make(map[model.SubSessionType]Provider), llmClient: llmClient, } } // Register 注册子会话提供者 func (m *Manager) Register(provider Provider) { m.mu.Lock() defer m.mu.Unlock() m.providers[provider.Type()] = provider log.Printf("[subsession] 注册子会话提供者: %s (优先级=%d, 超时=%v)", provider.Type(), provider.Priority(), provider.Timeout()) } // RegisterWithOverride 注册或覆盖子会话提供者 func (m *Manager) RegisterWithOverride(provider Provider) { m.mu.Lock() defer m.mu.Unlock() m.providers[provider.Type()] = provider log.Printf("[subsession] 注册(覆盖)子会话提供者: %s (优先级=%d, 超时=%v)", provider.Type(), provider.Priority(), provider.Timeout()) } // GetProvider 获取指定类型的 Provider func (m *Manager) GetProvider(t model.SubSessionType) (Provider, bool) { m.mu.RLock() defer m.mu.RUnlock() p, ok := m.providers[t] return p, ok } // ListProviders 列出所有已注册的 Provider 类型 func (m *Manager) ListProviders() []model.SubSessionType { m.mu.RLock() defer m.mu.RUnlock() types := make([]model.SubSessionType, 0, len(m.providers)) for t := range m.providers { types = append(types, t) } return types } // Dispatch 分派任务到子会话,并行执行,返回结果通道 func (m *Manager) Dispatch( ctx context.Context, intent *model.IntentResult, userMessage string, params CreateContextParams, ) <-chan model.SubSessionResult { m.mu.RLock() providers := make([]Provider, 0, len(m.providers)) for _, p := range m.providers { providers = append(providers, p) } m.mu.RUnlock() resultCh := make(chan model.SubSessionResult, len(providers)) var wg sync.WaitGroup for _, provider := range providers { if !provider.CanHandle(ctx, intent, userMessage) { log.Printf("[subsession] 跳过子会话 %s: CanHandle 返回 false", provider.Type()) continue } wg.Add(1) go func(p Provider) { defer wg.Done() defer func() { if r := recover(); r != nil { log.Printf("[subsession] dispatch goroutine panic 恢复 (type=%s): %v", p.Type(), r) } }() result := model.SubSessionResult{Type: p.Type()} // 创建带超时的 context subCtx, cancel := context.WithTimeout(ctx, p.Timeout()) defer cancel() // 构建 LLM 上下文 llmMessages, err := p.CreateContext(subCtx, params) if err != nil { result.Error = fmt.Sprintf("创建上下文失败: %v", err) log.Printf("[subsession] %s 创建上下文失败: %v", p.Type(), err) resultCh <- result return } log.Printf("[subsession] %s 开始执行 (上下文 %d 条消息)", p.Type(), len(llmMessages)) // 执行子会话 subResult, execErr := p.Execute(subCtx, llmMessages) if execErr != nil { result.Error = fmt.Sprintf("执行失败: %v", execErr) log.Printf("[subsession] %s 执行失败: %v", p.Type(), execErr) resultCh <- result return } // 检查超时 select { case <-subCtx.Done(): result.Error = "子会话超时" log.Printf("[subsession] %s 超时 (limit=%v)", p.Type(), p.Timeout()) default: if subResult != nil { result = *subResult result.Type = p.Type() log.Printf("[subsession] %s 完成: 摘要=%s", p.Type(), truncate(result.Summary, 50)) } } resultCh <- result }(provider) } // 等待所有子会话完成,关闭通道 go func() { defer func() { if r := recover(); r != nil { log.Printf("[subsession] wait goroutine panic 恢复: %v", r) } }() wg.Wait() close(resultCh) }() return resultCh } // generateID 生成随机 ID func generateID() string { b := make([]byte, 12) rand.Read(b) return fmt.Sprintf("sub-%x", b) } // truncate 截断字符串 func truncate(s string, maxLen int) string { runes := []rune(s) if len(runes) <= maxLen { return s } return string(runes[:maxLen]) + "..." } // Ensure llm is used var _ = llm.NewAdapter