fix: 第二轮修复 — 数据库启动检查、会话持久化、URL路由、设备排序等

1. DevTools 启动前检查数据库状态,失败时自动尝试启动
2. ai-core 添加数据库断线重连机制 (30秒间隔)
3. Dashboard 添加数据库状态卡片 (启动/停止/重启)
4. Gateway 会话空闲超时管理 (30分钟标记空闲)
5. 会话/消息 PostgreSQL 持久化 (SessionStore + REST API)
6. 前端服务端会话持久化 + URL hash 路由 + 侧边栏管理
7. 管理员回到主对话按钮
8. IoT 设备卡片固定排序
9. 更新相关文档
This commit is contained in:
2026-05-17 17:18:02 +08:00
parent 745b1c6aad
commit e7b7eff0d8
21 changed files with 1735 additions and 284 deletions
+8 -13
View File
@@ -64,21 +64,16 @@ func main() {
var memExtractor *memory.Extractor
if cfg.DatabaseURL != "" {
memStore, err = memory.NewStore(cfg.DatabaseURL)
if err != nil {
log.Printf("⚠ 记忆存储初始化失败 (将跳过记忆功能): %v", err)
} else {
defer memStore.Close()
log.Println("记忆存储已就绪")
memStore = memory.NewStore(cfg.DatabaseURL)
defer memStore.Close()
memRetriever = memory.NewRetriever(memStore, nil)
memRetriever = memory.NewRetriever(memStore, nil)
// 记忆提取器使用LLM
memExtractor = memory.NewExtractor(memStore, func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) {
return llmAdapter.Chat(ctx, messages)
})
log.Println("记忆提取器已就绪")
}
// 记忆提取器使用LLM
memExtractor = memory.NewExtractor(memStore, func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) {
return llmAdapter.Chat(ctx, messages)
})
log.Println("记忆提取器已就绪")
}
// 初始化会话历史存储
+153 -16
View File
@@ -4,22 +4,100 @@ import (
"context"
"database/sql"
"fmt"
"log"
"sync"
"time"
"github.com/yourname/cyrene-ai/ai-core/internal/model"
_ "github.com/lib/pq"
)
const reconnectInterval = 30 * time.Second
// Store 记忆持久化存储(PostgreSQL + pgvector
type Store struct {
db *sql.DB
databaseURL string
mu sync.RWMutex
db *sql.DB
}
// errDBNotReady 数据库未就绪时返回的友好错误
var errDBNotReady = fmt.Errorf("记忆系统未就绪: 数据库连接不可用,正在后台重试连接")
// NewStore 创建记忆存储
func NewStore(connStr string) (*Store, error) {
db, err := sql.Open("postgres", connStr)
// 连接失败时不返回 error,而是启动后台重连循环
func NewStore(connStr string) *Store {
s := &Store{
databaseURL: connStr,
}
// 尝试初始连接
if err := s.Reconnect(); err != nil {
log.Printf("[memory] ⚠ 记忆存储初始化: 数据库连接失败 (%v),将在后台每30秒重试", err)
} else {
log.Println("[memory] 记忆存储已就绪")
}
// 启动后台重连 goroutine
go s.reconnectLoop()
return s
}
// reconnectLoop 后台重连循环
func (s *Store) reconnectLoop() {
ticker := time.NewTicker(reconnectInterval)
defer ticker.Stop()
for range ticker.C {
s.mu.RLock()
ready := s.db != nil
s.mu.RUnlock()
if ready {
// 数据库已连接,检查连接是否仍然有效
s.mu.RLock()
db := s.db
s.mu.RUnlock()
if db != nil {
if err := db.Ping(); err != nil {
log.Printf("[memory] ⚠ 数据库连接丢失: %v,开始重连", err)
s.mu.Lock()
if s.db != nil {
s.db.Close()
s.db = nil
}
s.mu.Unlock()
}
}
}
if !s.IsReady() {
if err := s.Reconnect(); err != nil {
log.Printf("[memory] ⚠ 数据库重连失败: %v", err)
}
}
}
}
// Reconnect 尝试重连数据库并执行迁移
func (s *Store) Reconnect() error {
s.mu.Lock()
defer s.mu.Unlock()
// 如果已有有效连接,先检查
if s.db != nil {
if err := s.db.Ping(); err == nil {
return nil // 仍然有效
}
// 连接已失效,关闭旧连接
s.db.Close()
s.db = nil
}
db, err := sql.Open("postgres", s.databaseURL)
if err != nil {
return nil, fmt.Errorf("连接数据库失败: %w", err)
return fmt.Errorf("连接数据库失败: %w", err)
}
db.SetMaxOpenConns(25)
@@ -27,15 +105,36 @@ func NewStore(connStr string) (*Store, error) {
db.SetConnMaxLifetime(5 * time.Minute)
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("数据库ping失败: %w", err)
db.Close()
return fmt.Errorf("数据库ping失败: %w", err)
}
s := &Store{db: db}
s.db = db
// 执行建表迁移
if err := s.migrate(); err != nil {
return nil, fmt.Errorf("数据库迁移失败: %w", err)
log.Printf("[memory] ⚠ 数据库迁移失败: %v", err)
s.db.Close()
s.db = nil
return fmt.Errorf("数据库迁移失败: %w", err)
}
return s, nil
log.Println("[memory] ✅ 数据库重连成功,记忆系统已就绪")
return nil
}
// IsReady 返回数据库是否可用
func (s *Store) IsReady() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.db != nil
}
// getDB 获取当前数据库连接(带读锁保护)
func (s *Store) getDB() *sql.DB {
s.mu.RLock()
defer s.mu.RUnlock()
return s.db
}
// migrate 创建表结构
@@ -73,6 +172,11 @@ func (s *Store) migrate() error {
// Save 保存记忆
func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
db := s.getDB()
if db == nil {
return errDBNotReady
}
query := `INSERT INTO memories (user_id, content, summary, category, priority, session_id, source, embedding, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id, created_at`
@@ -86,7 +190,7 @@ func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
embedding = fmt.Sprintf("[%s]", joinFloats(vec))
}
return s.db.QueryRowContext(ctx, query,
return db.QueryRowContext(ctx, query,
entry.UserID, entry.Content, entry.Summary,
string(entry.Category), int(entry.Priority),
entry.SessionID, entry.Source, embedding, entry.ExpiresAt,
@@ -95,13 +199,18 @@ func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
// GetByID 根据ID获取记忆
func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
access_count, last_access, created_at, expires_at
FROM memories WHERE id = $1`
entry := &model.MemoryEntry{}
var category string
err := s.db.QueryRowContext(ctx, query, id).Scan(
err := db.QueryRowContext(ctx, query, id).Scan(
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
&category, &entry.Priority, &entry.SessionID, &entry.Source,
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
@@ -122,6 +231,11 @@ func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, err
// Query 按条件查询记忆
func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryEntry, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
if q.Limit <= 0 {
q.Limit = 10
}
@@ -147,7 +261,7 @@ func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryE
query += fmt.Sprintf(" ORDER BY priority DESC, created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
args = append(args, q.Limit, q.Offset)
rows, err := s.db.QueryContext(ctx, query, args...)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("查询记忆失败: %w", err)
}
@@ -173,13 +287,21 @@ func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryE
// Delete 删除记忆
func (s *Store) Delete(ctx context.Context, id string) error {
_, err := s.db.ExecContext(ctx, `DELETE FROM memories WHERE id = $1`, id)
db := s.getDB()
if db == nil {
return errDBNotReady
}
_, err := db.ExecContext(ctx, `DELETE FROM memories WHERE id = $1`, id)
return err
}
// PurgeExpired 清理过期记忆
func (s *Store) PurgeExpired(ctx context.Context) (int64, error) {
result, err := s.db.ExecContext(ctx,
db := s.getDB()
if db == nil {
return 0, errDBNotReady
}
result, err := db.ExecContext(ctx,
`DELETE FROM memories WHERE expires_at IS NOT NULL AND expires_at < NOW()`)
if err != nil {
return 0, err
@@ -189,6 +311,11 @@ func (s *Store) PurgeExpired(ctx context.Context) (int64, error) {
// SearchByVector 向量相似度搜索
func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []float64, limit int) ([]model.MemoryEntry, error) {
db := s.getDB()
if db == nil {
return nil, errDBNotReady
}
if limit <= 0 {
limit = 5
}
@@ -202,7 +329,7 @@ func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []f
ORDER BY embedding <=> $1
LIMIT $3`
rows, err := s.db.QueryContext(ctx, query, vecStr, userID, limit)
rows, err := db.QueryContext(ctx, query, vecStr, userID, limit)
if err != nil {
return nil, fmt.Errorf("向量搜索失败: %w", err)
}
@@ -229,13 +356,23 @@ func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []f
}
func (s *Store) incrementAccess(ctx context.Context, id string) {
s.db.ExecContext(ctx,
db := s.getDB()
if db == nil {
return
}
db.ExecContext(ctx,
`UPDATE memories SET access_count = access_count + 1, last_access = NOW() WHERE id = $1`, id)
}
// Close 关闭数据库连接
func (s *Store) Close() error {
return s.db.Close()
s.mu.Lock()
defer s.mu.Unlock()
if s.db != nil {
return s.db.Close()
}
return nil
}
// joinFloats 将 float64 切片转为逗号分隔字符串
+25 -1
View File
@@ -14,6 +14,7 @@ import (
"github.com/yourname/cyrene-ai/gateway/internal/config"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/router"
"github.com/yourname/cyrene-ai/gateway/internal/store"
"github.com/yourname/cyrene-ai/gateway/internal/ws"
)
@@ -26,6 +27,17 @@ func main() {
// 加载配置
cfg := config.Load()
// 初始化数据库持久化存储 (降级:连接失败不崩溃)
var sessionStore *store.SessionStore
databaseURL := cfg.DatabaseURL()
if s, err := store.NewSessionStore(databaseURL); err != nil {
log.Printf("⚠ 会话持久化存储初始化失败 (数据库不可用): %v", err)
log.Println("⚠ Gateway 将以仅内存模式运行 — 会话数据在重启后丢失")
} else {
sessionStore = s
log.Println("✅ 会话持久化存储已启用 (PostgreSQL)")
}
// 初始化Gin
if cfg.Env == "production" {
gin.SetMode(gin.ReleaseMode)
@@ -39,13 +51,18 @@ func main() {
// 初始化WebSocket Hub
hub := ws.NewHub()
hub.SetStore(sessionStore)
hub.SetIdleTimeout(cfg.SessionIdleTimeoutMin)
go hub.Run()
// 启动闲置会话清理 (标记超时会话为 idle,不删除)
hub.StartIdleCleanup()
// 启动 IoT 设备状态广播(每10秒向所有WebSocket客户端推送设备状态)
hub.StartIoTBroadcast(cfg.IoTDebugServiceURL)
// 注册路由
router.Setup(r, hub, cfg)
router.Setup(r, hub, cfg, sessionStore)
// 启动服务
srv := &http.Server{
@@ -68,6 +85,13 @@ func main() {
hub.StopIoTBroadcast()
// 关闭数据库连接
if sessionStore != nil {
if err := sessionStore.Close(); err != nil {
log.Printf("⚠ 关闭数据库连接失败: %v", err)
}
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
srv.Shutdown(ctx)
+1
View File
@@ -7,6 +7,7 @@ require (
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/gorilla/websocket v1.5.3
github.com/joho/godotenv v1.5.1
github.com/lib/pq v1.10.9
)
require (
+2
View File
@@ -42,6 +42,8 @@ github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZY
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
+16 -1
View File
@@ -1,6 +1,7 @@
package config
import (
"fmt"
"os"
"time"
@@ -49,6 +50,9 @@ type Config struct {
// WebSocket
WSMaxConnections int
// 会话闲置超时 (分钟) — 超过此时间后会话标记为 idle 但不删除
SessionIdleTimeoutMin int
// Webhook (第三方平台接入)
WebhookAPIKey string
}
@@ -87,12 +91,23 @@ func Load() *Config {
LLMAPIKey: getEnv("LLM_API_KEY", ""),
LLMModel: getEnv("LLM_MODEL", "gpt-4o"),
WSMaxConnections: getEnvInt("WS_MAX_CONNECTIONS", 1000),
WSMaxConnections: getEnvInt("WS_MAX_CONNECTIONS", 1000),
SessionIdleTimeoutMin: getEnvInt("SESSION_IDLE_TIMEOUT_MIN", 30),
WebhookAPIKey: getEnv("WEBHOOK_API_KEY", ""),
}
}
// DatabaseURL 构建 PostgreSQL 连接字符串
func (c *Config) DatabaseURL() string {
return fmt.Sprintf(
"postgres://%s:%s@%s:%s/%s?sslmode=disable",
c.PostgresUser, c.PostgresPass,
c.PostgresHost, c.PostgresPort,
c.PostgresDB,
)
}
// GenerateToken 生成JWT token
func (c *Config) GenerateToken(userID string) (string, error) {
claims := jwt.MapClaims{
@@ -1,143 +1,267 @@
package handler
import (
"log"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/store"
"github.com/yourname/cyrene-ai/gateway/internal/ws"
)
// SessionHandler 会话管理处理器
type SessionHandler struct {
// MVP阶段使用内存存储,后续迁移到PostgreSQL
sessions map[string][]SessionInfo // userID -> sessions
hub *ws.Hub
}
// SessionInfo 会话信息
type SessionInfo struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Title string `json:"title"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
MessageCount int `json:"message_count"`
IsActive bool `json:"is_active"`
store *store.SessionStore // PostgreSQL 持久化存储
hub *ws.Hub
useDB bool // 数据库是否可用
}
// NewSessionHandler 创建会话处理器
func NewSessionHandler(hub *ws.Hub) *SessionHandler {
func NewSessionHandler(hub *ws.Hub, s *store.SessionStore) *SessionHandler {
return &SessionHandler{
sessions: make(map[string][]SessionInfo),
hub: hub,
store: s,
hub: hub,
useDB: s != nil && s.IsAvailable(),
}
}
// ========== POST /api/v1/sessions — 创建会话 ==========
type createSessionRequest struct {
UserID string `json:"user_id"`
SessionID string `json:"session_id"`
Title string `json:"title"`
IsMain bool `json:"is_main"`
}
// Create 创建新会话
func (h *SessionHandler) Create(c *gin.Context) {
userID := middleware.GetUserID(c)
var req struct {
Title string `json:"title"`
}
var req createSessionRequest
if err := c.ShouldBindJSON(&req); err != nil {
// 允许空body
req.Title = "新的对话"
// 允许空 body
}
if req.UserID != "" {
userID = req.UserID
}
if req.Title == "" {
req.Title = "新的对话"
}
session := SessionInfo{
ID: "session_" + randomID(12),
UserID: userID,
Title: req.Title,
CreatedAt: time.Now().UnixMilli(),
UpdatedAt: time.Now().UnixMilli(),
if req.SessionID == "" {
req.SessionID = "session_" + randomID(12)
}
h.sessions[userID] = append([]SessionInfo{session}, h.sessions[userID]...)
c.JSON(http.StatusCreated, session)
}
// List 获取会话列表
func (h *SessionHandler) List(c *gin.Context) {
userID := middleware.GetUserID(c)
sessions, ok := h.sessions[userID]
if !ok {
sessions = []SessionInfo{}
}
c.JSON(http.StatusOK, gin.H{
"sessions": sessions,
})
}
// Delete 删除会话
func (h *SessionHandler) Delete(c *gin.Context) {
userID := middleware.GetUserID(c)
sessionID := c.Param("id")
sessions := h.sessions[userID]
for i, s := range sessions {
if s.ID == sessionID {
h.sessions[userID] = append(sessions[:i], sessions[i+1:]...)
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
if h.useDB {
if err := h.store.CreateSession(userID, req.SessionID, req.Title, req.IsMain); err != nil {
log.Printf("[SessionHandler] 创建会话失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建会话失败", "errorType": "db_error"})
return
}
}
c.JSON(http.StatusNotFound, gin.H{
"error": "会话不存在",
"errorType": "session_not_found",
"hint": "会话可能已被删除,或 Gateway 重启后内存数据已清空",
c.JSON(http.StatusCreated, gin.H{
"id": req.SessionID,
"user_id": userID,
"title": req.Title,
"is_main": req.IsMain,
"created_at": time.Now().UnixMilli(),
"updated_at": time.Now().UnixMilli(),
})
}
// ========== GET /api/v1/sessions?user_id=xxx — 获取用户会话列表 ==========
// List 获取会话列表 (按 updated_at DESC 排序)
func (h *SessionHandler) List(c *gin.Context) {
userID := c.Query("user_id")
if userID == "" {
userID = middleware.GetUserID(c)
}
if h.useDB {
sessions, err := h.store.GetUserSessions(userID)
if err != nil {
log.Printf("[SessionHandler] 查询会话列表失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询会话失败", "errorType": "db_error"})
return
}
// 转换为列表格式
result := make([]gin.H, 0, len(sessions))
for _, s := range sessions {
result = append(result, gin.H{
"id": s.ID,
"user_id": s.UserID,
"title": s.Title,
"is_main": s.IsMain,
"created_at": s.CreatedAt.UnixMilli(),
"updated_at": s.UpdatedAt.UnixMilli(),
})
}
c.JSON(http.StatusOK, gin.H{"sessions": result})
return
}
// 降级:返回空列表
c.JSON(http.StatusOK, gin.H{"sessions": []gin.H{}})
}
// ========== GET /api/v1/sessions/:id — 获取单个会话 ==========
// Get 获取单个会话信息
func (h *SessionHandler) Get(c *gin.Context) {
userID := middleware.GetUserID(c)
sessionID := c.Param("id")
for _, s := range h.sessions[userID] {
if s.ID == sessionID {
c.JSON(http.StatusOK, s)
if h.useDB {
session, err := h.store.GetSession(sessionID)
if err != nil {
log.Printf("[SessionHandler] 查询会话失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询会话失败", "errorType": "db_error"})
return
}
if session == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "会话不存在",
"errorType": "session_not_found",
"hint": "该会话可能已被删除或尚未创建",
})
return
}
c.JSON(http.StatusOK, gin.H{
"id": session.ID,
"user_id": session.UserID,
"title": session.Title,
"is_main": session.IsMain,
"created_at": session.CreatedAt.UnixMilli(),
"updated_at": session.UpdatedAt.UnixMilli(),
})
return
}
c.JSON(http.StatusNotFound, gin.H{
"error": "会话存储不可用",
"errorType": "store_unavailable",
"hint": "数据库连接未建立,Gateway 运行在仅内存模式",
})
}
// ========== DELETE /api/v1/sessions/:id — 删除会话 ==========
// Delete 删除会话 (不删除记忆)
func (h *SessionHandler) Delete(c *gin.Context) {
sessionID := c.Param("id")
if h.useDB {
if err := h.store.DeleteSession(sessionID); err != nil {
log.Printf("[SessionHandler] 删除会话失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除会话失败", "errorType": "db_error"})
return
}
}
c.JSON(http.StatusNotFound, gin.H{
"error": "会话不存在",
"errorType": "session_not_found",
"hint": "会话可能已被删除,或 Gateway 重启后内存数据已清空",
})
// 同时清理 Hub 中的缓存
h.hub.DeleteConversation("", sessionID)
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
// ========== DELETE /api/v1/sessions?user_id=xxx — 删除用户所有会话 ==========
// DeleteAll 删除用户所有会话 (不删除记忆)
func (h *SessionHandler) DeleteAll(c *gin.Context) {
userID := c.Query("user_id")
if userID == "" {
userID = middleware.GetUserID(c)
}
if h.useDB {
if err := h.store.DeleteAllUserSessions(userID); err != nil {
log.Printf("[SessionHandler] 删除用户所有会话失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除会话失败", "errorType": "db_error"})
return
}
}
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
// ========== GET /api/v1/sessions/:id/messages?limit=50 — 获取会话消息 ==========
// GetMessages 获取会话的完整消息列表
func (h *SessionHandler) GetMessages(c *gin.Context) {
userID := middleware.GetUserID(c)
sessionID := c.Param("id")
limit := 50
if l := c.Query("limit"); l != "" {
parsed := 0
for _, ch := range l {
if ch < '0' || ch > '9' {
break
}
parsed = parsed*10 + int(ch-'0')
}
if parsed > 0 {
limit = parsed
}
}
messages := h.hub.GetConversation(userID, sessionID)
if h.useDB {
messages, err := h.store.GetMessages(sessionID, limit)
if err != nil {
log.Printf("[SessionHandler] 查询消息失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询消息失败", "errorType": "db_error"})
return
}
// 转换为统一格式
result := make([]gin.H, 0, len(messages))
for _, m := range messages {
result = append(result, gin.H{
"id": m.ID,
"session_id": m.SessionID,
"role": m.Role,
"content": m.Content,
"created_at": m.CreatedAt.UnixMilli(),
})
}
c.JSON(http.StatusOK, gin.H{"messages": result})
return
}
// 降级:从 Hub 内存缓存读取
messages := h.hub.GetConversation("", sessionID)
if messages == nil {
messages = []ws.Message{}
}
c.JSON(http.StatusOK, gin.H{"messages": messages})
}
c.JSON(http.StatusOK, gin.H{
"messages": messages,
})
// ========== DELETE /api/v1/sessions/:id/messages — 清空会话消息 ==========
// ClearMessages 清空会话所有消息但不删除会话本身
func (h *SessionHandler) ClearMessages(c *gin.Context) {
sessionID := c.Param("id")
if h.useDB {
if err := h.store.ClearSessionMessages(sessionID); err != nil {
log.Printf("[SessionHandler] 清空消息失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空消息失败", "errorType": "db_error"})
return
}
}
// 同时清理 Hub 内存缓存
h.hub.DeleteConversation("", sessionID)
c.JSON(http.StatusOK, gin.H{"status": "cleared"})
}
// ========== Admin 端点 ==========
// ListActiveSessions 获取当前所有活跃 WebSocket 会话列表 (管理员)
func (h *SessionHandler) ListActiveSessions(c *gin.Context) {
sessions := h.hub.GetActiveSessions()
sessions := h.hub.GetAllActiveSessions()
if sessions == nil {
sessions = []*ws.SessionState{}
}
@@ -188,6 +312,5 @@ func randomID(n int) string {
for i := range b {
b[i] = letters[i%len(letters)]
}
// 使用纳秒时间戳增加唯一性
return string(b)
}
+10 -7
View File
@@ -9,17 +9,18 @@ import (
"github.com/yourname/cyrene-ai/gateway/internal/config"
"github.com/yourname/cyrene-ai/gateway/internal/handler"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/store"
"github.com/yourname/cyrene-ai/gateway/internal/ws"
)
// Setup 注册所有路由
func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config) {
func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.SessionStore) {
// 限流器
rateLimiter := middleware.NewRateLimiter(10, 20) // 每秒10个请求,突发20
// 初始化处理器
authHandler := handler.NewAuthHandler(cfg)
sessionHandler := handler.NewSessionHandler(hub)
sessionHandler := handler.NewSessionHandler(hub, sessionStore)
memoryHandler := handler.NewMemoryHandler(cfg.AICoreURL)
chatHandler := handler.NewChatHandler(cfg, hub)
webhookHandler := handler.NewWebhookHandler(cfg, hub)
@@ -54,11 +55,13 @@ func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config) {
// 会话管理
sessions := protected.Group("/sessions")
{
sessions.POST("", sessionHandler.Create)
sessions.GET("", sessionHandler.List)
sessions.GET("/:id", sessionHandler.Get)
sessions.DELETE("/:id", sessionHandler.Delete)
sessions.GET("/:id/messages", sessionHandler.GetMessages)
sessions.POST("", sessionHandler.Create) // POST /api/v1/sessions
sessions.GET("", sessionHandler.List) // GET /api/v1/sessions?user_id=xxx
sessions.DELETE("", sessionHandler.DeleteAll) // DELETE /api/v1/sessions?user_id=xxx
sessions.GET("/:id", sessionHandler.Get) // GET /api/v1/sessions/:id
sessions.DELETE("/:id", sessionHandler.Delete) // DELETE /api/v1/sessions/:id
sessions.GET("/:id/messages", sessionHandler.GetMessages) // GET /api/v1/sessions/:id/messages?limit=50
sessions.DELETE("/:id/messages", sessionHandler.ClearMessages) // DELETE /api/v1/sessions/:id/messages
}
// 记忆管理
@@ -0,0 +1,270 @@
package store
import (
"database/sql"
"fmt"
"log"
"time"
_ "github.com/lib/pq"
)
// Session 会话模型
type Session struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Title string `json:"title"`
IsMain bool `json:"is_main"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// Message 消息模型
type Message struct {
ID int `json:"id"`
SessionID string `json:"session_id"`
Role string `json:"role"`
Content string `json:"content"`
CreatedAt time.Time `json:"created_at"`
}
// SessionStore 会话持久化存储
type SessionStore struct {
db *sql.DB
}
// NewSessionStore 初始化数据库连接并自动建表
// 如果连接失败,返回 nil 和错误(调用方可以选择降级为仅内存模式)
func NewSessionStore(databaseURL string) (*SessionStore, error) {
db, err := sql.Open("postgres", databaseURL)
if err != nil {
return nil, fmt.Errorf("无法打开数据库连接: %w", err)
}
// 配置连接池
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
// 验证连接
if err := db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("数据库连接验证失败: %w", err)
}
store := &SessionStore{db: db}
// 自动建表
if err := store.migrate(); err != nil {
db.Close()
return nil, fmt.Errorf("数据库迁移失败: %w", err)
}
log.Println("[SessionStore] PostgreSQL 持久化存储已初始化")
return store, nil
}
// migrate 自动创建表结构
func (s *SessionStore) migrate() error {
queries := []string{
`CREATE TABLE IF NOT EXISTS sessions (
id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(128) NOT NULL,
title VARCHAR(256) DEFAULT '',
is_main BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at DESC)`,
`CREATE TABLE IF NOT EXISTS messages (
id SERIAL PRIMARY KEY,
session_id VARCHAR(64) REFERENCES sessions(id) ON DELETE CASCADE,
role VARCHAR(16) NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMP DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages(session_id)`,
`CREATE INDEX IF NOT EXISTS idx_messages_created_at ON messages(session_id, created_at)`,
}
for _, q := range queries {
if _, err := s.db.Exec(q); err != nil {
return fmt.Errorf("迁移SQL执行失败: %w\nSQL: %s", err, q)
}
}
return nil
}
// CreateSession 创建新会话
func (s *SessionStore) CreateSession(userID, sessionID, title string, isMain bool) error {
now := time.Now()
_, err := s.db.Exec(
`INSERT INTO sessions (id, user_id, title, is_main, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $5)
ON CONFLICT (id) DO UPDATE SET updated_at = $5`,
sessionID, userID, title, isMain, now,
)
if err != nil {
return fmt.Errorf("创建会话失败: %w", err)
}
return nil
}
// GetUserSessions 获取用户的所有会话(按 updated_at DESC 排序)
func (s *SessionStore) GetUserSessions(userID string) ([]Session, error) {
rows, err := s.db.Query(
`SELECT id, user_id, title, is_main, created_at, updated_at
FROM sessions WHERE user_id = $1
ORDER BY updated_at DESC`,
userID,
)
if err != nil {
return nil, fmt.Errorf("查询用户会话失败: %w", err)
}
defer rows.Close()
var sessions []Session
for rows.Next() {
var sess Session
if err := rows.Scan(&sess.ID, &sess.UserID, &sess.Title, &sess.IsMain, &sess.CreatedAt, &sess.UpdatedAt); err != nil {
return nil, fmt.Errorf("扫描会话行失败: %w", err)
}
sessions = append(sessions, sess)
}
if sessions == nil {
sessions = []Session{}
}
return sessions, rows.Err()
}
// GetSession 获取单个会话
func (s *SessionStore) GetSession(sessionID string) (*Session, error) {
var sess Session
err := s.db.QueryRow(
`SELECT id, user_id, title, is_main, created_at, updated_at
FROM sessions WHERE id = $1`,
sessionID,
).Scan(&sess.ID, &sess.UserID, &sess.Title, &sess.IsMain, &sess.CreatedAt, &sess.UpdatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("查询会话失败: %w", err)
}
return &sess, nil
}
// UpdateSessionTitle 更新会话标题
func (s *SessionStore) UpdateSessionTitle(sessionID, title string) error {
_, err := s.db.Exec(
`UPDATE sessions SET title = $1, updated_at = NOW() WHERE id = $2`,
title, sessionID,
)
if err != nil {
return fmt.Errorf("更新会话标题失败: %w", err)
}
return nil
}
// UpdateSessionTime 更新会话的 updated_at 时间戳
func (s *SessionStore) UpdateSessionTime(sessionID string) error {
_, err := s.db.Exec(
`UPDATE sessions SET updated_at = NOW() WHERE id = $1`,
sessionID,
)
if err != nil {
return fmt.Errorf("更新会话时间失败: %w", err)
}
return nil
}
// DeleteSession 删除会话(级联删除消息,但不删除记忆)
func (s *SessionStore) DeleteSession(sessionID string) error {
_, err := s.db.Exec(`DELETE FROM sessions WHERE id = $1`, sessionID)
if err != nil {
return fmt.Errorf("删除会话失败: %w", err)
}
return nil
}
// DeleteAllUserSessions 删除用户的所有会话(但不删除记忆)
func (s *SessionStore) DeleteAllUserSessions(userID string) error {
_, err := s.db.Exec(`DELETE FROM sessions WHERE user_id = $1`, userID)
if err != nil {
return fmt.Errorf("删除用户所有会话失败: %w", err)
}
return nil
}
// AddMessage 添加一条消息到会话
func (s *SessionStore) AddMessage(sessionID, role, content string) error {
_, err := s.db.Exec(
`INSERT INTO messages (session_id, role, content) VALUES ($1, $2, $3)`,
sessionID, role, content,
)
if err != nil {
return fmt.Errorf("添加消息失败: %w", err)
}
return nil
}
// GetMessages 获取会话的消息列表(按时间正序)
func (s *SessionStore) GetMessages(sessionID string, limit int) ([]Message, error) {
if limit <= 0 {
limit = 50
}
rows, err := s.db.Query(
`SELECT id, session_id, role, content, created_at
FROM messages WHERE session_id = $1
ORDER BY created_at ASC
LIMIT $2`,
sessionID, limit,
)
if err != nil {
return nil, fmt.Errorf("查询消息失败: %w", err)
}
defer rows.Close()
var messages []Message
for rows.Next() {
var msg Message
if err := rows.Scan(&msg.ID, &msg.SessionID, &msg.Role, &msg.Content, &msg.CreatedAt); err != nil {
return nil, fmt.Errorf("扫描消息行失败: %w", err)
}
messages = append(messages, msg)
}
if messages == nil {
messages = []Message{}
}
return messages, rows.Err()
}
// ClearSessionMessages 清空会话的所有消息但不删除会话本身
func (s *SessionStore) ClearSessionMessages(sessionID string) error {
_, err := s.db.Exec(`DELETE FROM messages WHERE session_id = $1`, sessionID)
if err != nil {
return fmt.Errorf("清空会话消息失败: %w", err)
}
return nil
}
// Close 关闭数据库连接
func (s *SessionStore) Close() error {
if s.db != nil {
return s.db.Close()
}
return nil
}
// IsAvailable 检查存储是否可用(数据库连接正常)
func (s *SessionStore) IsAvailable() bool {
if s.db == nil {
return false
}
return s.db.Ping() == nil
}
+93 -2
View File
@@ -8,6 +8,8 @@ import (
"os"
"sync"
"time"
"github.com/yourname/cyrene-ai/gateway/internal/store"
)
// SessionState 会话状态
@@ -60,6 +62,22 @@ type Hub struct {
iotServiceURL string
iotStopCh chan struct{}
iotPollRunning bool
// 持久化存储 (可选,数据库连接失败时为 nil)
store *store.SessionStore
// 闲置超时时间
idleTimeout time.Duration
}
// SetStore 设置持久化存储 (可选)
func (h *Hub) SetStore(s *store.SessionStore) {
h.store = s
}
// SetIdleTimeout 设置闲置超时时间
func (h *Hub) SetIdleTimeout(minutes int) {
h.idleTimeout = time.Duration(minutes) * time.Minute
}
// NewHub 创建WebSocket Hub
@@ -72,9 +90,79 @@ func NewHub() *Hub {
userClients: make(map[string]map[*Client]bool),
sessions: make(map[string]*SessionState),
iotStopCh: make(chan struct{}),
idleTimeout: 30 * time.Minute, // 默认30分钟
}
}
// StartIdleCleanup 启动闲置会话清理 goroutine
// 每5分钟检查一次,将超过 idleTimeout 无活动的会话标记为 idle
func (h *Hub) StartIdleCleanup() {
go func() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
h.cleanupIdleSessions()
}
}()
log.Printf("[WS] 闲置会话清理已启动 (超时: %v)", h.idleTimeout)
}
// cleanupIdleSessions 标记超时会话为 idle(不删除状态)
func (h *Hub) cleanupIdleSessions() {
h.mu.Lock()
defer h.mu.Unlock()
now := time.Now()
idleCount := 0
for sessionID, s := range h.sessions {
// 检查该 session 是否还有活跃连接
hasActiveConn := false
for _, clients := range h.userClients {
for c := range clients {
if c.SessionID == sessionID {
hasActiveConn = true
break
}
}
if hasActiveConn {
break
}
}
// 如果没有活跃连接且超过闲置超时,标记为 idle
if !hasActiveConn && now.Sub(s.LastActivity) > h.idleTimeout {
if s.State != "idle" {
s.State = "idle"
idleCount++
}
}
}
if idleCount > 0 {
log.Printf("[WS] 闲置清理: %d 个会话标记为 idle", idleCount)
}
}
// GetAllActiveSessions 返回所有会话状态(包括 idle),供 DevTools 监看使用
func (h *Hub) GetAllActiveSessions() []*SessionState {
h.mu.RLock()
defer h.mu.RUnlock()
if h.sessions == nil || len(h.sessions) == 0 {
return []*SessionState{}
}
result := make([]*SessionState, 0, len(h.sessions))
for _, s := range h.sessions {
cp := *s
cp.RecentMessages = nil
result = append(result, &cp)
}
return result
}
// Run 启动Hub主循环
func (h *Hub) Run() {
for {
@@ -119,7 +207,7 @@ func (h *Hub) Run() {
}
}
// 检查该session是否还有其他连接,没有则移除会话状态
// 检查该session是否还有其他连接,没有则标记为 idle 而非删除
hasOtherConn := false
if clients, ok := h.userClients[client.UserID]; ok {
for c := range clients {
@@ -130,7 +218,10 @@ func (h *Hub) Run() {
}
}
if !hasOtherConn {
delete(h.sessions, client.SessionID)
// 不再删除 session 状态,而是标记为 idle 保留在内存中
if s, ok := h.sessions[client.SessionID]; ok {
s.State = "idle"
}
}
}
h.mu.Unlock()
+54 -12
View File
@@ -744,19 +744,22 @@ async function renderDashboard() {
<div class="cards-grid cards-4" id="dashboard-svc-cards"></div>
</div>
<!-- 数据库连接状态 -->
<div class="card">
<!-- 数据库状态卡片 -->
<div class="card" id="db-card">
<div class="card-header">
<span class="card-title">🗄️ 数据库连接</span>
${data.database?.checked ? `
<span class="badge ${data.database.postgresAlive ? 'badge-running' : 'badge-error'}">
PostgreSQL ${data.database.postgresAlive ? '通联' : '断开'}
</span>
<span class="badge ${data.database.tunnelRunning ? 'badge-running' : 'badge-stopped'}" style="margin-left:6px">
隧道 ${data.database.tunnelRunning ? '运行中' : '未运行'}
</span>
` : '<span class="badge badge-stopped">待检查</span>'}
<a href="#" onclick="switchPanel('database');return false" style="font-size:11px;color:var(--accent);text-decoration:none">🔍 详情 →</a>
<span class="card-title">🗄️ 数据库</span>
<span class="badge badge-stopped" id="db-status-badge">检查中...</span>
</div>
<div class="metrics">
<div class="metric"><div class="value" id="db-type-display">PostgreSQL</div><div class="label">类型</div></div>
<div class="metric"><div class="value" id="db-port-display">5432</div><div class="label">端口</div></div>
<div class="metric"><div class="value" id="db-uptime-display">—</div><div class="label">状态</div></div>
</div>
<div class="btn-group" style="margin-top:10px">
<button class="btn btn-xs btn-green" onclick="controlDB('start')">▶ 启动</button>
<button class="btn btn-xs btn-red" onclick="controlDB('stop')">⏹ 停止</button>
<button class="btn btn-xs" onclick="controlDB('restart')">🔄 重启</button>
<a href="#" onclick="switchPanel('database');return false" style="font-size:10px;color:var(--accent);text-decoration:none;margin-left:auto;align-self:center">🔍 详情 →</a>
</div>
</div>
@@ -789,6 +792,9 @@ async function renderDashboard() {
// 渲染服务卡片
renderDashboardSvcCards(svcs);
// 渲染数据库卡片
renderDBCard();
// 渲染性能快照
const perfContainer = document.getElementById('dashboard-perf');
const perf = data.performance?.perService || {};
@@ -1701,6 +1707,42 @@ async function tunnelAction(action) {
}
}
// ========== 数据库卡片控制 ==========
async function renderDBCard() {
const data = await api('/api/db/status');
const badge = document.getElementById('db-status-badge');
const typeDisplay = document.getElementById('db-type-display');
const portDisplay = document.getElementById('db-port-display');
const uptimeDisplay = document.getElementById('db-uptime-display');
if (data.error) {
if (badge) { badge.textContent = '错误'; badge.className = 'badge badge-error'; }
if (uptimeDisplay) uptimeDisplay.textContent = '错误';
return;
}
const online = data.online;
if (badge) {
badge.textContent = online ? '🟢 在线' : '🔴 离线';
badge.className = 'badge ' + (online ? 'badge-running' : 'badge-error');
}
if (typeDisplay) typeDisplay.textContent = 'PostgreSQL';
if (portDisplay) portDisplay.textContent = data.port || 5432;
if (uptimeDisplay) uptimeDisplay.textContent = online ? '已连接' : '未连接';
}
async function controlDB(action) {
showToast('正在' + action + '数据库...', 'info');
const data = await api('/api/db/' + action, { method: 'POST' });
if (data.error) {
showToast('操作失败: ' + data.error, 'error');
} else {
showToast('数据库 ' + action + ' 完成', 'success');
// 等待2秒后刷新状态
setTimeout(renderDBCard, 2000);
}
}
</script>
<script src="iot-panel.js"></script>
<script>
+10
View File
@@ -52,6 +52,16 @@ async function renderIoTPanel() {
}
var devices = result.devices;
// 固定设备排列顺序: 先按类型,同类型再按 device_id
var typeOrder = { 'sensor': 1, 'ac': 2, 'light': 3, 'curtain': 4, 'lock': 5, 'camera': 6, 'speaker': 7, 'thermostat': 8 };
devices.sort(function(a, b) {
var oa = typeOrder[a.type] || 99;
var ob = typeOrder[b.type] || 99;
if (oa !== ob) return oa - ob;
return (a.id || a.entity_id || '').localeCompare(b.id || b.entity_id || '');
});
var badge = document.getElementById('iot-badge');
if (badge) {
badge.textContent = devices.length;
+76
View File
@@ -696,6 +696,82 @@ app.post('/api/tunnel/:action', (req, res) => {
}
});
// ---- 数据库控制 (Docker Compose) ----
const DB_COMPOSE_FILE = path.join(ROOT, 'docker-compose.dev.db.yml');
const DB_PORT = 5432;
// GET /api/db/status
app.get('/api/db/status', (_req, res) => {
try {
const online = checkPort(DB_PORT);
res.json({
online,
port: DB_PORT,
checked_at: new Date().toISOString(),
});
} catch (err) {
res.json({
online: false,
port: DB_PORT,
checked_at: new Date().toISOString(),
});
}
});
// POST /api/db/start
app.post('/api/db/start', (_req, res) => {
try {
const out = execSync(`docker compose -f "${DB_COMPOSE_FILE}" up -d`, {
encoding: 'utf-8',
timeout: 60000,
stdio: 'pipe',
});
res.json({ success: true, action: 'start', output: out.trim() });
} catch (err) {
const stderr = err.stderr?.toString() || err.message;
res.status(500).json({ success: false, action: 'start', error: stderr });
}
});
// POST /api/db/stop
app.post('/api/db/stop', (_req, res) => {
try {
const out = execSync(`docker compose -f "${DB_COMPOSE_FILE}" down`, {
encoding: 'utf-8',
timeout: 30000,
stdio: 'pipe',
});
res.json({ success: true, action: 'stop', output: out.trim() });
} catch (err) {
const stderr = err.stderr?.toString() || err.message;
res.status(500).json({ success: false, action: 'stop', error: stderr });
}
});
// POST /api/db/restart
app.post('/api/db/restart', (_req, res) => {
try {
const downOut = execSync(`docker compose -f "${DB_COMPOSE_FILE}" down`, {
encoding: 'utf-8',
timeout: 30000,
stdio: 'pipe',
});
const upOut = execSync(`docker compose -f "${DB_COMPOSE_FILE}" up -d`, {
encoding: 'utf-8',
timeout: 60000,
stdio: 'pipe',
});
res.json({
success: true,
action: 'restart',
output: `down: ${downOut.trim()}\nup: ${upOut.trim()}`,
});
} catch (err) {
const stderr = err.stderr?.toString() || err.message;
res.status(500).json({ success: false, action: 'restart', error: stderr });
}
});
// ========== 启动 ==========
// 启动性能监控
performanceMonitor.start();
+72
View File
@@ -7,8 +7,16 @@ import { spawn, execSync } from 'child_process';
import { EventEmitter } from 'events';
import fs from 'fs';
import net from 'net';
import path from 'path';
import { fileURLToPath } from 'url';
import { SERVICES, logFile } from './config.js';
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
const ROOT = path.resolve(__dirname, '../..');
const DB_PORTS = [5432, 5433, 5434];
const DB_COMPOSE_FILE = path.join(ROOT, 'docker-compose.dev.db.yml');
/**
* 通过 TCP 连接尝试判断端口是否被占用若被占用则尝试用 fuser 释放
*/
@@ -36,6 +44,64 @@ function releasePort(port) {
});
}
/**
* 检查端口是否可连接 (TCP connect, 超时2秒)
*/
function isPortOpen(port) {
return new Promise((resolve) => {
const sock = new net.Socket();
sock.setTimeout(2000);
sock.on('connect', () => { sock.destroy(); resolve(true); });
sock.on('error', () => { sock.destroy(); resolve(false); });
sock.on('timeout', () => { sock.destroy(); resolve(false); });
sock.connect(port, '127.0.0.1');
});
}
/**
* 确保数据库在线
* 检查 DB_PORTS 中至少有一个端口可用若不可用则尝试 docker compose up
* 等待最多 30 秒检查数据库就绪
* @param {string} serviceId - 正在启动的服务 ID
* @param {EventEmitter} emitter - 用于发送日志事件
*/
async function ensureDBOnline(serviceId, emitter) {
// 1. 快速检查:任意数据库端口是否已在线
for (const port of DB_PORTS) {
if (await isPortOpen(port)) {
emitter.emit('log', serviceId, 'system', `数据库端口 ${port} 已在线`);
return;
}
}
// 2. 数据库不在线,尝试 docker compose up
emitter.emit('log', serviceId, 'system', '数据库未启动,正在通过 Docker Compose 启动...');
try {
execSync(`docker compose -f "${DB_COMPOSE_FILE}" up -d`, {
timeout: 60000,
stdio: 'pipe',
});
emitter.emit('log', serviceId, 'system', 'Docker Compose 启动命令已执行,等待数据库就绪...');
} catch (err) {
const stderr = err.stderr?.toString() || err.message;
emitter.emit('log', serviceId, 'error', `Docker Compose 启动失败: ${stderr}`);
}
// 3. 等待最多 30 秒检查数据库就绪
for (let i = 0; i < 30; i++) {
await new Promise((r) => setTimeout(r, 1000));
for (const port of DB_PORTS) {
if (await isPortOpen(port)) {
emitter.emit('log', serviceId, 'system', `数据库端口 ${port} 已就绪 (等待 ${i + 1}s)`);
return;
}
}
}
// 4. 30 秒后仍不可用
emitter.emit('log', serviceId, 'error', '⚠️ 数据库无法启动,请手动检查 Docker。将继续启动后端服务...');
}
class ProcessManager extends EventEmitter {
constructor() {
super();
@@ -65,6 +131,12 @@ class ProcessManager extends EventEmitter {
throw new Error(`${svc.name} 已在运行中`);
}
// 对 gateway 和 ai-core 做数据库前置检查
if (serviceId === 'gateway' || serviceId === 'ai-core') {
this.emit('log', serviceId, 'system', '检查数据库连接状态...');
await ensureDBOnline(serviceId, this);
}
// 启动前释放端口,避免 "address already in use"
if (svc.port) {
this.emit('log', serviceId, 'system', `检查端口 ${svc.port}...`);
+134 -2
View File
@@ -1,13 +1,37 @@
import { useState } from 'react';
import { useState, useEffect, useCallback, useRef } from 'react';
import { AppLayout } from '@/components/layout/AppLayout';
import { ChatContainer } from '@/components/chat/ChatContainer';
import { ChatInput } from '@/components/chat/ChatInput';
import { useAuth } from '@/hooks/useAuth';
import { useChat } from '@/hooks/useChat';
import { useSessionStore, isAdminUser } from '@/store/sessionStore';
import { useChatStore } from '@/store/chatStore';
import { fetchMessages } from '@/api/sessions';
/** URL Hash 工具 */
const SESSION_HASH_PREFIX = 'session=';
function getSessionIdFromHash(): string | null {
const hash = window.location.hash.slice(1); // 去掉 #
if (hash.startsWith(SESSION_HASH_PREFIX)) {
return hash.slice(SESSION_HASH_PREFIX.length);
}
return null;
}
function setHashSessionId(sessionId: string | null) {
if (sessionId) {
window.location.hash = SESSION_HASH_PREFIX + sessionId;
} else {
// 清除 hash
history.replaceState(null, '', window.location.pathname + window.location.search);
}
}
export default function App() {
const { isLoggedIn, login, register, loading: authLoading } = useAuth();
const { isLoggedIn, login, register, loading: authLoading, userId } = useAuth();
const { send } = useChat();
const { loadSessionsFromServer, ensureMainSession, setCurrentSessionId, setMessages, loadMessagesFromServer, sessions, currentSessionId } = useSessionStore();
const [authMode, setAuthMode] = useState<'login' | 'register'>('login');
const [username, setUsername] = useState('');
@@ -17,6 +41,114 @@ export default function App() {
const [error, setError] = useState('');
const [successMsg, setSuccessMsg] = useState('');
const initializedRef = useRef(false);
// ========== URL Hash 路由 ==========
/** 根据 hash 恢复或选择初始会话 */
const initSession = useCallback(async () => {
if (!userId || initializedRef.current) return;
initializedRef.current = true;
const admin = isAdminUser(userId);
// 1. 从服务端加载会话列表
await loadSessionsFromServer(userId);
const currentSessions = useSessionStore.getState().sessions;
// 2. 检查 URL hash
const hashId = getSessionIdFromHash();
if (hashId) {
// 尝试加载 hash 指定的会话
const found = currentSessions.find((s) => s.id === hashId);
if (found) {
setCurrentSessionId(found.id);
await loadMessagesFromServer(found.id);
return;
}
// 会话可能已被删除,尝试从 API 获取消息(404 时 catch
try {
const resp = await fetchMessages(hashId);
if (resp.messages && resp.messages.length > 0) {
// 消息存在说明会话仍有效(虽然不在列表里,可能是刚创建的)
setCurrentSessionId(hashId);
const msgs = resp.messages.map((m: any, i: number) => ({
id: m.id ? String(m.id) : `hist_${i}_${Date.now()}`,
role: m.role,
content: m.content,
timestamp: typeof m.created_at === 'number' ? m.created_at : Date.now(),
isStreaming: false,
}));
setMessages(msgs);
useChatStore.getState().setMessages(msgs);
return;
}
} catch {
// 会话不存在,回退
}
// 回退:清除 hash,加载最新/主对话
setHashSessionId(null);
}
// 3. 无 hash 或 hash 无效:加载最新会话
if (admin) {
// 管理员:确保主对话存在
const mainSession = await ensureMainSession(userId);
if (mainSession) {
setCurrentSessionId(mainSession.id);
setHashSessionId(mainSession.id);
await loadMessagesFromServer(mainSession.id);
return;
}
}
// 普通用户:选择最新会话
if (currentSessions.length > 0) {
const latest = currentSessions[0]; // 已按 updated_at DESC 排序
setCurrentSessionId(latest.id);
setHashSessionId(latest.id);
await loadMessagesFromServer(latest.id);
}
}, [userId, loadSessionsFromServer, ensureMainSession, setCurrentSessionId, setMessages, loadMessagesFromServer]);
// 登录后初始化
useEffect(() => {
if (isLoggedIn && userId) {
initSession();
}
}, [isLoggedIn, userId, initSession]);
// 监听 hashchange 事件 (浏览器前进/后退)
useEffect(() => {
if (!isLoggedIn) return;
const handleHashChange = async () => {
const hashId = getSessionIdFromHash();
const currentId = useSessionStore.getState().currentSessionId;
if (hashId && hashId !== currentId) {
// hash 变化,切换会话
setCurrentSessionId(hashId);
await loadMessagesFromServer(hashId);
}
};
window.addEventListener('hashchange', handleHashChange);
return () => window.removeEventListener('hashchange', handleHashChange);
}, [isLoggedIn, setCurrentSessionId, loadMessagesFromServer]);
// 当前会话变化时更新 URL hash(仅在登录后、非 hashchange 驱动时)
useEffect(() => {
if (!isLoggedIn || !currentSessionId) return;
const hashId = getSessionIdFromHash();
if (hashId !== currentSessionId) {
setHashSessionId(currentSessionId);
}
}, [isLoggedIn, currentSessionId]);
// ========== 认证相关 ==========
const handleLogin = async () => {
setError('');
const result = await login(username, password);
+110 -8
View File
@@ -1,8 +1,110 @@
// 会话API
export {
createSession,
listSessions,
getSession,
deleteSession,
fetchSessionMessages,
} from './client';
// 会话持久化 API — 对接 Gateway REST API
import { request } from './client';
import type { Session, SessionListResponse, SessionMessagesResponse, SessionResponse } from '@/types/session';
/**
*
* GET /api/v1/sessions?user_id={userId}
*/
export async function fetchSessions(userId: string): Promise<Session[]> {
const resp = await request<SessionListResponse>(
`/sessions?user_id=${encodeURIComponent(userId)}`
);
if (resp.error) {
console.error('[sessions] 获取会话列表失败:', resp.error);
return [];
}
// Gateway 返回 { sessions: [...] }
return (resp.data as SessionListResponse)?.sessions || [];
}
/**
*
* GET /api/v1/sessions/{sessionId}/messages?limit={limit}
*/
export async function fetchMessages(
sessionId: string,
limit: number = 50
): Promise<SessionMessagesResponse> {
const resp = await request<SessionMessagesResponse>(
`/sessions/${encodeURIComponent(sessionId)}/messages?limit=${limit}`
);
if (resp.error) {
console.error('[sessions] 获取消息失败:', resp.error);
return { messages: [] };
}
return (resp.data as SessionMessagesResponse) || { messages: [] };
}
/**
*
* POST /api/v1/sessions
*/
export async function createSession(
userId: string,
sessionId: string,
title: string = '新的对话',
isMain: boolean = false
): Promise<SessionResponse | null> {
const resp = await request<SessionResponse>('/sessions', {
method: 'POST',
body: {
user_id: userId,
session_id: sessionId,
title,
is_main: isMain,
},
});
if (resp.error) {
console.error('[sessions] 创建会话失败:', resp.error);
return null;
}
return resp.data as SessionResponse;
}
/**
* ()
* DELETE /api/v1/sessions/{sessionId}
*/
export async function deleteSession(sessionId: string): Promise<boolean> {
const resp = await request(`/sessions/${encodeURIComponent(sessionId)}`, {
method: 'DELETE',
});
if (resp.error) {
console.error('[sessions] 删除会话失败:', resp.error);
return false;
}
return true;
}
/**
* ()
* DELETE /api/v1/sessions?user_id={userId}
*/
export async function deleteAllSessions(userId: string): Promise<boolean> {
const resp = await request(`/sessions?user_id=${encodeURIComponent(userId)}`, {
method: 'DELETE',
});
if (resp.error) {
console.error('[sessions] 删除所有会话失败:', resp.error);
return false;
}
return true;
}
/**
*
* DELETE /api/v1/sessions/{sessionId}/messages
*/
export async function clearSessionMessages(sessionId: string): Promise<boolean> {
const resp = await request(
`/sessions/${encodeURIComponent(sessionId)}/messages`,
{ method: 'DELETE' }
);
if (resp.error) {
console.error('[sessions] 清空消息失败:', resp.error);
return false;
}
return true;
}
+180 -42
View File
@@ -1,6 +1,8 @@
import { useState, useCallback } from 'react';
import { useSession } from '@/hooks/useSession';
import { useSessionStore } from '@/store/sessionStore';
import { CyreneAvatar } from '@/components/persona/CyreneAvatar';
import type { Session } from '@/types/session';
interface SidebarProps {
onClose?: () => void;
@@ -11,12 +13,25 @@ export function Sidebar({ onClose }: SidebarProps) {
sessions,
createSession,
deleteSession,
deleteAllSessions,
clearMainSession,
setCurrentSession,
isAdmin,
} = useSession();
const currentSessionId = useSessionStore((s) => s.currentSessionId);
const displaySessions = sessions;
const activeSessionId = currentSessionId;
// 确认弹窗状态
const [confirmAction, setConfirmAction] = useState<{
type: 'delete' | 'clearMain' | 'deleteAll';
sessionId?: string;
} | null>(null);
// 按 updated_at 降序排列
const displaySessions = [...sessions].sort((a, b) => {
const ta = typeof a.updated_at === 'string' ? parseInt(a.updated_at, 10) : (a.updated_at as unknown as number);
const tb = typeof b.updated_at === 'string' ? parseInt(b.updated_at, 10) : (b.updated_at as unknown as number);
return (tb || 0) - (ta || 0);
});
const handleNewChat = async () => {
const session = await createSession();
@@ -28,11 +43,43 @@ export function Sidebar({ onClose }: SidebarProps) {
if (onClose) onClose();
};
const handleDeleteSession = (e: { stopPropagation: () => void }, id: string) => {
e.stopPropagation();
deleteSession(id);
const handleMainSession = async () => {
// 找到主对话
const mainSession = displaySessions.find((s) => s.is_main);
if (mainSession) {
setCurrentSession(mainSession.id);
if (onClose) onClose();
}
};
const handleDeleteClick = (e: React.MouseEvent, id: string) => {
e.stopPropagation();
const session = displaySessions.find((s) => s.id === id);
// 管理员主对话不可删除
if (isAdmin && session?.is_main) return;
setConfirmAction({ type: 'delete', sessionId: id });
};
const handleConfirmAction = async () => {
if (!confirmAction) return;
switch (confirmAction.type) {
case 'delete':
if (confirmAction.sessionId) {
await deleteSession(confirmAction.sessionId);
}
break;
case 'clearMain':
await clearMainSession();
break;
case 'deleteAll':
await deleteAllSessions();
break;
}
setConfirmAction(null);
};
const cancelConfirm = () => setConfirmAction(null);
/** 格式化时间戳为可读字符串 */
const formatTime = (ts: string | number): string => {
if (!ts) return '';
@@ -52,7 +99,28 @@ export function Sidebar({ onClose }: SidebarProps) {
return (
<aside className="h-full bg-white/90 dark:bg-gray-900/90 border-r border-pink-100 dark:border-pink-900 flex flex-col">
{/* 侧边栏头部 */}
{/* 主对话按钮 (仅管理员可见) */}
{isAdmin && (
<div className="p-3 border-b border-pink-100 dark:border-pink-900 flex gap-2">
<button
onClick={handleMainSession}
className="flex-1 flex items-center justify-center gap-1.5 px-3 py-2 bg-amber-50 hover:bg-amber-100 dark:bg-amber-900/20 dark:hover:bg-amber-900/30 text-amber-600 dark:text-amber-400 rounded-xl text-sm font-medium transition-colors border border-amber-200 dark:border-amber-800"
title="回到主对话"
>
<span>🏠</span>
<span></span>
</button>
<button
onClick={() => setConfirmAction({ type: 'clearMain' })}
className="flex items-center justify-center px-2 py-2 bg-red-50 hover:bg-red-100 dark:bg-red-900/20 dark:hover:bg-red-900/30 text-red-400 hover:text-red-500 rounded-xl text-sm transition-colors border border-red-200 dark:border-red-800"
title="清空主对话消息"
>
<span>🗑</span>
</button>
</div>
)}
{/* 新对话按钮 */}
<div className="p-4 border-b border-pink-100 dark:border-pink-900">
<button
onClick={handleNewChat}
@@ -70,46 +138,80 @@ export function Sidebar({ onClose }: SidebarProps) {
</p>
) : (
displaySessions.map((session) => (
<div
key={session.id}
onClick={() => handleSelectSession(session.id)}
className={`
group flex items-center justify-between px-4 py-2.5 mx-2 rounded-lg cursor-pointer transition-colors
${
activeSessionId === session.id
? 'bg-pink-50 dark:bg-pink-900/30 text-pink-600'
: 'hover:bg-gray-50 dark:hover:bg-gray-800 text-gray-600 dark:text-gray-300'
}
`}
>
<div className="flex items-center gap-2 min-w-0 flex-1">
<CyreneAvatar size="sm" />
<div className="min-w-0">
<p className="text-sm font-medium truncate">{session.title || '新的对话'}</p>
<p className="text-xs text-gray-400 truncate">
{session.message_count != null ? `${session.message_count} 条消息` : ''}
{session.updated_at ? ` · ${formatTime(session.updated_at)}` : ''}
</p>
</div>
</div>
<button
onClick={(e) => handleDeleteSession(e, session.id)}
className="opacity-0 group-hover:opacity-100 p-1 text-gray-400 hover:text-red-400 transition-all"
title="删除会话"
displaySessions.map((session) => {
const isMainSession = session.is_main;
const isActive = currentSessionId === session.id;
// 管理员主对话不可删除
const canDelete = !(isAdmin && isMainSession);
return (
<div
key={session.id}
onClick={() => handleSelectSession(session.id)}
className={`
group flex items-center justify-between px-4 py-2.5 mx-2 rounded-lg cursor-pointer transition-colors
${
isActive
? 'bg-pink-50 dark:bg-pink-900/30 text-pink-600'
: 'hover:bg-gray-50 dark:hover:bg-gray-800 text-gray-600 dark:text-gray-300'
}
`}
>
<svg xmlns="http://www.w3.org/2000/svg" className="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" />
</svg>
</button>
</div>
))
<div className="flex items-center gap-2 min-w-0 flex-1">
<CyreneAvatar size="sm" />
<div className="min-w-0">
<p className="text-sm font-medium truncate">
{isMainSession ? '🏠 ' : ''}
{session.title || '新的对话'}
</p>
<p className="text-xs text-gray-400 truncate">
{session.message_count != null
? `${session.message_count} 条消息`
: ''}
{session.updated_at
? `${session.message_count != null ? ' · ' : ''}${formatTime(session.updated_at)}`
: ''}
</p>
</div>
</div>
{canDelete && (
<button
onClick={(e) => handleDeleteClick(e, session.id)}
className="opacity-0 group-hover:opacity-100 p-1 text-gray-400 hover:text-red-400 transition-all"
title="删除会话"
>
<svg
xmlns="http://www.w3.org/2000/svg"
className="h-4 w-4"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
/>
</svg>
</button>
)}
</div>
);
})
)}
</div>
{/* 底部信息 */}
<div className="p-4 border-t border-pink-100 dark:border-pink-900">
<div className="flex items-center gap-2 text-xs text-gray-400">
{/* 底部:一键清空所有对话 */}
<div className="p-3 border-t border-pink-100 dark:border-pink-900">
<button
onClick={() => setConfirmAction({ type: 'deleteAll' })}
className="w-full flex items-center justify-center gap-1.5 px-3 py-2 bg-red-50 hover:bg-red-100 dark:bg-red-900/20 dark:hover:bg-red-900/30 text-red-400 hover:text-red-500 rounded-xl text-xs font-medium transition-colors border border-red-200 dark:border-red-800"
>
<span>🗑</span>
<span></span>
</button>
<div className="flex items-center gap-2 mt-3 text-xs text-gray-400">
<CyreneAvatar size="sm" />
<div>
<p className="font-medium text-pink-400"> AI</p>
@@ -117,6 +219,42 @@ export function Sidebar({ onClose }: SidebarProps) {
</div>
</div>
</div>
{/* 确认弹窗 */}
{confirmAction && (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/40">
<div className="bg-white dark:bg-gray-800 rounded-2xl shadow-xl p-6 m-4 max-w-sm w-full border border-pink-100 dark:border-pink-700">
<h3 className="text-lg font-semibold text-gray-800 dark:text-gray-200 mb-2">
{confirmAction.type === 'delete'
? '删除会话'
: confirmAction.type === 'clearMain'
? '清空主对话'
: '清空所有对话'}
</h3>
<p className="text-sm text-gray-500 dark:text-gray-400 mb-6">
{confirmAction.type === 'delete'
? '确定要删除这个会话吗?此操作不会删除记忆。'
: confirmAction.type === 'clearMain'
? '确定要清空主对话的所有消息吗?此操作不会删除记忆。'
: '确定要删除所有对话吗?此操作不会删除记忆。管理员将回到主对话,普通用户进入新对话。'}
</p>
<div className="flex justify-end gap-3">
<button
onClick={cancelConfirm}
className="px-4 py-2 rounded-lg text-sm text-gray-500 hover:bg-gray-100 dark:hover:bg-gray-700 transition-colors"
>
</button>
<button
onClick={handleConfirmAction}
className="px-4 py-2 rounded-lg text-sm bg-red-400 hover:bg-red-500 text-white font-medium transition-colors"
>
</button>
</div>
</div>
</div>
)}
</aside>
);
}
+97 -46
View File
@@ -1,9 +1,24 @@
import { useCallback, useEffect } from 'react';
import { useSessionStore } from '@/store/sessionStore';
import { useSessionStore, isAdminUser } from '@/store/sessionStore';
import { useChatStore } from '@/store/chatStore';
import { listSessions as apiListSessions, createSession as apiCreateSession, deleteSession as apiDeleteSession } from '@/api/client';
import { createSession as apiCreateSession } from '@/api/sessions';
import type { Session } from '@/types/session';
/** 生成简易随机ID */
function randomID(n: number = 12): string {
const letters = 'abcdefghijklmnopqrstuvwxyz0123456789';
let result = '';
for (let i = 0; i < n; i++) {
result += letters.charAt(Math.floor(Math.random() * letters.length));
}
return `session_${result}`;
}
/** 获取当前用户ID */
function getUserId(): string {
return localStorage.getItem('user_id') || '';
}
export function useSession() {
const {
sessions,
@@ -14,69 +29,105 @@ export function useSession() {
removeSession,
setCurrentSessionId,
setLoading,
loadSessionsFromServer,
loadMessagesFromServer,
clearMainSessionMessages,
deleteSessionAndRefresh,
deleteAllSessionsAndReset,
ensureMainSession,
} = useSessionStore();
const { clearMessages } = useChatStore();
const userId = getUserId();
const isAdmin = isAdminUser(userId);
/** 从服务端加载会话列表 */
const loadSessions = useCallback(async () => {
setLoading(true);
try {
const resp = await apiListSessions();
if (resp.data) {
const data = resp.data as { sessions: Session[] };
setSessions(data.sessions || []);
}
} catch {
// ignore
} finally {
setLoading(false);
}
}, [setSessions, setLoading]);
if (!userId) return;
await loadSessionsFromServer(userId);
}, [userId, loadSessionsFromServer]);
// 初始加载
useEffect(() => {
loadSessions();
}, [loadSessions]);
const createSession = useCallback(async (title?: string) => {
const resp = await apiCreateSession(title);
if (resp.data) {
const session = resp.data as Session;
addSession(session);
// setCurrentSessionId 内部已处理消息清空和历史加载,无需额外 clearMessages
await setCurrentSessionId(session.id);
return session;
}
return null;
}, [addSession, setCurrentSessionId]);
const deleteSession = useCallback(async (id: string) => {
await apiDeleteSession(id);
// 如果删除的是当前活跃会话,先切换到其他会话
if (currentSessionId === id) {
const remaining = useSessionStore.getState().sessions.filter((s: Session) => s.id !== id);
if (remaining.length > 0) {
// 切换到列表中的第一个会话
await setCurrentSessionId(remaining[0].id);
} else {
clearMessages();
setCurrentSessionId(null);
/** 创建新会话 (isMain 默认为 false) */
const createSession = useCallback(
async (title?: string, isMain: boolean = false) => {
const sid = randomID();
const created = await apiCreateSession(
userId,
sid,
title || '新的对话',
isMain
);
if (created) {
const newSession: Session = {
id: created.id,
user_id: created.user_id,
title: created.title,
is_main: created.is_main,
created_at: String(created.created_at),
updated_at: String(created.updated_at),
message_count: 0,
};
addSession(newSession);
setCurrentSessionId(newSession.id);
return newSession;
}
}
removeSession(id);
}, [removeSession, currentSessionId, clearMessages, setCurrentSessionId]);
return null;
},
[userId, addSession, setCurrentSessionId]
);
const setCurrentSession = useCallback((id: string) => {
// setCurrentSessionId 内部已处理消息清空和历史加载
setCurrentSessionId(id);
}, [setCurrentSessionId]);
/** 清空主对话消息 */
const clearMainSession = useCallback(async () => {
const mainSession = useSessionStore
.getState()
.sessions.find((s: Session) => s.is_main);
if (!mainSession) {
// 尝试确保主对话存在
const ensured = await ensureMainSession(userId);
if (!ensured) return false;
return clearMainSessionMessages(ensured.id);
}
return clearMainSessionMessages(mainSession.id);
}, [userId, clearMainSessionMessages, ensureMainSession]);
/** 删除单个会话 */
const deleteSession = useCallback(
async (id: string) => {
await deleteSessionAndRefresh(id, userId);
},
[userId, deleteSessionAndRefresh]
);
/** 删除所有会话 */
const deleteAllSessions = useCallback(async () => {
await deleteAllSessionsAndReset(userId);
}, [userId, deleteAllSessionsAndReset]);
/** 切换当前会话 */
const setCurrentSession = useCallback(
async (id: string) => {
setCurrentSessionId(id);
// 加载该会话的消息历史
await loadMessagesFromServer(id);
},
[setCurrentSessionId, loadMessagesFromServer]
);
return {
sessions,
currentSessionId,
loading,
isAdmin,
userId,
loadSessions,
createSession,
clearMainSession,
deleteSession,
deleteAllSessions,
setCurrentSession,
};
}
+50 -22
View File
@@ -2,6 +2,7 @@ import { useEffect, useRef, useCallback, useState } from 'react';
import { useChatStore } from '@/store/chatStore';
import { useSessionStore } from '@/store/sessionStore';
import { getToken } from '@/api/client';
import { fetchMessages } from '@/api/sessions';
import type { WSClientMessage, WSServerMessage } from '@/types/chat';
const WS_BASE_URL =
@@ -12,7 +13,8 @@ export function useWebSocket() {
const wsRef = useRef<WebSocket | null>(null);
const reconnectTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
const shouldReconnectRef = useRef(true);
const activeSessionRef = useRef<string | null>(null); // 追踪当前活跃会话,防止竞态
const activeSessionRef = useRef<string | null>(null);
const loadingRef = useRef(false); // 防止重复加载消息
// 订阅 sessionStore 中的 currentSessionId 变化
const currentSessionId = useSessionStore((s) => s.currentSessionId);
@@ -40,7 +42,7 @@ export function useWebSocket() {
shouldReconnectRef.current = true;
console.log('[WS] 已连接, session_id:', sessionID);
// 连接后发送会话恢复消息,恢复历史上下文
// 连接后发送会话恢复消息,恢复后端上下文
const sid = useSessionStore.getState().currentSessionId;
if (sid) {
const resumeMsg: WSClientMessage = {
@@ -81,10 +83,39 @@ export function useWebSocket() {
wsRef.current = ws;
}, []);
// 初始连接 + 会话切换时重连
// 会话切换时:先通过 REST API 加载历史消息,再建立 WS 连接
useEffect(() => {
activeSessionRef.current = currentSessionId;
connect();
const loadAndConnect = async () => {
// 如果是从 URL 恢复的 session,先加载历史消息
if (currentSessionId && !loadingRef.current) {
loadingRef.current = true;
try {
const resp = await fetchMessages(currentSessionId);
const rawMessages = resp.messages || [];
const msgs = rawMessages.map((m: any, i: number) => ({
id: m.id ? String(m.id) : `hist_${i}_${Date.now()}`,
role: m.role,
content: m.content,
timestamp:
typeof m.created_at === 'number' ? m.created_at : Date.now(),
isStreaming: false,
}));
useSessionStore.getState().setMessages(msgs);
useChatStore.getState().setMessages(msgs);
} catch {
// 加载失败不影响后续连接
} finally {
loadingRef.current = false;
}
}
connect();
};
loadAndConnect();
return () => {
if (reconnectTimerRef.current) {
clearTimeout(reconnectTimerRef.current);
@@ -96,12 +127,13 @@ export function useWebSocket() {
const sendMessage = useCallback((msg: WSClientMessage) => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
// 自动附上 session_id
const sessionID = useSessionStore.getState().currentSessionId;
wsRef.current.send(JSON.stringify({
...msg,
session_id: msg.session_id || sessionID || undefined,
}));
wsRef.current.send(
JSON.stringify({
...msg,
session_id: msg.session_id || sessionID || undefined,
})
);
}
}, []);
@@ -109,7 +141,8 @@ export function useWebSocket() {
}
function handleServerMessage(msg: WSServerMessage) {
const { addMessage, appendToLastMessage, finishStreaming, setTyping } = useChatStore.getState();
const { addMessage, appendToLastMessage, finishStreaming, setTyping } =
useChatStore.getState();
const { setMessages } = useSessionStore.getState();
const chatState = useChatStore.getState();
@@ -128,21 +161,22 @@ function handleServerMessage(msg: WSServerMessage) {
case 'stream_chunk':
if (msg.content) {
// 首个 chunk 到达时创建消息并隐藏 typing indicator
const { messages } = useChatStore.getState();
const lastMsg = messages[messages.length - 1];
if (!lastMsg || lastMsg.role !== 'assistant' || !lastMsg.isStreaming) {
// 创建新的流式消息
if (
!lastMsg ||
lastMsg.role !== 'assistant' ||
!lastMsg.isStreaming
) {
addMessage({
id: msg.message_id || ('msg_' + Date.now()),
id: msg.message_id || 'msg_' + Date.now(),
role: 'assistant',
content: msg.content,
timestamp: msg.timestamp,
isStreaming: true,
});
setTyping(false); // 首个 chunk 到达,隐藏 typing 指示器
setTyping(false);
} else {
// 追加到现有流式消息
appendToLastMessage(msg.content);
}
}
@@ -154,28 +188,23 @@ function handleServerMessage(msg: WSServerMessage) {
case 'history_response':
if (msg.messages) {
// 确保每条消息都有 id
const msgsWithIds = msg.messages.map((m: any, i: number) => ({
...m,
id: m.id || `hist_${i}_${Date.now()}`,
}));
// 同步历史消息到两个 store
setMessages(msgsWithIds);
useChatStore.getState().setMessages(msgsWithIds);
}
// 确保历史加载后 typing indicator 关闭
setTyping(false);
break;
case 'device_update':
// 处理 IoT 设备状态更新
if (msg.devices && msg.devices.length > 0) {
chatState.setIoTDevices(msg.devices);
}
break;
case 'background_thinking':
// 处理后端推送的后台思考状态
if (msg.thinking_status) {
chatState.setBackgroundThinkingStatus(msg.thinking_status);
}
@@ -187,7 +216,6 @@ function handleServerMessage(msg: WSServerMessage) {
break;
case 'pong':
// 忽略心跳响应
break;
default:
+160 -32
View File
@@ -1,15 +1,38 @@
import { create } from 'zustand';
import type { Session } from '@/types/session';
import type { Message } from '@/types/chat';
import { fetchSessionMessages as apiFetchMessages } from '@/api/client';
import {
fetchSessions,
fetchMessages,
createSession as apiCreateSession,
deleteSession as apiDeleteSession,
deleteAllSessions as apiDeleteAllSessions,
clearSessionMessages as apiClearMessages,
} from '@/api/sessions';
import { useChatStore } from '@/store/chatStore';
/** 生成简易随机ID */
function randomID(n: number = 12): string {
const letters = 'abcdefghijklmnopqrstuvwxyz0123456789';
let result = '';
for (let i = 0; i < n; i++) {
result += letters.charAt(Math.floor(Math.random() * letters.length));
}
return `session_${result}`;
}
/** 判断是否为管理员用户 (user_id 以 "admin_" 开头) */
export function isAdminUser(userId: string | null): boolean {
return userId?.startsWith('admin_') ?? false;
}
interface SessionStore {
sessions: Session[];
currentSessionId: string | null;
loading: boolean;
messages: Message[];
// 基础操作
setSessions: (sessions: Session[]) => void;
addSession: (session: Session) => void;
removeSession: (id: string) => void;
@@ -17,9 +40,17 @@ interface SessionStore {
setLoading: (loading: boolean) => void;
setMessages: (messages: Message[]) => void;
clearMessages: () => void;
// 服务端持久化操作
loadSessionsFromServer: (userId: string) => Promise<void>;
loadMessagesFromServer: (sessionId: string) => Promise<void>;
clearMainSessionMessages: (sessionId: string) => Promise<boolean>;
deleteSessionAndRefresh: (id: string, userId: string) => Promise<void>;
deleteAllSessionsAndReset: (userId: string) => Promise<void>;
ensureMainSession: (userId: string) => Promise<Session | null>;
}
export const useSessionStore = create<SessionStore>((set) => ({
export const useSessionStore = create<SessionStore>((set, get) => ({
sessions: [],
currentSessionId: null,
loading: false,
@@ -34,46 +65,143 @@ export const useSessionStore = create<SessionStore>((set) => ({
currentSessionId: state.currentSessionId === id ? null : state.currentSessionId,
messages: state.currentSessionId === id ? [] : state.messages,
})),
setCurrentSessionId: async (id) => {
// 立即清除旧消息,防止闪旧数据
set({ currentSessionId: id, messages: [], loading: true });
useChatStore.getState().clearMessages();
// 清除旧消息(同时清 chatStore)
if (id === null) {
set({ messages: [], loading: false });
return;
}
// 从后端加载历史消息
try {
const resp = await apiFetchMessages(id);
if (resp.data) {
const data = resp.data as { messages: Message[] };
const msgs = (data.messages || []).map((m: Message, i: number) => ({
...m,
id: m.id || `hist_${i}_${Date.now()}`,
}));
set({ messages: msgs, loading: false });
// 同步到 chatStore 以便 ChatContainer 渲染
useChatStore.getState().setMessages(msgs);
} else {
set({ messages: [], loading: false });
useChatStore.getState().clearMessages();
}
} catch {
set({ messages: [], loading: false });
setCurrentSessionId: (id) => {
set({ currentSessionId: id });
// 切换会话时清空旧消息,等待加载
if (id !== get().currentSessionId) {
set({ messages: [], loading: true });
useChatStore.getState().clearMessages();
}
},
setLoading: (loading) => set({ loading }),
setMessages: (messages) => {
set({ messages });
// 同步到 chatStore
useChatStore.getState().setMessages(messages);
},
clearMessages: () => {
set({ messages: [] });
useChatStore.getState().clearMessages();
},
// ========== 服务端持久化操作 ==========
/**
*
*/
loadSessionsFromServer: async (userId: string) => {
set({ loading: true });
try {
const sessions = await fetchSessions(userId);
set({ sessions, loading: false });
} catch {
set({ loading: false });
}
},
/**
*
*/
loadMessagesFromServer: async (sessionId: string) => {
set({ loading: true });
try {
const resp = await fetchMessages(sessionId);
const rawMessages = resp.messages || [];
const msgs: Message[] = rawMessages.map((m: any, i: number) => ({
id: m.id ? String(m.id) : `hist_${i}_${Date.now()}`,
role: m.role,
content: m.content,
timestamp: typeof m.created_at === 'number' ? m.created_at : Date.now(),
isStreaming: false,
}));
set({ messages: msgs, loading: false });
useChatStore.getState().setMessages(msgs);
} catch {
set({ messages: [], loading: false });
useChatStore.getState().clearMessages();
}
},
/**
*
*/
clearMainSessionMessages: async (sessionId: string) => {
const ok = await apiClearMessages(sessionId);
if (ok) {
set({ messages: [] });
useChatStore.getState().clearMessages();
}
return ok;
},
/**
*
*/
deleteSessionAndRefresh: async (id: string, userId: string) => {
const ok = await apiDeleteSession(id);
if (!ok) return;
const state = get();
const remaining = state.sessions.filter((s) => s.id !== id);
const wasCurrent = state.currentSessionId === id;
// 更新本地列表
set({ sessions: remaining });
if (wasCurrent) {
if (remaining.length > 0) {
// 切换到列表中的第一个会话
const nextId = remaining[0].id;
set({ currentSessionId: nextId });
await get().loadMessagesFromServer(nextId);
} else {
// 没有会话了:管理员回到主对话,普通用户创建新对话
set({ currentSessionId: null, messages: [] });
useChatStore.getState().clearMessages();
}
}
},
/**
*
*/
deleteAllSessionsAndReset: async (userId: string) => {
const ok = await apiDeleteAllSessions(userId);
if (!ok) return;
set({ sessions: [], currentSessionId: null, messages: [] });
useChatStore.getState().clearMessages();
},
/**
*
*/
ensureMainSession: async (userId: string) => {
const state = get();
// 先检查本地列表
const existing = state.sessions.find((s) => s.is_main);
if (existing) return existing;
// 本地没有,尝试从服务端加载
await get().loadSessionsFromServer(userId);
const refreshed = get().sessions.find((s) => s.is_main);
if (refreshed) return refreshed;
// 服务端也没有,创建主对话
const sid = randomID();
const created = await apiCreateSession(userId, sid, '主对话', true);
if (created) {
const newSession: Session = {
id: created.id,
user_id: created.user_id,
title: created.title,
is_main: created.is_main,
created_at: String(created.created_at),
updated_at: String(created.updated_at),
message_count: 0,
};
set((s) => ({ sessions: [newSession, ...s.sessions] }));
return newSession;
}
return null;
},
}));
+13 -2
View File
@@ -5,15 +5,16 @@ export interface Session {
id: string;
user_id: string;
title: string;
is_main: boolean;
created_at: string;
updated_at: string;
message_count: number;
is_active: boolean;
message_count?: number;
}
/** 创建会话参数 */
export interface CreateSessionParams {
title?: string;
is_main?: boolean;
}
/** 会话列表响应 */
@@ -26,6 +27,16 @@ export interface SessionMessagesResponse {
messages: import('@/types/chat').Message[];
}
/** 单个会话响应 */
export interface SessionResponse {
id: string;
user_id: string;
title: string;
is_main: boolean;
created_at: string;
updated_at: string;
}
/** 认证相关 */
export interface AuthResponse {
user_id: string;