fix: Phase 6联调 — 插件管理器端口修正 + 多模型配置系统整合 + 历史消息刷新修复
## 调试日志
### 1. 插件管理器启动失败
- **症状**: DevTools 显示插件管理器一直"已停止",手动启动正常
- **排查**: 对比 process-manager.js 传入的环境变量 vs plugin-manager config.go 读取的变量
- **根因**: config.js 传入 PLUGIN_MANAGER_PORT=8094,但 config.go 读取 os.Getenv("PORT"),env 名不匹配。且 process.env 中 PORT 泄露时被误读为 9090,与 DevTools 端口冲突
- **修复**: config.js 将 PLUGIN_MANAGER_PORT → PORT,使 env 名与代码一致 (c3055f4)
### 2. 历史消息刷新后消失
- **症状**: 浏览器刷新后聊天历史清空
- **排查**: WebSocket history_response handler 中 if (msg.messages) 对空数组 [] 为 truthy
- **根因**: 后端返回空的 history_response (缓存为空) 时,空数组覆盖了 HTTP 已加载的消息
- **修复**: useWebSocket.ts 改为 if (msg.messages && msg.messages.length > 0),空数组走 else-if 分支仅打日志,不覆盖已有消息
### 3. Phase 6 多模型配置系统
- Gateway: ModelsConfigStore (JSON文件持久化) + Admin CRUD API (providers/models/routing)
- ai-core: ModelSelector 支持按 purpose 选择 + fallback_chain,无配置时回退 .env
- DevTools: 模型配置管理面板 (Providers/Models/Routing 三Tab)、在线模型查询代理、路由表单 checkbox 多选、关键词搜索过滤
- .gitignore: models.json + platform_configs.json
### 4. 多端客户端追踪
- Hub 新增 knownClients 映射 (clientID → KnownClient),在线/离线状态追踪
- 客户端备注持久化到 PostgreSQL
- DevTools 客户端管理面板
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,234 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ModelsConfigStore manages persistence of model configuration to a JSON file.
|
||||
type ModelsConfigStore struct {
|
||||
mu sync.RWMutex
|
||||
path string
|
||||
config *ModelsConfig
|
||||
}
|
||||
|
||||
// NewModelsConfigStore creates a ModelsConfigStore, creating an empty config file if it doesn't exist.
|
||||
func NewModelsConfigStore(path string) (*ModelsConfigStore, error) {
|
||||
s := &ModelsConfigStore{
|
||||
path: path,
|
||||
config: &ModelsConfig{
|
||||
Version: "1.0",
|
||||
Providers: make(map[string]*ProviderConfig),
|
||||
Models: make(map[string]*ModelConfig),
|
||||
Routing: make(map[string]*RoutingRule),
|
||||
},
|
||||
}
|
||||
if err := s.load(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *ModelsConfigStore) load() error {
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return s.save() // Initialize empty file.
|
||||
}
|
||||
return fmt.Errorf("read model config file: %w", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
var cfg ModelsConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("parse model config file: %w", err)
|
||||
}
|
||||
if cfg.Providers == nil {
|
||||
cfg.Providers = make(map[string]*ProviderConfig)
|
||||
}
|
||||
if cfg.Models == nil {
|
||||
cfg.Models = make(map[string]*ModelConfig)
|
||||
}
|
||||
if cfg.Routing == nil {
|
||||
cfg.Routing = make(map[string]*RoutingRule)
|
||||
}
|
||||
if cfg.Version == "" {
|
||||
cfg.Version = "1.0"
|
||||
}
|
||||
s.config = &cfg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ModelsConfigStore) save() error {
|
||||
data, err := json.MarshalIndent(s.config, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal model config: %w", err)
|
||||
}
|
||||
tmpPath := s.path + ".tmp"
|
||||
if err := os.WriteFile(tmpPath, data, 0640); err != nil {
|
||||
return fmt.Errorf("write model config file: %w", err)
|
||||
}
|
||||
return os.Rename(tmpPath, s.path)
|
||||
}
|
||||
|
||||
// HasConfig returns true if there are any providers or models configured.
|
||||
func (s *ModelsConfigStore) HasConfig() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.config.Providers) > 0 || len(s.config.Models) > 0
|
||||
}
|
||||
|
||||
// ---- Providers ----
|
||||
|
||||
func (s *ModelsConfigStore) ListProviders() []*ProviderConfig {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
result := make([]*ProviderConfig, 0, len(s.config.Providers))
|
||||
for _, p := range s.config.Providers {
|
||||
result = append(result, p)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *ModelsConfigStore) GetProvider(name string) (*ProviderConfig, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
p, ok := s.config.Providers[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("provider not found: %s", name)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (s *ModelsConfigStore) SetProvider(cfg *ProviderConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if cfg.Name == "" {
|
||||
return fmt.Errorf("provider name is required")
|
||||
}
|
||||
if cfg.BaseURL == "" {
|
||||
return fmt.Errorf("provider base_url is required")
|
||||
}
|
||||
cfg.UpdatedAt = time.Now()
|
||||
s.config.Providers[cfg.Name] = cfg
|
||||
return s.save()
|
||||
}
|
||||
|
||||
func (s *ModelsConfigStore) DeleteProvider(name string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, ok := s.config.Providers[name]; !ok {
|
||||
return fmt.Errorf("provider not found: %s", name)
|
||||
}
|
||||
delete(s.config.Providers, name)
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// ---- Models ----
|
||||
|
||||
func (s *ModelsConfigStore) ListModels() []*ModelConfig {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
result := make([]*ModelConfig, 0, len(s.config.Models))
|
||||
for _, m := range s.config.Models {
|
||||
result = append(result, m)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *ModelsConfigStore) GetModel(id string) (*ModelConfig, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
m, ok := s.config.Models[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model not found: %s", id)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (s *ModelsConfigStore) SetModel(cfg *ModelConfig) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if cfg.ID == "" {
|
||||
return fmt.Errorf("model id is required")
|
||||
}
|
||||
if cfg.Provider == "" {
|
||||
return fmt.Errorf("model provider is required")
|
||||
}
|
||||
cfg.UpdatedAt = time.Now()
|
||||
if cfg.Params == nil {
|
||||
cfg.Params = make(map[string]interface{})
|
||||
}
|
||||
if cfg.Tags == nil {
|
||||
cfg.Tags = []string{}
|
||||
}
|
||||
s.config.Models[cfg.ID] = cfg
|
||||
return s.save()
|
||||
}
|
||||
|
||||
func (s *ModelsConfigStore) DeleteModel(id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, ok := s.config.Models[id]; !ok {
|
||||
return fmt.Errorf("model not found: %s", id)
|
||||
}
|
||||
delete(s.config.Models, id)
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// ---- Routing ----
|
||||
|
||||
func (s *ModelsConfigStore) ListRouting() []*RoutingRule {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
result := make([]*RoutingRule, 0, len(s.config.Routing))
|
||||
for _, r := range s.config.Routing {
|
||||
result = append(result, r)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *ModelsConfigStore) GetRouting(purpose string) (*RoutingRule, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
r, ok := s.config.Routing[purpose]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("routing not found: %s", purpose)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (s *ModelsConfigStore) SetRouting(rule *RoutingRule) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if rule.Purpose == "" {
|
||||
return fmt.Errorf("routing purpose is required")
|
||||
}
|
||||
if len(rule.FallbackChain) == 0 {
|
||||
return fmt.Errorf("routing fallback_chain is required")
|
||||
}
|
||||
s.config.Routing[rule.Purpose] = rule
|
||||
return s.save()
|
||||
}
|
||||
|
||||
func (s *ModelsConfigStore) DeleteRouting(purpose string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if _, ok := s.config.Routing[purpose]; !ok {
|
||||
return fmt.Errorf("routing not found: %s", purpose)
|
||||
}
|
||||
delete(s.config.Routing, purpose)
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// GetConfig returns a copy of the full config (for ai-core loader compatibility).
|
||||
func (s *ModelsConfigStore) GetConfig() *ModelsConfig {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
// Return shallow copy; callers should treat as read-only.
|
||||
return s.config
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package config
|
||||
|
||||
import "time"
|
||||
|
||||
// ProviderConfig defines an LLM service provider (e.g. deepseek, openai).
|
||||
type ProviderConfig struct {
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
TimeoutSec int `json:"timeout_sec"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
APIVersion string `json:"api_version,omitempty"`
|
||||
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ModelConfig defines a specific model under a provider.
|
||||
type ModelConfig struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Description string `json:"description"`
|
||||
Priority int `json:"priority"`
|
||||
Tags []string `json:"tags"`
|
||||
Params map[string]interface{} `json:"params"`
|
||||
Enabled bool `json:"enabled"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// RoutingRule maps a purpose to an ordered fallback chain of model IDs.
|
||||
type RoutingRule struct {
|
||||
Purpose string `json:"purpose"`
|
||||
FallbackChain []string `json:"fallback_chain"`
|
||||
Required bool `json:"required"`
|
||||
}
|
||||
|
||||
// ModelsConfig is the top-level configuration document.
|
||||
type ModelsConfig struct {
|
||||
Version string `json:"version"`
|
||||
Providers map[string]*ProviderConfig `json:"providers"`
|
||||
Models map[string]*ModelConfig `json:"models"`
|
||||
Routing map[string]*RoutingRule `json:"routing"`
|
||||
}
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"github.com/yourname/cyrene-ai/pkg/logger"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -50,6 +50,9 @@ func (h *ChatHandler) HandleWebSocket(c *gin.Context) {
|
||||
// 从query参数获取token和session_id
|
||||
token := c.Query("token")
|
||||
sessionID := c.Query("session_id")
|
||||
clientID := c.Query("client_id")
|
||||
deviceName := c.Query("device_name")
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
if token == "" {
|
||||
// 也尝试从Authorization头读取
|
||||
@@ -93,7 +96,7 @@ func (h *ChatHandler) HandleWebSocket(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 创建客户端
|
||||
client := ws.NewClient(h.hub, conn, userID, sessionID)
|
||||
client := ws.NewClient(h.hub, conn, userID, sessionID, clientID, deviceName, userAgent)
|
||||
|
||||
// 注册到Hub
|
||||
h.hub.Register(client)
|
||||
@@ -128,7 +131,7 @@ func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage)
|
||||
|
||||
// 持久化用户消息到数据库(在 WebSocket 发送之前)
|
||||
if h.sessionStore != nil && h.sessionStore.IsAvailable() {
|
||||
if err := h.sessionStore.AddMessage(client.SessionID, "user", "chat", msg.Content); err != nil {
|
||||
if err := h.sessionStore.AddMessage(client.SessionID, "user", "chat", msg.Content, client.ClientID); err != nil {
|
||||
logger.Printf("[chat] 持久化用户消息失败: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -168,6 +171,10 @@ func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage)
|
||||
Role: "user",
|
||||
Content: msg.Content,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
ClientInfo: &ws.ClientInfo{
|
||||
ClientID: client.ClientID,
|
||||
DeviceName: client.DeviceName,
|
||||
},
|
||||
}
|
||||
if len(msg.Attachments) > 0 {
|
||||
userMsg.Attachments = msg.Attachments
|
||||
@@ -229,13 +236,13 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
// 增大 scanner buffer 以处理大块 SSE 数据
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
|
||||
// 通知前端 AI 开始生成回复
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "stream_start",
|
||||
MessageID: "msg_" + generateID(),
|
||||
SessionID: client.SessionID,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
// 通知前端 AI 开始生成回复
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "stream_start",
|
||||
MessageID: "msg_" + generateID(),
|
||||
SessionID: client.SessionID,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
|
||||
var fullText string
|
||||
var msgID string
|
||||
@@ -312,24 +319,30 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
reviewMsgID := fmt.Sprintf("%s_r%d", msgID, i)
|
||||
// 持久化每条审查消息
|
||||
if h.sessionStore != nil && h.sessionStore.IsAvailable() {
|
||||
if err := h.sessionStore.AddMessage(client.SessionID, role, msgType, rm.Content); err != nil {
|
||||
if err := h.sessionStore.AddMessage(client.SessionID, role, msgType, rm.Content, client.ClientID); err != nil {
|
||||
logger.Printf("[chat] 持久化审查消息失败: %v", err)
|
||||
}
|
||||
}
|
||||
clientInfo := &ws.ClientInfo{
|
||||
ClientID: client.ClientID,
|
||||
DeviceName: client.DeviceName,
|
||||
}
|
||||
h.hub.CacheMessage(client.UserID, client.SessionID, ws.Message{
|
||||
ID: reviewMsgID,
|
||||
Role: role,
|
||||
Content: rm.Content,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
ID: reviewMsgID,
|
||||
Role: role,
|
||||
Content: rm.Content,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
ClientInfo: clientInfo,
|
||||
})
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "response",
|
||||
MessageID: reviewMsgID,
|
||||
Content: rm.Content,
|
||||
Role: role,
|
||||
MsgType: msgType,
|
||||
SessionID: client.SessionID,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
Type: "response",
|
||||
MessageID: reviewMsgID,
|
||||
Content: rm.Content,
|
||||
Role: role,
|
||||
MsgType: msgType,
|
||||
SessionID: client.SessionID,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
ClientInfo: clientInfo,
|
||||
})
|
||||
// 使用 MessageScheduler 计算的 per-message 延迟
|
||||
if rm.DelayMs > 0 {
|
||||
@@ -416,7 +429,7 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
// 如果有审查消息,每条已单独持久化,跳过 fullText 以避免重复
|
||||
if !hasReview && fullText != "" {
|
||||
if h.sessionStore != nil && h.sessionStore.IsAvailable() {
|
||||
if err := h.sessionStore.AddMessage(client.SessionID, "assistant", "chat", fullText); err != nil {
|
||||
if err := h.sessionStore.AddMessage(client.SessionID, "assistant", "chat", fullText, client.ClientID); err != nil {
|
||||
logger.Printf("[chat] 持久化 AI 回复失败: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -426,6 +439,10 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
Role: "assistant",
|
||||
Content: fullText,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
ClientInfo: &ws.ClientInfo{
|
||||
ClientID: client.ClientID,
|
||||
DeviceName: client.DeviceName,
|
||||
},
|
||||
})
|
||||
}
|
||||
// RecordMessage 使用不带 [review] 标记的文本
|
||||
@@ -451,7 +468,6 @@ func (h *ChatHandler) handleVoiceInput(client *ws.Client, msg ws.ClientMessage)
|
||||
client.SendMessage(response)
|
||||
}
|
||||
|
||||
|
||||
// handleHistoryRequest 处理历史消息请求
|
||||
func (h *ChatHandler) handleHistoryRequest(client *ws.Client, msg ws.ClientMessage) {
|
||||
// 优先使用请求中的 session_id,否则使用客户端的 session_id
|
||||
@@ -469,19 +485,28 @@ func (h *ChatHandler) handleHistoryRequest(client *ws.Client, msg ws.ClientMessa
|
||||
logger.Printf("[history] 从数据库恢复会话历史: session=%s, %d 条消息", sessionID, len(dbMessages))
|
||||
// 恢复到内存缓存
|
||||
for _, dbMsg := range dbMessages {
|
||||
var ci *ws.ClientInfo
|
||||
if dbMsg.ClientID != "" {
|
||||
ci = h.hub.ClientInfo(dbMsg.ClientID)
|
||||
if ci == nil {
|
||||
ci = &ws.ClientInfo{ClientID: dbMsg.ClientID}
|
||||
}
|
||||
}
|
||||
messages = append(messages, ws.Message{
|
||||
ID: fmt.Sprintf("db_%d", dbMsg.ID),
|
||||
Role: dbMsg.Role,
|
||||
MsgType: dbMsg.MsgType,
|
||||
Content: dbMsg.Content,
|
||||
Timestamp: dbMsg.CreatedAt.UnixMilli(),
|
||||
ID: fmt.Sprintf("db_%d", dbMsg.ID),
|
||||
Role: dbMsg.Role,
|
||||
MsgType: dbMsg.MsgType,
|
||||
Content: dbMsg.Content,
|
||||
Timestamp: dbMsg.CreatedAt.UnixMilli(),
|
||||
ClientInfo: ci,
|
||||
})
|
||||
h.hub.CacheMessage(client.UserID, sessionID, ws.Message{
|
||||
ID: fmt.Sprintf("db_%d", dbMsg.ID),
|
||||
Role: dbMsg.Role,
|
||||
MsgType: dbMsg.MsgType,
|
||||
Content: dbMsg.Content,
|
||||
Timestamp: dbMsg.CreatedAt.UnixMilli(),
|
||||
ID: fmt.Sprintf("db_%d", dbMsg.ID),
|
||||
Role: dbMsg.Role,
|
||||
MsgType: dbMsg.MsgType,
|
||||
Content: dbMsg.Content,
|
||||
Timestamp: dbMsg.CreatedAt.UnixMilli(),
|
||||
ClientInfo: ci,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -534,72 +559,216 @@ func (h *ChatHandler) HandleProactiveMessage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户是否在线
|
||||
// Parse content to split (action) from chat text.
|
||||
segments := parseProactiveContent(req.Content)
|
||||
|
||||
// Check online status.
|
||||
onlineCount := h.hub.UserClientCount(req.UserID)
|
||||
if onlineCount == 0 {
|
||||
// Phase 2: 离线时排队,等待用户重连后推送
|
||||
data, _ := json.Marshal(ws.ServerMessage{
|
||||
Type: "response",
|
||||
MessageID: "proactive_" + generateID(),
|
||||
Content: req.Content,
|
||||
Role: "assistant",
|
||||
MsgType: "proactive",
|
||||
SessionID: req.SessionID,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
h.hub.QueueProactiveMessage(req.UserID, data)
|
||||
logger.Printf("[proactive] 用户离线,消息已排队: user=%s", req.UserID)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"reason": "queued",
|
||||
"message": "用户离线,消息已排队等待重连后推送",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建主动消息
|
||||
msgID := "proactive_" + generateID()
|
||||
msg := ws.ServerMessage{
|
||||
Type: "response",
|
||||
MessageID: msgID,
|
||||
Content: req.Content,
|
||||
Role: "assistant",
|
||||
MsgType: "proactive",
|
||||
SessionID: req.SessionID,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
logger.Printf("[proactive] 序列化消息失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "内部错误"})
|
||||
return
|
||||
}
|
||||
|
||||
h.hub.SendToUser(req.UserID, data)
|
||||
|
||||
// 同时缓存到对话历史(使用 admin 的主 session)
|
||||
sessionID := req.SessionID
|
||||
if sessionID == "" {
|
||||
sessionID = "session_admin_main"
|
||||
}
|
||||
h.hub.CacheMessage(req.UserID, sessionID, ws.Message{
|
||||
ID: msgID,
|
||||
Role: "assistant",
|
||||
Content: req.Content,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
h.hub.RecordMessage(sessionID, "assistant", req.Content)
|
||||
|
||||
logger.Printf("[proactive] 主动消息已推送: user=%s, online=%d, content_len=%d", req.UserID, onlineCount, len(req.Content))
|
||||
timestamp := time.Now().UnixMilli()
|
||||
|
||||
for i, seg := range segments {
|
||||
msgID := fmt.Sprintf("proactive_%s_%d", generateID(), i)
|
||||
msgType := "chat"
|
||||
role := "assistant"
|
||||
if seg.msgType == "action" {
|
||||
msgType = "action"
|
||||
role = "action"
|
||||
}
|
||||
|
||||
msg := ws.ServerMessage{
|
||||
Type: "response",
|
||||
MessageID: msgID,
|
||||
Content: seg.content,
|
||||
Role: role,
|
||||
MsgType: msgType,
|
||||
SessionID: sessionID,
|
||||
Timestamp: timestamp + int64(i),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
logger.Printf("[proactive] 序列化消息失败: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if onlineCount == 0 {
|
||||
h.hub.QueueProactiveMessage(req.UserID, data)
|
||||
} else {
|
||||
h.hub.SendToUser(req.UserID, data)
|
||||
if i < len(segments)-1 {
|
||||
delay := 200 + int(time.Now().UnixNano()%200)
|
||||
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// Persist to database so proactive messages survive restarts.
|
||||
if h.sessionStore != nil && h.sessionStore.IsAvailable() {
|
||||
if err := h.sessionStore.AddMessage(sessionID, role, msgType, seg.content, ""); err != nil {
|
||||
logger.Printf("[proactive] 持久化消息失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Cache each segment to conversation history.
|
||||
h.hub.CacheMessage(req.UserID, sessionID, ws.Message{
|
||||
ID: msgID,
|
||||
Role: role,
|
||||
MsgType: msgType,
|
||||
Content: seg.content,
|
||||
Timestamp: timestamp,
|
||||
})
|
||||
h.hub.RecordMessage(sessionID, role, seg.content)
|
||||
}
|
||||
|
||||
logger.Printf("[proactive] 主动消息已推送: user=%s, online=%d, segments=%d", req.UserID, onlineCount, len(segments))
|
||||
|
||||
reason := "delivered"
|
||||
if onlineCount == 0 {
|
||||
reason = "queued"
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "消息已推送",
|
||||
"segments": len(segments),
|
||||
"delivered": onlineCount,
|
||||
"reason": reason,
|
||||
})
|
||||
}
|
||||
|
||||
// proactiveSegment holds a parsed piece of a proactive message.
|
||||
type proactiveSegment struct {
|
||||
msgType string // "chat" or "action"
|
||||
content string
|
||||
}
|
||||
|
||||
// parseProactiveContent splits text by (parenthesized actions).
|
||||
// "(笑) 你好呀 (调暗灯光) 今天过得如何" →
|
||||
// [action: "笑", chat: "你好呀", action: "调暗灯光", chat: "今天过得如何"]
|
||||
func parseProactiveContent(text string) []proactiveSegment {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var segments []proactiveSegment
|
||||
remaining := []rune(text)
|
||||
|
||||
for len(remaining) > 0 {
|
||||
actionStart := -1 // index in remaining
|
||||
actionEnd := -1 // index after closing paren
|
||||
var actionContent string
|
||||
|
||||
for i, r := range remaining {
|
||||
if r == '(' || r == '(' {
|
||||
actionStart = i
|
||||
closeRune := ')'
|
||||
if r == '(' {
|
||||
closeRune = ')'
|
||||
}
|
||||
for j := i + 1; j < len(remaining); j++ {
|
||||
if remaining[j] == closeRune {
|
||||
actionEnd = j + 1
|
||||
actionContent = string(remaining[i+1 : j])
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if actionStart >= 0 {
|
||||
if actionStart > 0 {
|
||||
prefix := strings.TrimSpace(string(remaining[:actionStart]))
|
||||
if prefix != "" {
|
||||
segments = append(segments, proactiveSegment{msgType: "chat", content: prefix})
|
||||
}
|
||||
}
|
||||
content := strings.TrimSpace(actionContent)
|
||||
if content != "" {
|
||||
segments = append(segments, proactiveSegment{msgType: "action", content: content})
|
||||
}
|
||||
remaining = remaining[actionEnd:]
|
||||
} else {
|
||||
text := strings.TrimSpace(string(remaining))
|
||||
if text != "" {
|
||||
segments = append(segments, proactiveSegment{msgType: "chat", content: text})
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(segments) == 0 && text != "" {
|
||||
segments = append(segments, proactiveSegment{msgType: "chat", content: strings.TrimSpace(text)})
|
||||
}
|
||||
|
||||
return segments
|
||||
}
|
||||
|
||||
// ========== 多端客户端管理 API ==========
|
||||
|
||||
// HandleListClients returns all known clients for the authenticated user.
|
||||
// GET /api/v1/admin/clients
|
||||
func (h *ChatHandler) HandleListClients(c *gin.Context) {
|
||||
userID := c.Query("user_id")
|
||||
if userID == "" {
|
||||
userID = "admin"
|
||||
}
|
||||
clients := h.hub.GetKnownClients(userID)
|
||||
|
||||
// Merge with persisted notes from DB
|
||||
if h.sessionStore != nil && h.sessionStore.IsAvailable() {
|
||||
dbClients, err := h.sessionStore.GetClients(userID)
|
||||
if err == nil {
|
||||
noteByID := make(map[string]string)
|
||||
for _, dc := range dbClients {
|
||||
noteByID[dc.ClientID] = dc.Note
|
||||
}
|
||||
for i := range clients {
|
||||
if note, ok := noteByID[clients[i].ClientID]; ok && note != "" {
|
||||
clients[i].Note = note
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"clients": clients,
|
||||
"total": len(clients),
|
||||
})
|
||||
}
|
||||
|
||||
// HandleUpdateClientNote sets a label/note on a client.
|
||||
// PUT /api/v1/admin/clients/:id/note
|
||||
func (h *ChatHandler) HandleUpdateClientNote(c *gin.Context) {
|
||||
clientID := c.Param("id")
|
||||
var req struct {
|
||||
Note string `json:"note"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效"})
|
||||
return
|
||||
}
|
||||
|
||||
// Update in-memory
|
||||
if !h.hub.UpdateClientNote(clientID, req.Note) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "客户端未找到"})
|
||||
return
|
||||
}
|
||||
|
||||
// Persist to DB
|
||||
if h.sessionStore != nil && h.sessionStore.IsAvailable() {
|
||||
if err := h.sessionStore.UpdateClientNote(clientID, req.Note); err != nil {
|
||||
logger.Printf("[clients] 持久化备注失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "client_id": clientID, "note": req.Note})
|
||||
}
|
||||
|
||||
func generateID() string {
|
||||
return time.Now().Format("20060102150405") + randomStr(6)
|
||||
}
|
||||
@@ -609,7 +778,7 @@ func randomStr(n int) string {
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// fallback: deterministic but hard to predict
|
||||
for i := range b {
|
||||
b[i] = byte(time.Now().UnixNano()%256)
|
||||
b[i] = byte(time.Now().UnixNano() % 256)
|
||||
}
|
||||
}
|
||||
return hex.EncodeToString(b)[:n]
|
||||
@@ -637,4 +806,3 @@ func parseMultiMessage(text string) []string {
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,297 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
)
|
||||
|
||||
// ModelConfigHandler exposes admin CRUD endpoints for model configuration.
|
||||
type ModelConfigHandler struct {
|
||||
store *config.ModelsConfigStore
|
||||
}
|
||||
|
||||
func NewModelConfigHandler(store *config.ModelsConfigStore) *ModelConfigHandler {
|
||||
return &ModelConfigHandler{store: store}
|
||||
}
|
||||
|
||||
// ---- Providers ----
|
||||
|
||||
func (h *ModelConfigHandler) ListProviders(c *gin.Context) {
|
||||
providers := h.store.ListProviders()
|
||||
if providers == nil {
|
||||
providers = []*config.ProviderConfig{}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"providers": providers, "total": len(providers)})
|
||||
}
|
||||
|
||||
func (h *ModelConfigHandler) GetProvider(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
p, err := h.store.GetProvider(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, p)
|
||||
}
|
||||
|
||||
func (h *ModelConfigHandler) SetProvider(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
var body config.ProviderConfig
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON: " + err.Error()})
|
||||
return
|
||||
}
|
||||
body.Name = name
|
||||
if err := h.store.SetProvider(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "saved", "name": name})
|
||||
}
|
||||
|
||||
func (h *ModelConfigHandler) DeleteProvider(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
if err := h.store.DeleteProvider(name); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted", "name": name})
|
||||
}
|
||||
|
||||
// ---- Models ----
|
||||
|
||||
func (h *ModelConfigHandler) ListModels(c *gin.Context) {
|
||||
models := h.store.ListModels()
|
||||
if models == nil {
|
||||
models = []*config.ModelConfig{}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"models": models, "total": len(models)})
|
||||
}
|
||||
|
||||
func (h *ModelConfigHandler) GetModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
m, err := h.store.GetModel(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, m)
|
||||
}
|
||||
|
||||
func (h *ModelConfigHandler) SetModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var body config.ModelConfig
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON: " + err.Error()})
|
||||
return
|
||||
}
|
||||
body.ID = id
|
||||
if err := h.store.SetModel(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "saved", "id": id})
|
||||
}
|
||||
|
||||
func (h *ModelConfigHandler) DeleteModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if err := h.store.DeleteModel(id); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted", "id": id})
|
||||
}
|
||||
|
||||
// ---- Routing ----
|
||||
|
||||
func (h *ModelConfigHandler) ListRouting(c *gin.Context) {
|
||||
routing := h.store.ListRouting()
|
||||
if routing == nil {
|
||||
routing = []*config.RoutingRule{}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"routing": routing, "total": len(routing)})
|
||||
}
|
||||
|
||||
func (h *ModelConfigHandler) GetRouting(c *gin.Context) {
|
||||
purpose := c.Param("purpose")
|
||||
r, err := h.store.GetRouting(purpose)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, r)
|
||||
}
|
||||
|
||||
func (h *ModelConfigHandler) SetRouting(c *gin.Context) {
|
||||
purpose := c.Param("purpose")
|
||||
var body config.RoutingRule
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON: " + err.Error()})
|
||||
return
|
||||
}
|
||||
body.Purpose = purpose
|
||||
if err := h.store.SetRouting(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "saved", "purpose": purpose})
|
||||
}
|
||||
|
||||
func (h *ModelConfigHandler) DeleteRouting(c *gin.Context) {
|
||||
purpose := c.Param("purpose")
|
||||
if err := h.store.DeleteRouting(purpose); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted", "purpose": purpose})
|
||||
}
|
||||
|
||||
// ---- Health Check ----
|
||||
|
||||
func (h *ModelConfigHandler) TestProvider(c *gin.Context) {
|
||||
var body struct {
|
||||
Provider string `json:"provider"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON: " + err.Error()})
|
||||
return
|
||||
}
|
||||
p, err := h.store.GetProvider(body.Provider)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"provider": p.Name,
|
||||
"base_url": p.BaseURL,
|
||||
"message": "Provider 配置已保存,连接测试请通过实际 LLM 调用验证",
|
||||
})
|
||||
}
|
||||
|
||||
// ---- Remote Model List Proxy ----
|
||||
|
||||
// ProxyListModels forwards a request to the provider's models endpoint using the stored API key.
|
||||
func (h *ModelConfigHandler) ProxyListModels(c *gin.Context) {
|
||||
providerName := c.Param("name")
|
||||
modelsURL := c.Query("url")
|
||||
if modelsURL == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing 'url' query parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
p, err := h.store.GetProvider(providerName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if p.APIKey == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "provider 未配置 API Key"})
|
||||
return
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
req, err := http.NewRequest("GET", modelsURL, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建请求失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+p.APIKey)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "请求模型列表失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) // 2 MB limit
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "读取响应失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": fmt.Sprintf("Provider API 返回错误 (HTTP %d)", resp.StatusCode),
|
||||
"body": string(body),
|
||||
"models_url": modelsURL,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the response body which may use different formats:
|
||||
// OpenAI: {"object":"list","data":[{"id":"...","object":"model",...}]}
|
||||
// DashScope: {"request_id":"...","data":{"models":[{"model_id":"..."}]}}
|
||||
// Generic: {"data":[{"id":"..."}]} or {"data":[{"model_id":"..."}]}
|
||||
ids := parseModelListResponse(body)
|
||||
if len(ids) == 0 {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": "无法从 Provider 响应中解析模型列表 (不支持的格式)",
|
||||
"raw": string(body),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"provider": providerName,
|
||||
"url": modelsURL,
|
||||
"models": ids,
|
||||
"total": len(ids),
|
||||
})
|
||||
}
|
||||
|
||||
// parseModelListResponse attempts to extract model IDs from various provider response formats.
|
||||
// Supported formats:
|
||||
// - OpenAI-compatible: {"object":"list","data":[{"id":"gpt-4o",...}]}
|
||||
// - DashScope: {"data":{"models":[{"model_id":"qwen-turbo",...}]}}
|
||||
// - Generic: {"data":[{"id":"..."}]} or {"data":[{"model_id":"..."}]}
|
||||
func parseModelListResponse(body []byte) []string {
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal(body, &raw); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Strategy 1: data is an array of objects — try "id" then "model_id"
|
||||
if dataArr, ok := raw["data"].([]interface{}); ok {
|
||||
ids := extractIDs(dataArr, "id")
|
||||
if len(ids) > 0 {
|
||||
return ids
|
||||
}
|
||||
return extractIDs(dataArr, "model_id")
|
||||
}
|
||||
|
||||
// Strategy 2: data is an object with a "models" array (DashScope format)
|
||||
if dataObj, ok := raw["data"].(map[string]interface{}); ok {
|
||||
if modelsArr, ok := dataObj["models"].([]interface{}); ok {
|
||||
ids := extractIDs(modelsArr, "model_id")
|
||||
if len(ids) > 0 {
|
||||
return ids
|
||||
}
|
||||
return extractIDs(modelsArr, "id")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractIDs(items []interface{}, key string) []string {
|
||||
ids := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
if obj, ok := item.(map[string]interface{}); ok {
|
||||
if v, ok := obj[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
ids = append(ids, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
@@ -322,8 +322,9 @@ func (h *SessionHandler) GetMessages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 降级:从 Hub 内存缓存读取
|
||||
messages := h.hub.GetConversation("", sessionID)
|
||||
// 降级:从 Hub 内存缓存读取(使用当前认证用户的 ID 作为缓存键前缀)
|
||||
userID := middleware.GetUserID(c)
|
||||
messages := h.hub.GetConversation(userID, sessionID)
|
||||
if messages == nil {
|
||||
messages = []ws.Message{}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
// Setup 注册所有路由
|
||||
func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.SessionStore, reminderStore *store.ReminderStore, briefingStore *store.BriefingStore, automationStore *store.AutomationStore, fileStore *store.FileStore, ruleEngine *engine.RuleEngine, knowledgeStore *store.KnowledgeStore, imageHandler *handler.ImageHandler, db interface{}) {
|
||||
func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.SessionStore, reminderStore *store.ReminderStore, briefingStore *store.BriefingStore, automationStore *store.AutomationStore, fileStore *store.FileStore, ruleEngine *engine.RuleEngine, knowledgeStore *store.KnowledgeStore, imageHandler *handler.ImageHandler, db interface{}, modelConfigStore *config.ModelsConfigStore) {
|
||||
// 限流器
|
||||
rateLimiter := middleware.NewRateLimiter(10, 20) // 每秒10个请求,突发20
|
||||
|
||||
@@ -36,6 +36,7 @@ func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.S
|
||||
fileHandler := handler.NewFileHandler(fileStore)
|
||||
automationHandler := handler.NewAutomationHandler(automationStore, ruleEngine)
|
||||
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeStore, fileStore)
|
||||
modelConfigHandler := handler.NewModelConfigHandler(modelConfigStore)
|
||||
if imageHandler == nil {
|
||||
imageHandler = handler.NewImageHandler(cfg, fileStore)
|
||||
}
|
||||
@@ -199,6 +200,29 @@ func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.S
|
||||
admin.GET("/sessions", sessionHandler.ListActiveSessions)
|
||||
admin.GET("/sessions/active", sessionHandler.GetActiveSessions)
|
||||
admin.GET("/sessions/:id", sessionHandler.GetSession)
|
||||
|
||||
// 多端客户端管理
|
||||
admin.GET("/clients", chatHandler.HandleListClients)
|
||||
admin.PUT("/clients/:id/note", chatHandler.HandleUpdateClientNote)
|
||||
|
||||
// 模型配置管理
|
||||
models := admin.Group("/models")
|
||||
{
|
||||
models.GET("/providers", modelConfigHandler.ListProviders)
|
||||
models.GET("/providers/:name", modelConfigHandler.GetProvider)
|
||||
models.POST("/providers/:name", modelConfigHandler.SetProvider)
|
||||
models.DELETE("/providers/:name", modelConfigHandler.DeleteProvider)
|
||||
models.GET("/models", modelConfigHandler.ListModels)
|
||||
models.GET("/models/:id", modelConfigHandler.GetModel)
|
||||
models.POST("/models/:id", modelConfigHandler.SetModel)
|
||||
models.DELETE("/models/:id", modelConfigHandler.DeleteModel)
|
||||
models.GET("/routing", modelConfigHandler.ListRouting)
|
||||
models.GET("/routing/:purpose", modelConfigHandler.GetRouting)
|
||||
models.POST("/routing/:purpose", modelConfigHandler.SetRouting)
|
||||
models.DELETE("/routing/:purpose", modelConfigHandler.DeleteRouting)
|
||||
models.POST("/health-check", modelConfigHandler.TestProvider)
|
||||
models.GET("/fetch-models/:name", modelConfigHandler.ProxyListModels)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,9 +26,21 @@ type Message struct {
|
||||
Role string `json:"role"`
|
||||
MsgType string `json:"msg_type"`
|
||||
Content string `json:"content"`
|
||||
ClientID string `json:"client_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// ClientRecord 客户端记录 (持久化)
|
||||
type ClientRecord struct {
|
||||
ClientID string `json:"client_id"`
|
||||
UserID string `json:"user_id"`
|
||||
DeviceName string `json:"device_name"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
Note string `json:"note"`
|
||||
FirstSeenAt time.Time `json:"first_seen_at"`
|
||||
LastSeenAt time.Time `json:"last_seen_at"`
|
||||
}
|
||||
|
||||
// SessionStore 会话持久化存储
|
||||
type SessionStore struct {
|
||||
db *sql.DB
|
||||
@@ -92,6 +104,19 @@ func (s *SessionStore) migrate() error {
|
||||
|
||||
// 为已存在的数据库添加 msg_type 列 (Phase 0.1)
|
||||
`ALTER TABLE messages ADD COLUMN IF NOT EXISTS msg_type VARCHAR(16) DEFAULT 'chat'`,
|
||||
// 为已存在的数据库添加 client_id 列 (Phase 5: 多端客户端追踪)
|
||||
`ALTER TABLE messages ADD COLUMN IF NOT EXISTS client_id VARCHAR(128) DEFAULT ''`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS clients (
|
||||
client_id VARCHAR(128) PRIMARY KEY,
|
||||
user_id VARCHAR(128) NOT NULL,
|
||||
device_name VARCHAR(256) DEFAULT '',
|
||||
user_agent VARCHAR(512) DEFAULT '',
|
||||
note VARCHAR(256) DEFAULT '',
|
||||
first_seen_at TIMESTAMP DEFAULT NOW(),
|
||||
last_seen_at TIMESTAMP DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_clients_user_id ON clients(user_id)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
@@ -205,10 +230,10 @@ func (s *SessionStore) DeleteAllUserSessions(userID string) error {
|
||||
}
|
||||
|
||||
// AddMessage 添加一条消息到会话
|
||||
func (s *SessionStore) AddMessage(sessionID, role, msgType, content string) error {
|
||||
func (s *SessionStore) AddMessage(sessionID, role, msgType, content, clientID string) error {
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO messages (session_id, role, msg_type, content) VALUES ($1, $2, $3, $4)`,
|
||||
sessionID, role, msgType, content,
|
||||
`INSERT INTO messages (session_id, role, msg_type, content, client_id) VALUES ($1, $2, $3, $4, $5)`,
|
||||
sessionID, role, msgType, content, clientID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加消息失败: %w", err)
|
||||
@@ -226,7 +251,7 @@ func (s *SessionStore) GetMessages(sessionID string, limit, offset int) ([]Messa
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, session_id, role, COALESCE(msg_type, 'chat'), content, created_at
|
||||
`SELECT id, session_id, role, COALESCE(msg_type, 'chat'), content, COALESCE(client_id, ''), created_at
|
||||
FROM messages WHERE session_id = $1
|
||||
ORDER BY created_at ASC
|
||||
LIMIT $2 OFFSET $3`,
|
||||
@@ -240,7 +265,7 @@ func (s *SessionStore) GetMessages(sessionID string, limit, offset int) ([]Messa
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
var msg Message
|
||||
if err := rows.Scan(&msg.ID, &msg.SessionID, &msg.Role, &msg.MsgType, &msg.Content, &msg.CreatedAt); err != nil {
|
||||
if err := rows.Scan(&msg.ID, &msg.SessionID, &msg.Role, &msg.MsgType, &msg.Content, &msg.ClientID, &msg.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描消息行失败: %w", err)
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
@@ -341,3 +366,57 @@ func (s *SessionStore) IsAvailable() bool {
|
||||
}
|
||||
return s.db.Ping() == nil
|
||||
}
|
||||
|
||||
// ========== 多端客户端追踪 ==========
|
||||
|
||||
// UpsertClient inserts or updates a client record.
|
||||
func (s *SessionStore) UpsertClient(clientID, userID, deviceName, userAgent string) error {
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO clients (client_id, user_id, device_name, user_agent, first_seen_at, last_seen_at)
|
||||
VALUES ($1, $2, $3, $4, NOW(), NOW())
|
||||
ON CONFLICT (client_id) DO UPDATE SET
|
||||
device_name = EXCLUDED.device_name,
|
||||
user_agent = EXCLUDED.user_agent,
|
||||
last_seen_at = NOW()`,
|
||||
clientID, userID, deviceName, userAgent,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upsert client failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClients returns all known clients for a user.
|
||||
func (s *SessionStore) GetClients(userID string) ([]ClientRecord, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT client_id, user_id, device_name, user_agent, note, first_seen_at, last_seen_at
|
||||
FROM clients WHERE user_id = $1 ORDER BY last_seen_at DESC`,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query clients failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var result []ClientRecord
|
||||
for rows.Next() {
|
||||
var cr ClientRecord
|
||||
if err := rows.Scan(&cr.ClientID, &cr.UserID, &cr.DeviceName, &cr.UserAgent, &cr.Note, &cr.FirstSeenAt, &cr.LastSeenAt); err != nil {
|
||||
return nil, fmt.Errorf("scan client row failed: %w", err)
|
||||
}
|
||||
result = append(result, cr)
|
||||
}
|
||||
if result == nil {
|
||||
result = []ClientRecord{}
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
// UpdateClientNote sets the user-defined note for a client.
|
||||
func (s *SessionStore) UpdateClientNote(clientID, note string) error {
|
||||
_, err := s.db.Exec(`UPDATE clients SET note = $1 WHERE client_id = $2`, note, clientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update client note failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -24,21 +24,27 @@ const (
|
||||
|
||||
// Client WebSocket客户端
|
||||
type Client struct {
|
||||
Hub *Hub
|
||||
Conn *websocket.Conn
|
||||
Send chan []byte
|
||||
UserID string
|
||||
SessionID string
|
||||
Hub *Hub
|
||||
Conn *websocket.Conn
|
||||
Send chan []byte
|
||||
UserID string
|
||||
SessionID string
|
||||
ClientID string
|
||||
DeviceName string
|
||||
UserAgent string
|
||||
}
|
||||
|
||||
// NewClient 创建WebSocket客户端
|
||||
func NewClient(hub *Hub, conn *websocket.Conn, userID, sessionID string) *Client {
|
||||
func NewClient(hub *Hub, conn *websocket.Conn, userID, sessionID, clientID, deviceName, userAgent string) *Client {
|
||||
return &Client{
|
||||
Hub: hub,
|
||||
Conn: conn,
|
||||
Send: make(chan []byte, 256),
|
||||
UserID: userID,
|
||||
SessionID: sessionID,
|
||||
Hub: hub,
|
||||
Conn: conn,
|
||||
Send: make(chan []byte, 256),
|
||||
UserID: userID,
|
||||
SessionID: sessionID,
|
||||
ClientID: clientID,
|
||||
DeviceName: deviceName,
|
||||
UserAgent: userAgent,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -39,10 +39,23 @@ type Message struct {
|
||||
MsgType string `json:"msg_type,omitempty"`
|
||||
Attachments []MessageAttachment `json:"attachments,omitempty"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
ClientInfo *ClientInfo `json:"client_info,omitempty"`
|
||||
}
|
||||
|
||||
const maxRecentMessages = 20
|
||||
|
||||
// KnownClient tracks a device that has ever connected (online or offline).
|
||||
type KnownClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
UserID string `json:"user_id"`
|
||||
DeviceName string `json:"device_name"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
Note string `json:"note"` // user-assigned label
|
||||
Online bool `json:"online"`
|
||||
LastSeenAt time.Time `json:"last_seen_at"`
|
||||
FirstSeenAt time.Time `json:"first_seen_at"`
|
||||
}
|
||||
|
||||
// Hub WebSocket连接池
|
||||
type Hub struct {
|
||||
mu sync.RWMutex
|
||||
@@ -76,6 +89,9 @@ type Hub struct {
|
||||
pendingProactive map[string][]json.RawMessage // userID -> queued messages
|
||||
aiCoreURL string
|
||||
internalToken string
|
||||
|
||||
// 多端客户端追踪: clientID -> KnownClient (在线+离线)
|
||||
knownClients map[string]*KnownClient
|
||||
}
|
||||
|
||||
// SetStore 设置持久化存储 (可选)
|
||||
@@ -100,6 +116,7 @@ func NewHub() *Hub {
|
||||
iotStopCh: make(chan struct{}),
|
||||
idleTimeout: 30 * time.Minute, // 默认30分钟
|
||||
pendingProactive: make(map[string][]json.RawMessage),
|
||||
knownClients: make(map[string]*KnownClient),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -254,6 +271,34 @@ func (h *Hub) Run() {
|
||||
MessageCount: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// 多端客户端追踪
|
||||
if client.ClientID != "" {
|
||||
now := time.Now()
|
||||
if kc, ok := h.knownClients[client.ClientID]; ok {
|
||||
kc.Online = true
|
||||
kc.LastSeenAt = now
|
||||
kc.DeviceName = client.DeviceName
|
||||
kc.UserAgent = client.UserAgent
|
||||
} else {
|
||||
h.knownClients[client.ClientID] = &KnownClient{
|
||||
ClientID: client.ClientID,
|
||||
UserID: client.UserID,
|
||||
DeviceName: client.DeviceName,
|
||||
UserAgent: client.UserAgent,
|
||||
Online: true,
|
||||
LastSeenAt: now,
|
||||
FirstSeenAt: now,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 持久化客户端记录到数据库
|
||||
if client.ClientID != "" && h.store != nil && h.store.IsAvailable() {
|
||||
if err := h.store.UpsertClient(client.ClientID, client.UserID, client.DeviceName, client.UserAgent); err != nil {
|
||||
logger.Printf("[WS] 持久化客户端记录失败: %v", err)
|
||||
}
|
||||
}
|
||||
// Phase 2: 检测是否为重连 (之前处于离线状态)
|
||||
wasOffline := len(h.userClients[client.UserID]) == 1 // 刚加入,之前为0
|
||||
h.mu.Unlock()
|
||||
@@ -308,6 +353,23 @@ func (h *Hub) Run() {
|
||||
s.State = "idle"
|
||||
}
|
||||
}
|
||||
|
||||
// 多端客户端追踪: 检查同一 clientID 是否还有其他连接
|
||||
if client.ClientID != "" {
|
||||
hasOtherClientConn := false
|
||||
for c := range h.clients {
|
||||
if c.ClientID == client.ClientID {
|
||||
hasOtherClientConn = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOtherClientConn {
|
||||
if kc, ok := h.knownClients[client.ClientID]; ok {
|
||||
kc.Online = false
|
||||
kc.LastSeenAt = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
@@ -362,6 +424,23 @@ func (h *Hub) Run() {
|
||||
s.State = "idle"
|
||||
}
|
||||
}
|
||||
|
||||
// 多端客户端追踪
|
||||
if client.ClientID != "" {
|
||||
hasOtherClientConn := false
|
||||
for c := range h.clients {
|
||||
if c.ClientID == client.ClientID {
|
||||
hasOtherClientConn = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOtherClientConn {
|
||||
if kc, ok := h.knownClients[client.ClientID]; ok {
|
||||
kc.Online = false
|
||||
kc.LastSeenAt = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
@@ -766,3 +845,62 @@ func (h *Hub) DeleteConversation(userID, sessionID string) {
|
||||
key := cacheKey(userID, sessionID)
|
||||
h.conversationCache.Delete(key)
|
||||
}
|
||||
|
||||
// ========== 多端客户端追踪 ==========
|
||||
|
||||
// GetKnownClients returns all known clients (online + offline).
|
||||
func (h *Hub) GetKnownClients(userID string) []KnownClient {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
result := make([]KnownClient, 0)
|
||||
for _, kc := range h.knownClients {
|
||||
if userID == "" || kc.UserID == userID {
|
||||
cp := *kc
|
||||
cp.UserAgent = "" // don't leak UA in list
|
||||
result = append(result, cp)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// UpdateClientNote sets a user-defined note/label on a client.
|
||||
func (h *Hub) UpdateClientNote(clientID, note string) bool {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
kc, ok := h.knownClients[clientID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
kc.Note = note
|
||||
return true
|
||||
}
|
||||
|
||||
// ClientInfo returns the ClientInfo for a given client.
|
||||
func (h *Hub) ClientInfo(clientID string) *ClientInfo {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
kc, ok := h.knownClients[clientID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return &ClientInfo{
|
||||
ClientID: kc.ClientID,
|
||||
DeviceName: kc.DeviceName,
|
||||
UserAgent: kc.UserAgent,
|
||||
}
|
||||
}
|
||||
|
||||
// buildClientInfo builds a ClientInfo from a Client.
|
||||
func buildClientInfo(c *Client) *ClientInfo {
|
||||
if c.ClientID == "" {
|
||||
return nil
|
||||
}
|
||||
return &ClientInfo{
|
||||
ClientID: c.ClientID,
|
||||
DeviceName: c.DeviceName,
|
||||
UserAgent: c.UserAgent,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,9 @@ type ClientMessage struct {
|
||||
AudioData string `json:"audio_data,omitempty"` // base64
|
||||
Attachments []MessageAttachment `json:"attachments,omitempty"` // 图片等附件
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
ClientID string `json:"client_id,omitempty"` // 客户端唯一标识 (多端区分)
|
||||
DeviceName string `json:"device_name,omitempty"` // 设备备注名称
|
||||
UserAgent string `json:"user_agent,omitempty"` // 浏览器 UA
|
||||
}
|
||||
|
||||
// ReviewMessage 审查后的结构化消息(动作/聊天分离)
|
||||
@@ -30,6 +33,13 @@ type ReviewMessage struct {
|
||||
DelayMs int `json:"delay_ms,omitempty"` // ms to wait before sending (0 = immediate)
|
||||
}
|
||||
|
||||
// ClientInfo carries the originating client's device metadata.
|
||||
type ClientInfo struct {
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
DeviceName string `json:"device_name,omitempty"`
|
||||
UserAgent string `json:"user_agent,omitempty"`
|
||||
}
|
||||
|
||||
// 服务端 → 客户端消息
|
||||
type ServerMessage struct {
|
||||
Type string `json:"type"` // response | segment | audio | error | device_update | pong | history_response | stream_chunk | stream_end | background_thinking | notification | multi_message | stream_segments | review | thinking | tool_progress | system_info
|
||||
@@ -55,6 +65,7 @@ type ServerMessage struct {
|
||||
ToolProgress *ToolProgressInfo `json:"tool_progress,omitempty"` // 工具执行进度
|
||||
SystemInfo *SystemInfoPayload `json:"system_info,omitempty"` // 系统通知信息
|
||||
ProtocolVersion int `json:"protocol_version,omitempty"` // 协议版本
|
||||
ClientInfo *ClientInfo `json:"client_info,omitempty"` // 消息来源客户端信息
|
||||
}
|
||||
|
||||
// ToolProgressInfo 工具执行进度
|
||||
|
||||
Reference in New Issue
Block a user