fix: 第二轮修复 — 数据库启动检查、会话持久化、URL路由、设备排序等
1. DevTools 启动前检查数据库状态,失败时自动尝试启动 2. ai-core 添加数据库断线重连机制 (30秒间隔) 3. Dashboard 添加数据库状态卡片 (启动/停止/重启) 4. Gateway 会话空闲超时管理 (30分钟标记空闲) 5. 会话/消息 PostgreSQL 持久化 (SessionStore + REST API) 6. 前端服务端会话持久化 + URL hash 路由 + 侧边栏管理 7. 管理员回到主对话按钮 8. IoT 设备卡片固定排序 9. 更新相关文档
This commit is contained in:
@@ -64,21 +64,16 @@ func main() {
|
||||
var memExtractor *memory.Extractor
|
||||
|
||||
if cfg.DatabaseURL != "" {
|
||||
memStore, err = memory.NewStore(cfg.DatabaseURL)
|
||||
if err != nil {
|
||||
log.Printf("⚠ 记忆存储初始化失败 (将跳过记忆功能): %v", err)
|
||||
} else {
|
||||
defer memStore.Close()
|
||||
log.Println("记忆存储已就绪")
|
||||
memStore = memory.NewStore(cfg.DatabaseURL)
|
||||
defer memStore.Close()
|
||||
|
||||
memRetriever = memory.NewRetriever(memStore, nil)
|
||||
memRetriever = memory.NewRetriever(memStore, nil)
|
||||
|
||||
// 记忆提取器使用LLM
|
||||
memExtractor = memory.NewExtractor(memStore, func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) {
|
||||
return llmAdapter.Chat(ctx, messages)
|
||||
})
|
||||
log.Println("记忆提取器已就绪")
|
||||
}
|
||||
// 记忆提取器使用LLM
|
||||
memExtractor = memory.NewExtractor(memStore, func(ctx context.Context, messages []model.LLMMessage) (*model.LLMResponse, error) {
|
||||
return llmAdapter.Chat(ctx, messages)
|
||||
})
|
||||
log.Println("记忆提取器已就绪")
|
||||
}
|
||||
|
||||
// 初始化会话历史存储
|
||||
|
||||
@@ -4,22 +4,100 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/ai-core/internal/model"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
const reconnectInterval = 30 * time.Second
|
||||
|
||||
// Store 记忆持久化存储(PostgreSQL + pgvector)
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
databaseURL string
|
||||
mu sync.RWMutex
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// errDBNotReady 数据库未就绪时返回的友好错误
|
||||
var errDBNotReady = fmt.Errorf("记忆系统未就绪: 数据库连接不可用,正在后台重试连接")
|
||||
|
||||
// NewStore 创建记忆存储
|
||||
func NewStore(connStr string) (*Store, error) {
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
// 连接失败时不返回 error,而是启动后台重连循环
|
||||
func NewStore(connStr string) *Store {
|
||||
s := &Store{
|
||||
databaseURL: connStr,
|
||||
}
|
||||
|
||||
// 尝试初始连接
|
||||
if err := s.Reconnect(); err != nil {
|
||||
log.Printf("[memory] ⚠ 记忆存储初始化: 数据库连接失败 (%v),将在后台每30秒重试", err)
|
||||
} else {
|
||||
log.Println("[memory] 记忆存储已就绪")
|
||||
}
|
||||
|
||||
// 启动后台重连 goroutine
|
||||
go s.reconnectLoop()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// reconnectLoop 后台重连循环
|
||||
func (s *Store) reconnectLoop() {
|
||||
ticker := time.NewTicker(reconnectInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
s.mu.RLock()
|
||||
ready := s.db != nil
|
||||
s.mu.RUnlock()
|
||||
|
||||
if ready {
|
||||
// 数据库已连接,检查连接是否仍然有效
|
||||
s.mu.RLock()
|
||||
db := s.db
|
||||
s.mu.RUnlock()
|
||||
if db != nil {
|
||||
if err := db.Ping(); err != nil {
|
||||
log.Printf("[memory] ⚠ 数据库连接丢失: %v,开始重连", err)
|
||||
s.mu.Lock()
|
||||
if s.db != nil {
|
||||
s.db.Close()
|
||||
s.db = nil
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !s.IsReady() {
|
||||
if err := s.Reconnect(); err != nil {
|
||||
log.Printf("[memory] ⚠ 数据库重连失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reconnect 尝试重连数据库并执行迁移
|
||||
func (s *Store) Reconnect() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 如果已有有效连接,先检查
|
||||
if s.db != nil {
|
||||
if err := s.db.Ping(); err == nil {
|
||||
return nil // 仍然有效
|
||||
}
|
||||
// 连接已失效,关闭旧连接
|
||||
s.db.Close()
|
||||
s.db = nil
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", s.databaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||
return fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(25)
|
||||
@@ -27,15 +105,36 @@ func NewStore(connStr string) (*Store, error) {
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("数据库ping失败: %w", err)
|
||||
db.Close()
|
||||
return fmt.Errorf("数据库ping失败: %w", err)
|
||||
}
|
||||
|
||||
s := &Store{db: db}
|
||||
s.db = db
|
||||
|
||||
// 执行建表迁移
|
||||
if err := s.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
log.Printf("[memory] ⚠ 数据库迁移失败: %v", err)
|
||||
s.db.Close()
|
||||
s.db = nil
|
||||
return fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
log.Println("[memory] ✅ 数据库重连成功,记忆系统已就绪")
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsReady 返回数据库是否可用
|
||||
func (s *Store) IsReady() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.db != nil
|
||||
}
|
||||
|
||||
// getDB 获取当前数据库连接(带读锁保护)
|
||||
func (s *Store) getDB() *sql.DB {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.db
|
||||
}
|
||||
|
||||
// migrate 创建表结构
|
||||
@@ -73,6 +172,11 @@ func (s *Store) migrate() error {
|
||||
|
||||
// Save 保存记忆
|
||||
func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return errDBNotReady
|
||||
}
|
||||
|
||||
query := `INSERT INTO memories (user_id, content, summary, category, priority, session_id, source, embedding, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING id, created_at`
|
||||
@@ -86,7 +190,7 @@ func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
embedding = fmt.Sprintf("[%s]", joinFloats(vec))
|
||||
}
|
||||
|
||||
return s.db.QueryRowContext(ctx, query,
|
||||
return db.QueryRowContext(ctx, query,
|
||||
entry.UserID, entry.Content, entry.Summary,
|
||||
string(entry.Category), int(entry.Priority),
|
||||
entry.SessionID, entry.Source, embedding, entry.ExpiresAt,
|
||||
@@ -95,13 +199,18 @@ func (s *Store) Save(ctx context.Context, entry *model.MemoryEntry) error {
|
||||
|
||||
// GetByID 根据ID获取记忆
|
||||
func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
query := `SELECT id, user_id, content, summary, category, priority, session_id, source,
|
||||
access_count, last_access, created_at, expires_at
|
||||
FROM memories WHERE id = $1`
|
||||
|
||||
entry := &model.MemoryEntry{}
|
||||
var category string
|
||||
err := s.db.QueryRowContext(ctx, query, id).Scan(
|
||||
err := db.QueryRowContext(ctx, query, id).Scan(
|
||||
&entry.ID, &entry.UserID, &entry.Content, &entry.Summary,
|
||||
&category, &entry.Priority, &entry.SessionID, &entry.Source,
|
||||
&entry.AccessCount, &entry.LastAccess, &entry.CreatedAt, &entry.ExpiresAt,
|
||||
@@ -122,6 +231,11 @@ func (s *Store) GetByID(ctx context.Context, id string) (*model.MemoryEntry, err
|
||||
|
||||
// Query 按条件查询记忆
|
||||
func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryEntry, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
if q.Limit <= 0 {
|
||||
q.Limit = 10
|
||||
}
|
||||
@@ -147,7 +261,7 @@ func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryE
|
||||
query += fmt.Sprintf(" ORDER BY priority DESC, created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
|
||||
args = append(args, q.Limit, q.Offset)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询记忆失败: %w", err)
|
||||
}
|
||||
@@ -173,13 +287,21 @@ func (s *Store) Query(ctx context.Context, q model.MemoryQuery) ([]model.MemoryE
|
||||
|
||||
// Delete 删除记忆
|
||||
func (s *Store) Delete(ctx context.Context, id string) error {
|
||||
_, err := s.db.ExecContext(ctx, `DELETE FROM memories WHERE id = $1`, id)
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return errDBNotReady
|
||||
}
|
||||
_, err := db.ExecContext(ctx, `DELETE FROM memories WHERE id = $1`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// PurgeExpired 清理过期记忆
|
||||
func (s *Store) PurgeExpired(ctx context.Context) (int64, error) {
|
||||
result, err := s.db.ExecContext(ctx,
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return 0, errDBNotReady
|
||||
}
|
||||
result, err := db.ExecContext(ctx,
|
||||
`DELETE FROM memories WHERE expires_at IS NOT NULL AND expires_at < NOW()`)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -189,6 +311,11 @@ func (s *Store) PurgeExpired(ctx context.Context) (int64, error) {
|
||||
|
||||
// SearchByVector 向量相似度搜索
|
||||
func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []float64, limit int) ([]model.MemoryEntry, error) {
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return nil, errDBNotReady
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = 5
|
||||
}
|
||||
@@ -202,7 +329,7 @@ func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []f
|
||||
ORDER BY embedding <=> $1
|
||||
LIMIT $3`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, vecStr, userID, limit)
|
||||
rows, err := db.QueryContext(ctx, query, vecStr, userID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("向量搜索失败: %w", err)
|
||||
}
|
||||
@@ -229,13 +356,23 @@ func (s *Store) SearchByVector(ctx context.Context, userID string, embedding []f
|
||||
}
|
||||
|
||||
func (s *Store) incrementAccess(ctx context.Context, id string) {
|
||||
s.db.ExecContext(ctx,
|
||||
db := s.getDB()
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
db.ExecContext(ctx,
|
||||
`UPDATE memories SET access_count = access_count + 1, last_access = NOW() WHERE id = $1`, id)
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (s *Store) Close() error {
|
||||
return s.db.Close()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.db != nil {
|
||||
return s.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// joinFloats 将 float64 切片转为逗号分隔字符串
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/router"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/ws"
|
||||
)
|
||||
|
||||
@@ -26,6 +27,17 @@ func main() {
|
||||
// 加载配置
|
||||
cfg := config.Load()
|
||||
|
||||
// 初始化数据库持久化存储 (降级:连接失败不崩溃)
|
||||
var sessionStore *store.SessionStore
|
||||
databaseURL := cfg.DatabaseURL()
|
||||
if s, err := store.NewSessionStore(databaseURL); err != nil {
|
||||
log.Printf("⚠ 会话持久化存储初始化失败 (数据库不可用): %v", err)
|
||||
log.Println("⚠ Gateway 将以仅内存模式运行 — 会话数据在重启后丢失")
|
||||
} else {
|
||||
sessionStore = s
|
||||
log.Println("✅ 会话持久化存储已启用 (PostgreSQL)")
|
||||
}
|
||||
|
||||
// 初始化Gin
|
||||
if cfg.Env == "production" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -39,13 +51,18 @@ func main() {
|
||||
|
||||
// 初始化WebSocket Hub
|
||||
hub := ws.NewHub()
|
||||
hub.SetStore(sessionStore)
|
||||
hub.SetIdleTimeout(cfg.SessionIdleTimeoutMin)
|
||||
go hub.Run()
|
||||
|
||||
// 启动闲置会话清理 (标记超时会话为 idle,不删除)
|
||||
hub.StartIdleCleanup()
|
||||
|
||||
// 启动 IoT 设备状态广播(每10秒向所有WebSocket客户端推送设备状态)
|
||||
hub.StartIoTBroadcast(cfg.IoTDebugServiceURL)
|
||||
|
||||
// 注册路由
|
||||
router.Setup(r, hub, cfg)
|
||||
router.Setup(r, hub, cfg, sessionStore)
|
||||
|
||||
// 启动服务
|
||||
srv := &http.Server{
|
||||
@@ -68,6 +85,13 @@ func main() {
|
||||
|
||||
hub.StopIoTBroadcast()
|
||||
|
||||
// 关闭数据库连接
|
||||
if sessionStore != nil {
|
||||
if err := sessionStore.Close(); err != nil {
|
||||
log.Printf("⚠ 关闭数据库连接失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
srv.Shutdown(ctx)
|
||||
|
||||
@@ -7,6 +7,7 @@ require (
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/lib/pq v1.10.9
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
@@ -42,6 +42,8 @@ github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZY
|
||||
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -49,6 +50,9 @@ type Config struct {
|
||||
// WebSocket
|
||||
WSMaxConnections int
|
||||
|
||||
// 会话闲置超时 (分钟) — 超过此时间后会话标记为 idle 但不删除
|
||||
SessionIdleTimeoutMin int
|
||||
|
||||
// Webhook (第三方平台接入)
|
||||
WebhookAPIKey string
|
||||
}
|
||||
@@ -87,12 +91,23 @@ func Load() *Config {
|
||||
LLMAPIKey: getEnv("LLM_API_KEY", ""),
|
||||
LLMModel: getEnv("LLM_MODEL", "gpt-4o"),
|
||||
|
||||
WSMaxConnections: getEnvInt("WS_MAX_CONNECTIONS", 1000),
|
||||
WSMaxConnections: getEnvInt("WS_MAX_CONNECTIONS", 1000),
|
||||
SessionIdleTimeoutMin: getEnvInt("SESSION_IDLE_TIMEOUT_MIN", 30),
|
||||
|
||||
WebhookAPIKey: getEnv("WEBHOOK_API_KEY", ""),
|
||||
}
|
||||
}
|
||||
|
||||
// DatabaseURL 构建 PostgreSQL 连接字符串
|
||||
func (c *Config) DatabaseURL() string {
|
||||
return fmt.Sprintf(
|
||||
"postgres://%s:%s@%s:%s/%s?sslmode=disable",
|
||||
c.PostgresUser, c.PostgresPass,
|
||||
c.PostgresHost, c.PostgresPort,
|
||||
c.PostgresDB,
|
||||
)
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
func (c *Config) GenerateToken(userID string) (string, error) {
|
||||
claims := jwt.MapClaims{
|
||||
|
||||
@@ -1,143 +1,267 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/ws"
|
||||
)
|
||||
|
||||
// SessionHandler 会话管理处理器
|
||||
type SessionHandler struct {
|
||||
// MVP阶段使用内存存储,后续迁移到PostgreSQL
|
||||
sessions map[string][]SessionInfo // userID -> sessions
|
||||
hub *ws.Hub
|
||||
}
|
||||
|
||||
// SessionInfo 会话信息
|
||||
type SessionInfo struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Title string `json:"title"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
MessageCount int `json:"message_count"`
|
||||
IsActive bool `json:"is_active"`
|
||||
store *store.SessionStore // PostgreSQL 持久化存储
|
||||
hub *ws.Hub
|
||||
useDB bool // 数据库是否可用
|
||||
}
|
||||
|
||||
// NewSessionHandler 创建会话处理器
|
||||
func NewSessionHandler(hub *ws.Hub) *SessionHandler {
|
||||
func NewSessionHandler(hub *ws.Hub, s *store.SessionStore) *SessionHandler {
|
||||
return &SessionHandler{
|
||||
sessions: make(map[string][]SessionInfo),
|
||||
hub: hub,
|
||||
store: s,
|
||||
hub: hub,
|
||||
useDB: s != nil && s.IsAvailable(),
|
||||
}
|
||||
}
|
||||
|
||||
// ========== POST /api/v1/sessions — 创建会话 ==========
|
||||
|
||||
type createSessionRequest struct {
|
||||
UserID string `json:"user_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Title string `json:"title"`
|
||||
IsMain bool `json:"is_main"`
|
||||
}
|
||||
|
||||
// Create 创建新会话
|
||||
func (h *SessionHandler) Create(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
var req struct {
|
||||
Title string `json:"title"`
|
||||
}
|
||||
var req createSessionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// 允许空body
|
||||
req.Title = "新的对话"
|
||||
// 允许空 body
|
||||
}
|
||||
if req.UserID != "" {
|
||||
userID = req.UserID
|
||||
}
|
||||
if req.Title == "" {
|
||||
req.Title = "新的对话"
|
||||
}
|
||||
|
||||
session := SessionInfo{
|
||||
ID: "session_" + randomID(12),
|
||||
UserID: userID,
|
||||
Title: req.Title,
|
||||
CreatedAt: time.Now().UnixMilli(),
|
||||
UpdatedAt: time.Now().UnixMilli(),
|
||||
if req.SessionID == "" {
|
||||
req.SessionID = "session_" + randomID(12)
|
||||
}
|
||||
|
||||
h.sessions[userID] = append([]SessionInfo{session}, h.sessions[userID]...)
|
||||
|
||||
c.JSON(http.StatusCreated, session)
|
||||
}
|
||||
|
||||
// List 获取会话列表
|
||||
func (h *SessionHandler) List(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
sessions, ok := h.sessions[userID]
|
||||
if !ok {
|
||||
sessions = []SessionInfo{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"sessions": sessions,
|
||||
})
|
||||
}
|
||||
|
||||
// Delete 删除会话
|
||||
func (h *SessionHandler) Delete(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
sessionID := c.Param("id")
|
||||
|
||||
sessions := h.sessions[userID]
|
||||
for i, s := range sessions {
|
||||
if s.ID == sessionID {
|
||||
h.sessions[userID] = append(sessions[:i], sessions[i+1:]...)
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||
if h.useDB {
|
||||
if err := h.store.CreateSession(userID, req.SessionID, req.Title, req.IsMain); err != nil {
|
||||
log.Printf("[SessionHandler] 创建会话失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建会话失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "会话不存在",
|
||||
"errorType": "session_not_found",
|
||||
"hint": "会话可能已被删除,或 Gateway 重启后内存数据已清空",
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"id": req.SessionID,
|
||||
"user_id": userID,
|
||||
"title": req.Title,
|
||||
"is_main": req.IsMain,
|
||||
"created_at": time.Now().UnixMilli(),
|
||||
"updated_at": time.Now().UnixMilli(),
|
||||
})
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/sessions?user_id=xxx — 获取用户会话列表 ==========
|
||||
|
||||
// List 获取会话列表 (按 updated_at DESC 排序)
|
||||
func (h *SessionHandler) List(c *gin.Context) {
|
||||
userID := c.Query("user_id")
|
||||
if userID == "" {
|
||||
userID = middleware.GetUserID(c)
|
||||
}
|
||||
|
||||
if h.useDB {
|
||||
sessions, err := h.store.GetUserSessions(userID)
|
||||
if err != nil {
|
||||
log.Printf("[SessionHandler] 查询会话列表失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询会话失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
// 转换为列表格式
|
||||
result := make([]gin.H, 0, len(sessions))
|
||||
for _, s := range sessions {
|
||||
result = append(result, gin.H{
|
||||
"id": s.ID,
|
||||
"user_id": s.UserID,
|
||||
"title": s.Title,
|
||||
"is_main": s.IsMain,
|
||||
"created_at": s.CreatedAt.UnixMilli(),
|
||||
"updated_at": s.UpdatedAt.UnixMilli(),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"sessions": result})
|
||||
return
|
||||
}
|
||||
|
||||
// 降级:返回空列表
|
||||
c.JSON(http.StatusOK, gin.H{"sessions": []gin.H{}})
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/sessions/:id — 获取单个会话 ==========
|
||||
|
||||
// Get 获取单个会话信息
|
||||
func (h *SessionHandler) Get(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
sessionID := c.Param("id")
|
||||
|
||||
for _, s := range h.sessions[userID] {
|
||||
if s.ID == sessionID {
|
||||
c.JSON(http.StatusOK, s)
|
||||
if h.useDB {
|
||||
session, err := h.store.GetSession(sessionID)
|
||||
if err != nil {
|
||||
log.Printf("[SessionHandler] 查询会话失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询会话失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if session == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "会话不存在",
|
||||
"errorType": "session_not_found",
|
||||
"hint": "该会话可能已被删除或尚未创建",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": session.ID,
|
||||
"user_id": session.UserID,
|
||||
"title": session.Title,
|
||||
"is_main": session.IsMain,
|
||||
"created_at": session.CreatedAt.UnixMilli(),
|
||||
"updated_at": session.UpdatedAt.UnixMilli(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "会话存储不可用",
|
||||
"errorType": "store_unavailable",
|
||||
"hint": "数据库连接未建立,Gateway 运行在仅内存模式",
|
||||
})
|
||||
}
|
||||
|
||||
// ========== DELETE /api/v1/sessions/:id — 删除会话 ==========
|
||||
|
||||
// Delete 删除会话 (不删除记忆)
|
||||
func (h *SessionHandler) Delete(c *gin.Context) {
|
||||
sessionID := c.Param("id")
|
||||
|
||||
if h.useDB {
|
||||
if err := h.store.DeleteSession(sessionID); err != nil {
|
||||
log.Printf("[SessionHandler] 删除会话失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除会话失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "会话不存在",
|
||||
"errorType": "session_not_found",
|
||||
"hint": "会话可能已被删除,或 Gateway 重启后内存数据已清空",
|
||||
})
|
||||
// 同时清理 Hub 中的缓存
|
||||
h.hub.DeleteConversation("", sessionID)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||
}
|
||||
|
||||
// ========== DELETE /api/v1/sessions?user_id=xxx — 删除用户所有会话 ==========
|
||||
|
||||
// DeleteAll 删除用户所有会话 (不删除记忆)
|
||||
func (h *SessionHandler) DeleteAll(c *gin.Context) {
|
||||
userID := c.Query("user_id")
|
||||
if userID == "" {
|
||||
userID = middleware.GetUserID(c)
|
||||
}
|
||||
|
||||
if h.useDB {
|
||||
if err := h.store.DeleteAllUserSessions(userID); err != nil {
|
||||
log.Printf("[SessionHandler] 删除用户所有会话失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除会话失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/sessions/:id/messages?limit=50 — 获取会话消息 ==========
|
||||
|
||||
// GetMessages 获取会话的完整消息列表
|
||||
func (h *SessionHandler) GetMessages(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
sessionID := c.Param("id")
|
||||
limit := 50
|
||||
if l := c.Query("limit"); l != "" {
|
||||
parsed := 0
|
||||
for _, ch := range l {
|
||||
if ch < '0' || ch > '9' {
|
||||
break
|
||||
}
|
||||
parsed = parsed*10 + int(ch-'0')
|
||||
}
|
||||
if parsed > 0 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
messages := h.hub.GetConversation(userID, sessionID)
|
||||
if h.useDB {
|
||||
messages, err := h.store.GetMessages(sessionID, limit)
|
||||
if err != nil {
|
||||
log.Printf("[SessionHandler] 查询消息失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询消息失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
// 转换为统一格式
|
||||
result := make([]gin.H, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
result = append(result, gin.H{
|
||||
"id": m.ID,
|
||||
"session_id": m.SessionID,
|
||||
"role": m.Role,
|
||||
"content": m.Content,
|
||||
"created_at": m.CreatedAt.UnixMilli(),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"messages": result})
|
||||
return
|
||||
}
|
||||
|
||||
// 降级:从 Hub 内存缓存读取
|
||||
messages := h.hub.GetConversation("", sessionID)
|
||||
if messages == nil {
|
||||
messages = []ws.Message{}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"messages": messages})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"messages": messages,
|
||||
})
|
||||
// ========== DELETE /api/v1/sessions/:id/messages — 清空会话消息 ==========
|
||||
|
||||
// ClearMessages 清空会话所有消息但不删除会话本身
|
||||
func (h *SessionHandler) ClearMessages(c *gin.Context) {
|
||||
sessionID := c.Param("id")
|
||||
|
||||
if h.useDB {
|
||||
if err := h.store.ClearSessionMessages(sessionID); err != nil {
|
||||
log.Printf("[SessionHandler] 清空消息失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空消息失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 同时清理 Hub 内存缓存
|
||||
h.hub.DeleteConversation("", sessionID)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "cleared"})
|
||||
}
|
||||
|
||||
// ========== Admin 端点 ==========
|
||||
|
||||
// ListActiveSessions 获取当前所有活跃 WebSocket 会话列表 (管理员)
|
||||
func (h *SessionHandler) ListActiveSessions(c *gin.Context) {
|
||||
sessions := h.hub.GetActiveSessions()
|
||||
sessions := h.hub.GetAllActiveSessions()
|
||||
if sessions == nil {
|
||||
sessions = []*ws.SessionState{}
|
||||
}
|
||||
@@ -188,6 +312,5 @@ func randomID(n int) string {
|
||||
for i := range b {
|
||||
b[i] = letters[i%len(letters)]
|
||||
}
|
||||
// 使用纳秒时间戳增加唯一性
|
||||
return string(b)
|
||||
}
|
||||
|
||||
@@ -9,17 +9,18 @@ import (
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/handler"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/ws"
|
||||
)
|
||||
|
||||
// Setup 注册所有路由
|
||||
func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config) {
|
||||
func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.SessionStore) {
|
||||
// 限流器
|
||||
rateLimiter := middleware.NewRateLimiter(10, 20) // 每秒10个请求,突发20
|
||||
|
||||
// 初始化处理器
|
||||
authHandler := handler.NewAuthHandler(cfg)
|
||||
sessionHandler := handler.NewSessionHandler(hub)
|
||||
sessionHandler := handler.NewSessionHandler(hub, sessionStore)
|
||||
memoryHandler := handler.NewMemoryHandler(cfg.AICoreURL)
|
||||
chatHandler := handler.NewChatHandler(cfg, hub)
|
||||
webhookHandler := handler.NewWebhookHandler(cfg, hub)
|
||||
@@ -54,11 +55,13 @@ func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config) {
|
||||
// 会话管理
|
||||
sessions := protected.Group("/sessions")
|
||||
{
|
||||
sessions.POST("", sessionHandler.Create)
|
||||
sessions.GET("", sessionHandler.List)
|
||||
sessions.GET("/:id", sessionHandler.Get)
|
||||
sessions.DELETE("/:id", sessionHandler.Delete)
|
||||
sessions.GET("/:id/messages", sessionHandler.GetMessages)
|
||||
sessions.POST("", sessionHandler.Create) // POST /api/v1/sessions
|
||||
sessions.GET("", sessionHandler.List) // GET /api/v1/sessions?user_id=xxx
|
||||
sessions.DELETE("", sessionHandler.DeleteAll) // DELETE /api/v1/sessions?user_id=xxx
|
||||
sessions.GET("/:id", sessionHandler.Get) // GET /api/v1/sessions/:id
|
||||
sessions.DELETE("/:id", sessionHandler.Delete) // DELETE /api/v1/sessions/:id
|
||||
sessions.GET("/:id/messages", sessionHandler.GetMessages) // GET /api/v1/sessions/:id/messages?limit=50
|
||||
sessions.DELETE("/:id/messages", sessionHandler.ClearMessages) // DELETE /api/v1/sessions/:id/messages
|
||||
}
|
||||
|
||||
// 记忆管理
|
||||
|
||||
@@ -0,0 +1,270 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// Session 会话模型
|
||||
type Session struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Title string `json:"title"`
|
||||
IsMain bool `json:"is_main"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// Message 消息模型
|
||||
type Message struct {
|
||||
ID int `json:"id"`
|
||||
SessionID string `json:"session_id"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SessionStore 会话持久化存储
|
||||
type SessionStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewSessionStore 初始化数据库连接并自动建表
|
||||
// 如果连接失败,返回 nil 和错误(调用方可以选择降级为仅内存模式)
|
||||
func NewSessionStore(databaseURL string) (*SessionStore, error) {
|
||||
db, err := sql.Open("postgres", databaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无法打开数据库连接: %w", err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
// 验证连接
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("数据库连接验证失败: %w", err)
|
||||
}
|
||||
|
||||
store := &SessionStore{db: db}
|
||||
|
||||
// 自动建表
|
||||
if err := store.migrate(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
log.Println("[SessionStore] PostgreSQL 持久化存储已初始化")
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// migrate 自动创建表结构
|
||||
func (s *SessionStore) migrate() error {
|
||||
queries := []string{
|
||||
`CREATE TABLE IF NOT EXISTS sessions (
|
||||
id VARCHAR(64) PRIMARY KEY,
|
||||
user_id VARCHAR(128) NOT NULL,
|
||||
title VARCHAR(256) DEFAULT '',
|
||||
is_main BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at DESC)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS messages (
|
||||
id SERIAL PRIMARY KEY,
|
||||
session_id VARCHAR(64) REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
role VARCHAR(16) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_messages_session_id ON messages(session_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_messages_created_at ON messages(session_id, created_at)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
if _, err := s.db.Exec(q); err != nil {
|
||||
return fmt.Errorf("迁移SQL执行失败: %w\nSQL: %s", err, q)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateSession 创建新会话
|
||||
func (s *SessionStore) CreateSession(userID, sessionID, title string, isMain bool) error {
|
||||
now := time.Now()
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO sessions (id, user_id, title, is_main, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $5)
|
||||
ON CONFLICT (id) DO UPDATE SET updated_at = $5`,
|
||||
sessionID, userID, title, isMain, now,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建会话失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserSessions 获取用户的所有会话(按 updated_at DESC 排序)
|
||||
func (s *SessionStore) GetUserSessions(userID string) ([]Session, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, user_id, title, is_main, created_at, updated_at
|
||||
FROM sessions WHERE user_id = $1
|
||||
ORDER BY updated_at DESC`,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询用户会话失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var sessions []Session
|
||||
for rows.Next() {
|
||||
var sess Session
|
||||
if err := rows.Scan(&sess.ID, &sess.UserID, &sess.Title, &sess.IsMain, &sess.CreatedAt, &sess.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描会话行失败: %w", err)
|
||||
}
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
|
||||
if sessions == nil {
|
||||
sessions = []Session{}
|
||||
}
|
||||
return sessions, rows.Err()
|
||||
}
|
||||
|
||||
// GetSession 获取单个会话
|
||||
func (s *SessionStore) GetSession(sessionID string) (*Session, error) {
|
||||
var sess Session
|
||||
err := s.db.QueryRow(
|
||||
`SELECT id, user_id, title, is_main, created_at, updated_at
|
||||
FROM sessions WHERE id = $1`,
|
||||
sessionID,
|
||||
).Scan(&sess.ID, &sess.UserID, &sess.Title, &sess.IsMain, &sess.CreatedAt, &sess.UpdatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("查询会话失败: %w", err)
|
||||
}
|
||||
return &sess, nil
|
||||
}
|
||||
|
||||
// UpdateSessionTitle 更新会话标题
|
||||
func (s *SessionStore) UpdateSessionTitle(sessionID, title string) error {
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE sessions SET title = $1, updated_at = NOW() WHERE id = $2`,
|
||||
title, sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新会话标题失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateSessionTime 更新会话的 updated_at 时间戳
|
||||
func (s *SessionStore) UpdateSessionTime(sessionID string) error {
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE sessions SET updated_at = NOW() WHERE id = $1`,
|
||||
sessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新会话时间失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSession 删除会话(级联删除消息,但不删除记忆)
|
||||
func (s *SessionStore) DeleteSession(sessionID string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM sessions WHERE id = $1`, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除会话失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteAllUserSessions 删除用户的所有会话(但不删除记忆)
|
||||
func (s *SessionStore) DeleteAllUserSessions(userID string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM sessions WHERE user_id = $1`, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除用户所有会话失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMessage 添加一条消息到会话
|
||||
func (s *SessionStore) AddMessage(sessionID, role, content string) error {
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO messages (session_id, role, content) VALUES ($1, $2, $3)`,
|
||||
sessionID, role, content,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加消息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMessages 获取会话的消息列表(按时间正序)
|
||||
func (s *SessionStore) GetMessages(sessionID string, limit int) ([]Message, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, session_id, role, content, created_at
|
||||
FROM messages WHERE session_id = $1
|
||||
ORDER BY created_at ASC
|
||||
LIMIT $2`,
|
||||
sessionID, limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询消息失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
var msg Message
|
||||
if err := rows.Scan(&msg.ID, &msg.SessionID, &msg.Role, &msg.Content, &msg.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描消息行失败: %w", err)
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
if messages == nil {
|
||||
messages = []Message{}
|
||||
}
|
||||
return messages, rows.Err()
|
||||
}
|
||||
|
||||
// ClearSessionMessages 清空会话的所有消息但不删除会话本身
|
||||
func (s *SessionStore) ClearSessionMessages(sessionID string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM messages WHERE session_id = $1`, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("清空会话消息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (s *SessionStore) Close() error {
|
||||
if s.db != nil {
|
||||
return s.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsAvailable 检查存储是否可用(数据库连接正常)
|
||||
func (s *SessionStore) IsAvailable() bool {
|
||||
if s.db == nil {
|
||||
return false
|
||||
}
|
||||
return s.db.Ping() == nil
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||||
)
|
||||
|
||||
// SessionState 会话状态
|
||||
@@ -60,6 +62,22 @@ type Hub struct {
|
||||
iotServiceURL string
|
||||
iotStopCh chan struct{}
|
||||
iotPollRunning bool
|
||||
|
||||
// 持久化存储 (可选,数据库连接失败时为 nil)
|
||||
store *store.SessionStore
|
||||
|
||||
// 闲置超时时间
|
||||
idleTimeout time.Duration
|
||||
}
|
||||
|
||||
// SetStore 设置持久化存储 (可选)
|
||||
func (h *Hub) SetStore(s *store.SessionStore) {
|
||||
h.store = s
|
||||
}
|
||||
|
||||
// SetIdleTimeout 设置闲置超时时间
|
||||
func (h *Hub) SetIdleTimeout(minutes int) {
|
||||
h.idleTimeout = time.Duration(minutes) * time.Minute
|
||||
}
|
||||
|
||||
// NewHub 创建WebSocket Hub
|
||||
@@ -72,9 +90,79 @@ func NewHub() *Hub {
|
||||
userClients: make(map[string]map[*Client]bool),
|
||||
sessions: make(map[string]*SessionState),
|
||||
iotStopCh: make(chan struct{}),
|
||||
idleTimeout: 30 * time.Minute, // 默认30分钟
|
||||
}
|
||||
}
|
||||
|
||||
// StartIdleCleanup 启动闲置会话清理 goroutine
|
||||
// 每5分钟检查一次,将超过 idleTimeout 无活动的会话标记为 idle
|
||||
func (h *Hub) StartIdleCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
h.cleanupIdleSessions()
|
||||
}
|
||||
}()
|
||||
log.Printf("[WS] 闲置会话清理已启动 (超时: %v)", h.idleTimeout)
|
||||
}
|
||||
|
||||
// cleanupIdleSessions 标记超时会话为 idle(不删除状态)
|
||||
func (h *Hub) cleanupIdleSessions() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
idleCount := 0
|
||||
|
||||
for sessionID, s := range h.sessions {
|
||||
// 检查该 session 是否还有活跃连接
|
||||
hasActiveConn := false
|
||||
for _, clients := range h.userClients {
|
||||
for c := range clients {
|
||||
if c.SessionID == sessionID {
|
||||
hasActiveConn = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasActiveConn {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有活跃连接且超过闲置超时,标记为 idle
|
||||
if !hasActiveConn && now.Sub(s.LastActivity) > h.idleTimeout {
|
||||
if s.State != "idle" {
|
||||
s.State = "idle"
|
||||
idleCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if idleCount > 0 {
|
||||
log.Printf("[WS] 闲置清理: %d 个会话标记为 idle", idleCount)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllActiveSessions 返回所有会话状态(包括 idle),供 DevTools 监看使用
|
||||
func (h *Hub) GetAllActiveSessions() []*SessionState {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
if h.sessions == nil || len(h.sessions) == 0 {
|
||||
return []*SessionState{}
|
||||
}
|
||||
|
||||
result := make([]*SessionState, 0, len(h.sessions))
|
||||
for _, s := range h.sessions {
|
||||
cp := *s
|
||||
cp.RecentMessages = nil
|
||||
result = append(result, &cp)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Run 启动Hub主循环
|
||||
func (h *Hub) Run() {
|
||||
for {
|
||||
@@ -119,7 +207,7 @@ func (h *Hub) Run() {
|
||||
}
|
||||
}
|
||||
|
||||
// 检查该session是否还有其他连接,没有则移除会话状态
|
||||
// 检查该session是否还有其他连接,没有则标记为 idle 而非删除
|
||||
hasOtherConn := false
|
||||
if clients, ok := h.userClients[client.UserID]; ok {
|
||||
for c := range clients {
|
||||
@@ -130,7 +218,10 @@ func (h *Hub) Run() {
|
||||
}
|
||||
}
|
||||
if !hasOtherConn {
|
||||
delete(h.sessions, client.SessionID)
|
||||
// 不再删除 session 状态,而是标记为 idle 保留在内存中
|
||||
if s, ok := h.sessions[client.SessionID]; ok {
|
||||
s.State = "idle"
|
||||
}
|
||||
}
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
+54
-12
@@ -744,19 +744,22 @@ async function renderDashboard() {
|
||||
<div class="cards-grid cards-4" id="dashboard-svc-cards"></div>
|
||||
</div>
|
||||
|
||||
<!-- 数据库连接状态 -->
|
||||
<div class="card">
|
||||
<!-- 数据库状态卡片 -->
|
||||
<div class="card" id="db-card">
|
||||
<div class="card-header">
|
||||
<span class="card-title">🗄️ 数据库连接</span>
|
||||
${data.database?.checked ? `
|
||||
<span class="badge ${data.database.postgresAlive ? 'badge-running' : 'badge-error'}">
|
||||
PostgreSQL ${data.database.postgresAlive ? '通联' : '断开'}
|
||||
</span>
|
||||
<span class="badge ${data.database.tunnelRunning ? 'badge-running' : 'badge-stopped'}" style="margin-left:6px">
|
||||
隧道 ${data.database.tunnelRunning ? '运行中' : '未运行'}
|
||||
</span>
|
||||
` : '<span class="badge badge-stopped">待检查</span>'}
|
||||
<a href="#" onclick="switchPanel('database');return false" style="font-size:11px;color:var(--accent);text-decoration:none">🔍 详情 →</a>
|
||||
<span class="card-title">🗄️ 数据库</span>
|
||||
<span class="badge badge-stopped" id="db-status-badge">检查中...</span>
|
||||
</div>
|
||||
<div class="metrics">
|
||||
<div class="metric"><div class="value" id="db-type-display">PostgreSQL</div><div class="label">类型</div></div>
|
||||
<div class="metric"><div class="value" id="db-port-display">5432</div><div class="label">端口</div></div>
|
||||
<div class="metric"><div class="value" id="db-uptime-display">—</div><div class="label">状态</div></div>
|
||||
</div>
|
||||
<div class="btn-group" style="margin-top:10px">
|
||||
<button class="btn btn-xs btn-green" onclick="controlDB('start')">▶ 启动</button>
|
||||
<button class="btn btn-xs btn-red" onclick="controlDB('stop')">⏹ 停止</button>
|
||||
<button class="btn btn-xs" onclick="controlDB('restart')">🔄 重启</button>
|
||||
<a href="#" onclick="switchPanel('database');return false" style="font-size:10px;color:var(--accent);text-decoration:none;margin-left:auto;align-self:center">🔍 详情 →</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -789,6 +792,9 @@ async function renderDashboard() {
|
||||
// 渲染服务卡片
|
||||
renderDashboardSvcCards(svcs);
|
||||
|
||||
// 渲染数据库卡片
|
||||
renderDBCard();
|
||||
|
||||
// 渲染性能快照
|
||||
const perfContainer = document.getElementById('dashboard-perf');
|
||||
const perf = data.performance?.perService || {};
|
||||
@@ -1701,6 +1707,42 @@ async function tunnelAction(action) {
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 数据库卡片控制 ==========
|
||||
async function renderDBCard() {
|
||||
const data = await api('/api/db/status');
|
||||
const badge = document.getElementById('db-status-badge');
|
||||
const typeDisplay = document.getElementById('db-type-display');
|
||||
const portDisplay = document.getElementById('db-port-display');
|
||||
const uptimeDisplay = document.getElementById('db-uptime-display');
|
||||
|
||||
if (data.error) {
|
||||
if (badge) { badge.textContent = '错误'; badge.className = 'badge badge-error'; }
|
||||
if (uptimeDisplay) uptimeDisplay.textContent = '错误';
|
||||
return;
|
||||
}
|
||||
|
||||
const online = data.online;
|
||||
if (badge) {
|
||||
badge.textContent = online ? '🟢 在线' : '🔴 离线';
|
||||
badge.className = 'badge ' + (online ? 'badge-running' : 'badge-error');
|
||||
}
|
||||
if (typeDisplay) typeDisplay.textContent = 'PostgreSQL';
|
||||
if (portDisplay) portDisplay.textContent = data.port || 5432;
|
||||
if (uptimeDisplay) uptimeDisplay.textContent = online ? '已连接' : '未连接';
|
||||
}
|
||||
|
||||
async function controlDB(action) {
|
||||
showToast('正在' + action + '数据库...', 'info');
|
||||
const data = await api('/api/db/' + action, { method: 'POST' });
|
||||
if (data.error) {
|
||||
showToast('操作失败: ' + data.error, 'error');
|
||||
} else {
|
||||
showToast('数据库 ' + action + ' 完成', 'success');
|
||||
// 等待2秒后刷新状态
|
||||
setTimeout(renderDBCard, 2000);
|
||||
}
|
||||
}
|
||||
|
||||
</script>
|
||||
<script src="iot-panel.js"></script>
|
||||
<script>
|
||||
|
||||
@@ -52,6 +52,16 @@ async function renderIoTPanel() {
|
||||
}
|
||||
|
||||
var devices = result.devices;
|
||||
|
||||
// 固定设备排列顺序: 先按类型,同类型再按 device_id
|
||||
var typeOrder = { 'sensor': 1, 'ac': 2, 'light': 3, 'curtain': 4, 'lock': 5, 'camera': 6, 'speaker': 7, 'thermostat': 8 };
|
||||
devices.sort(function(a, b) {
|
||||
var oa = typeOrder[a.type] || 99;
|
||||
var ob = typeOrder[b.type] || 99;
|
||||
if (oa !== ob) return oa - ob;
|
||||
return (a.id || a.entity_id || '').localeCompare(b.id || b.entity_id || '');
|
||||
});
|
||||
|
||||
var badge = document.getElementById('iot-badge');
|
||||
if (badge) {
|
||||
badge.textContent = devices.length;
|
||||
|
||||
@@ -696,6 +696,82 @@ app.post('/api/tunnel/:action', (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
// ---- 数据库控制 (Docker Compose) ----
|
||||
const DB_COMPOSE_FILE = path.join(ROOT, 'docker-compose.dev.db.yml');
|
||||
const DB_PORT = 5432;
|
||||
|
||||
// GET /api/db/status
|
||||
app.get('/api/db/status', (_req, res) => {
|
||||
try {
|
||||
const online = checkPort(DB_PORT);
|
||||
res.json({
|
||||
online,
|
||||
port: DB_PORT,
|
||||
checked_at: new Date().toISOString(),
|
||||
});
|
||||
} catch (err) {
|
||||
res.json({
|
||||
online: false,
|
||||
port: DB_PORT,
|
||||
checked_at: new Date().toISOString(),
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// POST /api/db/start
|
||||
app.post('/api/db/start', (_req, res) => {
|
||||
try {
|
||||
const out = execSync(`docker compose -f "${DB_COMPOSE_FILE}" up -d`, {
|
||||
encoding: 'utf-8',
|
||||
timeout: 60000,
|
||||
stdio: 'pipe',
|
||||
});
|
||||
res.json({ success: true, action: 'start', output: out.trim() });
|
||||
} catch (err) {
|
||||
const stderr = err.stderr?.toString() || err.message;
|
||||
res.status(500).json({ success: false, action: 'start', error: stderr });
|
||||
}
|
||||
});
|
||||
|
||||
// POST /api/db/stop
|
||||
app.post('/api/db/stop', (_req, res) => {
|
||||
try {
|
||||
const out = execSync(`docker compose -f "${DB_COMPOSE_FILE}" down`, {
|
||||
encoding: 'utf-8',
|
||||
timeout: 30000,
|
||||
stdio: 'pipe',
|
||||
});
|
||||
res.json({ success: true, action: 'stop', output: out.trim() });
|
||||
} catch (err) {
|
||||
const stderr = err.stderr?.toString() || err.message;
|
||||
res.status(500).json({ success: false, action: 'stop', error: stderr });
|
||||
}
|
||||
});
|
||||
|
||||
// POST /api/db/restart
|
||||
app.post('/api/db/restart', (_req, res) => {
|
||||
try {
|
||||
const downOut = execSync(`docker compose -f "${DB_COMPOSE_FILE}" down`, {
|
||||
encoding: 'utf-8',
|
||||
timeout: 30000,
|
||||
stdio: 'pipe',
|
||||
});
|
||||
const upOut = execSync(`docker compose -f "${DB_COMPOSE_FILE}" up -d`, {
|
||||
encoding: 'utf-8',
|
||||
timeout: 60000,
|
||||
stdio: 'pipe',
|
||||
});
|
||||
res.json({
|
||||
success: true,
|
||||
action: 'restart',
|
||||
output: `down: ${downOut.trim()}\nup: ${upOut.trim()}`,
|
||||
});
|
||||
} catch (err) {
|
||||
const stderr = err.stderr?.toString() || err.message;
|
||||
res.status(500).json({ success: false, action: 'restart', error: stderr });
|
||||
}
|
||||
});
|
||||
|
||||
// ========== 启动 ==========
|
||||
// 启动性能监控
|
||||
performanceMonitor.start();
|
||||
|
||||
@@ -7,8 +7,16 @@ import { spawn, execSync } from 'child_process';
|
||||
import { EventEmitter } from 'events';
|
||||
import fs from 'fs';
|
||||
import net from 'net';
|
||||
import path from 'path';
|
||||
import { fileURLToPath } from 'url';
|
||||
import { SERVICES, logFile } from './config.js';
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = path.dirname(__filename);
|
||||
const ROOT = path.resolve(__dirname, '../..');
|
||||
const DB_PORTS = [5432, 5433, 5434];
|
||||
const DB_COMPOSE_FILE = path.join(ROOT, 'docker-compose.dev.db.yml');
|
||||
|
||||
/**
|
||||
* 通过 TCP 连接尝试判断端口是否被占用,若被占用则尝试用 fuser 释放
|
||||
*/
|
||||
@@ -36,6 +44,64 @@ function releasePort(port) {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查端口是否可连接 (TCP connect, 超时2秒)
|
||||
*/
|
||||
function isPortOpen(port) {
|
||||
return new Promise((resolve) => {
|
||||
const sock = new net.Socket();
|
||||
sock.setTimeout(2000);
|
||||
sock.on('connect', () => { sock.destroy(); resolve(true); });
|
||||
sock.on('error', () => { sock.destroy(); resolve(false); });
|
||||
sock.on('timeout', () => { sock.destroy(); resolve(false); });
|
||||
sock.connect(port, '127.0.0.1');
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 确保数据库在线
|
||||
* 检查 DB_PORTS 中至少有一个端口可用,若不可用则尝试 docker compose up
|
||||
* 等待最多 30 秒检查数据库就绪
|
||||
* @param {string} serviceId - 正在启动的服务 ID
|
||||
* @param {EventEmitter} emitter - 用于发送日志事件
|
||||
*/
|
||||
async function ensureDBOnline(serviceId, emitter) {
|
||||
// 1. 快速检查:任意数据库端口是否已在线
|
||||
for (const port of DB_PORTS) {
|
||||
if (await isPortOpen(port)) {
|
||||
emitter.emit('log', serviceId, 'system', `数据库端口 ${port} 已在线`);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 数据库不在线,尝试 docker compose up
|
||||
emitter.emit('log', serviceId, 'system', '数据库未启动,正在通过 Docker Compose 启动...');
|
||||
try {
|
||||
execSync(`docker compose -f "${DB_COMPOSE_FILE}" up -d`, {
|
||||
timeout: 60000,
|
||||
stdio: 'pipe',
|
||||
});
|
||||
emitter.emit('log', serviceId, 'system', 'Docker Compose 启动命令已执行,等待数据库就绪...');
|
||||
} catch (err) {
|
||||
const stderr = err.stderr?.toString() || err.message;
|
||||
emitter.emit('log', serviceId, 'error', `Docker Compose 启动失败: ${stderr}`);
|
||||
}
|
||||
|
||||
// 3. 等待最多 30 秒检查数据库就绪
|
||||
for (let i = 0; i < 30; i++) {
|
||||
await new Promise((r) => setTimeout(r, 1000));
|
||||
for (const port of DB_PORTS) {
|
||||
if (await isPortOpen(port)) {
|
||||
emitter.emit('log', serviceId, 'system', `数据库端口 ${port} 已就绪 (等待 ${i + 1}s)`);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 30 秒后仍不可用
|
||||
emitter.emit('log', serviceId, 'error', '⚠️ 数据库无法启动,请手动检查 Docker。将继续启动后端服务...');
|
||||
}
|
||||
|
||||
class ProcessManager extends EventEmitter {
|
||||
constructor() {
|
||||
super();
|
||||
@@ -65,6 +131,12 @@ class ProcessManager extends EventEmitter {
|
||||
throw new Error(`${svc.name} 已在运行中`);
|
||||
}
|
||||
|
||||
// 对 gateway 和 ai-core 做数据库前置检查
|
||||
if (serviceId === 'gateway' || serviceId === 'ai-core') {
|
||||
this.emit('log', serviceId, 'system', '检查数据库连接状态...');
|
||||
await ensureDBOnline(serviceId, this);
|
||||
}
|
||||
|
||||
// 启动前释放端口,避免 "address already in use"
|
||||
if (svc.port) {
|
||||
this.emit('log', serviceId, 'system', `检查端口 ${svc.port}...`);
|
||||
|
||||
+134
-2
@@ -1,13 +1,37 @@
|
||||
import { useState } from 'react';
|
||||
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import { AppLayout } from '@/components/layout/AppLayout';
|
||||
import { ChatContainer } from '@/components/chat/ChatContainer';
|
||||
import { ChatInput } from '@/components/chat/ChatInput';
|
||||
import { useAuth } from '@/hooks/useAuth';
|
||||
import { useChat } from '@/hooks/useChat';
|
||||
import { useSessionStore, isAdminUser } from '@/store/sessionStore';
|
||||
import { useChatStore } from '@/store/chatStore';
|
||||
import { fetchMessages } from '@/api/sessions';
|
||||
|
||||
/** URL Hash 工具 */
|
||||
const SESSION_HASH_PREFIX = 'session=';
|
||||
|
||||
function getSessionIdFromHash(): string | null {
|
||||
const hash = window.location.hash.slice(1); // 去掉 #
|
||||
if (hash.startsWith(SESSION_HASH_PREFIX)) {
|
||||
return hash.slice(SESSION_HASH_PREFIX.length);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function setHashSessionId(sessionId: string | null) {
|
||||
if (sessionId) {
|
||||
window.location.hash = SESSION_HASH_PREFIX + sessionId;
|
||||
} else {
|
||||
// 清除 hash
|
||||
history.replaceState(null, '', window.location.pathname + window.location.search);
|
||||
}
|
||||
}
|
||||
|
||||
export default function App() {
|
||||
const { isLoggedIn, login, register, loading: authLoading } = useAuth();
|
||||
const { isLoggedIn, login, register, loading: authLoading, userId } = useAuth();
|
||||
const { send } = useChat();
|
||||
const { loadSessionsFromServer, ensureMainSession, setCurrentSessionId, setMessages, loadMessagesFromServer, sessions, currentSessionId } = useSessionStore();
|
||||
|
||||
const [authMode, setAuthMode] = useState<'login' | 'register'>('login');
|
||||
const [username, setUsername] = useState('');
|
||||
@@ -17,6 +41,114 @@ export default function App() {
|
||||
const [error, setError] = useState('');
|
||||
const [successMsg, setSuccessMsg] = useState('');
|
||||
|
||||
const initializedRef = useRef(false);
|
||||
|
||||
// ========== URL Hash 路由 ==========
|
||||
|
||||
/** 根据 hash 恢复或选择初始会话 */
|
||||
const initSession = useCallback(async () => {
|
||||
if (!userId || initializedRef.current) return;
|
||||
initializedRef.current = true;
|
||||
|
||||
const admin = isAdminUser(userId);
|
||||
|
||||
// 1. 从服务端加载会话列表
|
||||
await loadSessionsFromServer(userId);
|
||||
const currentSessions = useSessionStore.getState().sessions;
|
||||
|
||||
// 2. 检查 URL hash
|
||||
const hashId = getSessionIdFromHash();
|
||||
|
||||
if (hashId) {
|
||||
// 尝试加载 hash 指定的会话
|
||||
const found = currentSessions.find((s) => s.id === hashId);
|
||||
if (found) {
|
||||
setCurrentSessionId(found.id);
|
||||
await loadMessagesFromServer(found.id);
|
||||
return;
|
||||
}
|
||||
// 会话可能已被删除,尝试从 API 获取消息(404 时 catch)
|
||||
try {
|
||||
const resp = await fetchMessages(hashId);
|
||||
if (resp.messages && resp.messages.length > 0) {
|
||||
// 消息存在说明会话仍有效(虽然不在列表里,可能是刚创建的)
|
||||
setCurrentSessionId(hashId);
|
||||
const msgs = resp.messages.map((m: any, i: number) => ({
|
||||
id: m.id ? String(m.id) : `hist_${i}_${Date.now()}`,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
timestamp: typeof m.created_at === 'number' ? m.created_at : Date.now(),
|
||||
isStreaming: false,
|
||||
}));
|
||||
setMessages(msgs);
|
||||
useChatStore.getState().setMessages(msgs);
|
||||
return;
|
||||
}
|
||||
} catch {
|
||||
// 会话不存在,回退
|
||||
}
|
||||
// 回退:清除 hash,加载最新/主对话
|
||||
setHashSessionId(null);
|
||||
}
|
||||
|
||||
// 3. 无 hash 或 hash 无效:加载最新会话
|
||||
if (admin) {
|
||||
// 管理员:确保主对话存在
|
||||
const mainSession = await ensureMainSession(userId);
|
||||
if (mainSession) {
|
||||
setCurrentSessionId(mainSession.id);
|
||||
setHashSessionId(mainSession.id);
|
||||
await loadMessagesFromServer(mainSession.id);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// 普通用户:选择最新会话
|
||||
if (currentSessions.length > 0) {
|
||||
const latest = currentSessions[0]; // 已按 updated_at DESC 排序
|
||||
setCurrentSessionId(latest.id);
|
||||
setHashSessionId(latest.id);
|
||||
await loadMessagesFromServer(latest.id);
|
||||
}
|
||||
}, [userId, loadSessionsFromServer, ensureMainSession, setCurrentSessionId, setMessages, loadMessagesFromServer]);
|
||||
|
||||
// 登录后初始化
|
||||
useEffect(() => {
|
||||
if (isLoggedIn && userId) {
|
||||
initSession();
|
||||
}
|
||||
}, [isLoggedIn, userId, initSession]);
|
||||
|
||||
// 监听 hashchange 事件 (浏览器前进/后退)
|
||||
useEffect(() => {
|
||||
if (!isLoggedIn) return;
|
||||
|
||||
const handleHashChange = async () => {
|
||||
const hashId = getSessionIdFromHash();
|
||||
const currentId = useSessionStore.getState().currentSessionId;
|
||||
|
||||
if (hashId && hashId !== currentId) {
|
||||
// hash 变化,切换会话
|
||||
setCurrentSessionId(hashId);
|
||||
await loadMessagesFromServer(hashId);
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('hashchange', handleHashChange);
|
||||
return () => window.removeEventListener('hashchange', handleHashChange);
|
||||
}, [isLoggedIn, setCurrentSessionId, loadMessagesFromServer]);
|
||||
|
||||
// 当前会话变化时更新 URL hash(仅在登录后、非 hashchange 驱动时)
|
||||
useEffect(() => {
|
||||
if (!isLoggedIn || !currentSessionId) return;
|
||||
const hashId = getSessionIdFromHash();
|
||||
if (hashId !== currentSessionId) {
|
||||
setHashSessionId(currentSessionId);
|
||||
}
|
||||
}, [isLoggedIn, currentSessionId]);
|
||||
|
||||
// ========== 认证相关 ==========
|
||||
|
||||
const handleLogin = async () => {
|
||||
setError('');
|
||||
const result = await login(username, password);
|
||||
|
||||
@@ -1,8 +1,110 @@
|
||||
// 会话API
|
||||
export {
|
||||
createSession,
|
||||
listSessions,
|
||||
getSession,
|
||||
deleteSession,
|
||||
fetchSessionMessages,
|
||||
} from './client';
|
||||
// 会话持久化 API — 对接 Gateway REST API
|
||||
|
||||
import { request } from './client';
|
||||
import type { Session, SessionListResponse, SessionMessagesResponse, SessionResponse } from '@/types/session';
|
||||
|
||||
/**
|
||||
* 获取用户的会话列表
|
||||
* GET /api/v1/sessions?user_id={userId}
|
||||
*/
|
||||
export async function fetchSessions(userId: string): Promise<Session[]> {
|
||||
const resp = await request<SessionListResponse>(
|
||||
`/sessions?user_id=${encodeURIComponent(userId)}`
|
||||
);
|
||||
if (resp.error) {
|
||||
console.error('[sessions] 获取会话列表失败:', resp.error);
|
||||
return [];
|
||||
}
|
||||
// Gateway 返回 { sessions: [...] }
|
||||
return (resp.data as SessionListResponse)?.sessions || [];
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取会话的消息历史
|
||||
* GET /api/v1/sessions/{sessionId}/messages?limit={limit}
|
||||
*/
|
||||
export async function fetchMessages(
|
||||
sessionId: string,
|
||||
limit: number = 50
|
||||
): Promise<SessionMessagesResponse> {
|
||||
const resp = await request<SessionMessagesResponse>(
|
||||
`/sessions/${encodeURIComponent(sessionId)}/messages?limit=${limit}`
|
||||
);
|
||||
if (resp.error) {
|
||||
console.error('[sessions] 获取消息失败:', resp.error);
|
||||
return { messages: [] };
|
||||
}
|
||||
return (resp.data as SessionMessagesResponse) || { messages: [] };
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建新会话
|
||||
* POST /api/v1/sessions
|
||||
*/
|
||||
export async function createSession(
|
||||
userId: string,
|
||||
sessionId: string,
|
||||
title: string = '新的对话',
|
||||
isMain: boolean = false
|
||||
): Promise<SessionResponse | null> {
|
||||
const resp = await request<SessionResponse>('/sessions', {
|
||||
method: 'POST',
|
||||
body: {
|
||||
user_id: userId,
|
||||
session_id: sessionId,
|
||||
title,
|
||||
is_main: isMain,
|
||||
},
|
||||
});
|
||||
if (resp.error) {
|
||||
console.error('[sessions] 创建会话失败:', resp.error);
|
||||
return null;
|
||||
}
|
||||
return resp.data as SessionResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除单个会话 (不删除记忆)
|
||||
* DELETE /api/v1/sessions/{sessionId}
|
||||
*/
|
||||
export async function deleteSession(sessionId: string): Promise<boolean> {
|
||||
const resp = await request(`/sessions/${encodeURIComponent(sessionId)}`, {
|
||||
method: 'DELETE',
|
||||
});
|
||||
if (resp.error) {
|
||||
console.error('[sessions] 删除会话失败:', resp.error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除用户所有会话 (不删除记忆)
|
||||
* DELETE /api/v1/sessions?user_id={userId}
|
||||
*/
|
||||
export async function deleteAllSessions(userId: string): Promise<boolean> {
|
||||
const resp = await request(`/sessions?user_id=${encodeURIComponent(userId)}`, {
|
||||
method: 'DELETE',
|
||||
});
|
||||
if (resp.error) {
|
||||
console.error('[sessions] 删除所有会话失败:', resp.error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空会话消息但不删除会话本身
|
||||
* DELETE /api/v1/sessions/{sessionId}/messages
|
||||
*/
|
||||
export async function clearSessionMessages(sessionId: string): Promise<boolean> {
|
||||
const resp = await request(
|
||||
`/sessions/${encodeURIComponent(sessionId)}/messages`,
|
||||
{ method: 'DELETE' }
|
||||
);
|
||||
if (resp.error) {
|
||||
console.error('[sessions] 清空消息失败:', resp.error);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { useState, useCallback } from 'react';
|
||||
import { useSession } from '@/hooks/useSession';
|
||||
import { useSessionStore } from '@/store/sessionStore';
|
||||
import { CyreneAvatar } from '@/components/persona/CyreneAvatar';
|
||||
import type { Session } from '@/types/session';
|
||||
|
||||
interface SidebarProps {
|
||||
onClose?: () => void;
|
||||
@@ -11,12 +13,25 @@ export function Sidebar({ onClose }: SidebarProps) {
|
||||
sessions,
|
||||
createSession,
|
||||
deleteSession,
|
||||
deleteAllSessions,
|
||||
clearMainSession,
|
||||
setCurrentSession,
|
||||
isAdmin,
|
||||
} = useSession();
|
||||
const currentSessionId = useSessionStore((s) => s.currentSessionId);
|
||||
|
||||
const displaySessions = sessions;
|
||||
const activeSessionId = currentSessionId;
|
||||
// 确认弹窗状态
|
||||
const [confirmAction, setConfirmAction] = useState<{
|
||||
type: 'delete' | 'clearMain' | 'deleteAll';
|
||||
sessionId?: string;
|
||||
} | null>(null);
|
||||
|
||||
// 按 updated_at 降序排列
|
||||
const displaySessions = [...sessions].sort((a, b) => {
|
||||
const ta = typeof a.updated_at === 'string' ? parseInt(a.updated_at, 10) : (a.updated_at as unknown as number);
|
||||
const tb = typeof b.updated_at === 'string' ? parseInt(b.updated_at, 10) : (b.updated_at as unknown as number);
|
||||
return (tb || 0) - (ta || 0);
|
||||
});
|
||||
|
||||
const handleNewChat = async () => {
|
||||
const session = await createSession();
|
||||
@@ -28,11 +43,43 @@ export function Sidebar({ onClose }: SidebarProps) {
|
||||
if (onClose) onClose();
|
||||
};
|
||||
|
||||
const handleDeleteSession = (e: { stopPropagation: () => void }, id: string) => {
|
||||
e.stopPropagation();
|
||||
deleteSession(id);
|
||||
const handleMainSession = async () => {
|
||||
// 找到主对话
|
||||
const mainSession = displaySessions.find((s) => s.is_main);
|
||||
if (mainSession) {
|
||||
setCurrentSession(mainSession.id);
|
||||
if (onClose) onClose();
|
||||
}
|
||||
};
|
||||
|
||||
const handleDeleteClick = (e: React.MouseEvent, id: string) => {
|
||||
e.stopPropagation();
|
||||
const session = displaySessions.find((s) => s.id === id);
|
||||
// 管理员主对话不可删除
|
||||
if (isAdmin && session?.is_main) return;
|
||||
setConfirmAction({ type: 'delete', sessionId: id });
|
||||
};
|
||||
|
||||
const handleConfirmAction = async () => {
|
||||
if (!confirmAction) return;
|
||||
switch (confirmAction.type) {
|
||||
case 'delete':
|
||||
if (confirmAction.sessionId) {
|
||||
await deleteSession(confirmAction.sessionId);
|
||||
}
|
||||
break;
|
||||
case 'clearMain':
|
||||
await clearMainSession();
|
||||
break;
|
||||
case 'deleteAll':
|
||||
await deleteAllSessions();
|
||||
break;
|
||||
}
|
||||
setConfirmAction(null);
|
||||
};
|
||||
|
||||
const cancelConfirm = () => setConfirmAction(null);
|
||||
|
||||
/** 格式化时间戳为可读字符串 */
|
||||
const formatTime = (ts: string | number): string => {
|
||||
if (!ts) return '';
|
||||
@@ -52,7 +99,28 @@ export function Sidebar({ onClose }: SidebarProps) {
|
||||
|
||||
return (
|
||||
<aside className="h-full bg-white/90 dark:bg-gray-900/90 border-r border-pink-100 dark:border-pink-900 flex flex-col">
|
||||
{/* 侧边栏头部 */}
|
||||
{/* 主对话按钮 (仅管理员可见) */}
|
||||
{isAdmin && (
|
||||
<div className="p-3 border-b border-pink-100 dark:border-pink-900 flex gap-2">
|
||||
<button
|
||||
onClick={handleMainSession}
|
||||
className="flex-1 flex items-center justify-center gap-1.5 px-3 py-2 bg-amber-50 hover:bg-amber-100 dark:bg-amber-900/20 dark:hover:bg-amber-900/30 text-amber-600 dark:text-amber-400 rounded-xl text-sm font-medium transition-colors border border-amber-200 dark:border-amber-800"
|
||||
title="回到主对话"
|
||||
>
|
||||
<span>🏠</span>
|
||||
<span>主对话</span>
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setConfirmAction({ type: 'clearMain' })}
|
||||
className="flex items-center justify-center px-2 py-2 bg-red-50 hover:bg-red-100 dark:bg-red-900/20 dark:hover:bg-red-900/30 text-red-400 hover:text-red-500 rounded-xl text-sm transition-colors border border-red-200 dark:border-red-800"
|
||||
title="清空主对话消息"
|
||||
>
|
||||
<span>🗑️</span>
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 新对话按钮 */}
|
||||
<div className="p-4 border-b border-pink-100 dark:border-pink-900">
|
||||
<button
|
||||
onClick={handleNewChat}
|
||||
@@ -70,46 +138,80 @@ export function Sidebar({ onClose }: SidebarProps) {
|
||||
还没有对话哦,开始和昔涟聊天吧 ♪
|
||||
</p>
|
||||
) : (
|
||||
displaySessions.map((session) => (
|
||||
<div
|
||||
key={session.id}
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className={`
|
||||
group flex items-center justify-between px-4 py-2.5 mx-2 rounded-lg cursor-pointer transition-colors
|
||||
${
|
||||
activeSessionId === session.id
|
||||
? 'bg-pink-50 dark:bg-pink-900/30 text-pink-600'
|
||||
: 'hover:bg-gray-50 dark:hover:bg-gray-800 text-gray-600 dark:text-gray-300'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex items-center gap-2 min-w-0 flex-1">
|
||||
<CyreneAvatar size="sm" />
|
||||
<div className="min-w-0">
|
||||
<p className="text-sm font-medium truncate">{session.title || '新的对话'}</p>
|
||||
<p className="text-xs text-gray-400 truncate">
|
||||
{session.message_count != null ? `${session.message_count} 条消息` : ''}
|
||||
{session.updated_at ? ` · ${formatTime(session.updated_at)}` : ''}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
onClick={(e) => handleDeleteSession(e, session.id)}
|
||||
className="opacity-0 group-hover:opacity-100 p-1 text-gray-400 hover:text-red-400 transition-all"
|
||||
title="删除会话"
|
||||
displaySessions.map((session) => {
|
||||
const isMainSession = session.is_main;
|
||||
const isActive = currentSessionId === session.id;
|
||||
// 管理员主对话不可删除
|
||||
const canDelete = !(isAdmin && isMainSession);
|
||||
|
||||
return (
|
||||
<div
|
||||
key={session.id}
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className={`
|
||||
group flex items-center justify-between px-4 py-2.5 mx-2 rounded-lg cursor-pointer transition-colors
|
||||
${
|
||||
isActive
|
||||
? 'bg-pink-50 dark:bg-pink-900/30 text-pink-600'
|
||||
: 'hover:bg-gray-50 dark:hover:bg-gray-800 text-gray-600 dark:text-gray-300'
|
||||
}
|
||||
`}
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" className="h-4 w-4" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
))
|
||||
<div className="flex items-center gap-2 min-w-0 flex-1">
|
||||
<CyreneAvatar size="sm" />
|
||||
<div className="min-w-0">
|
||||
<p className="text-sm font-medium truncate">
|
||||
{isMainSession ? '🏠 ' : ''}
|
||||
{session.title || '新的对话'}
|
||||
</p>
|
||||
<p className="text-xs text-gray-400 truncate">
|
||||
{session.message_count != null
|
||||
? `${session.message_count} 条消息`
|
||||
: ''}
|
||||
{session.updated_at
|
||||
? `${session.message_count != null ? ' · ' : ''}${formatTime(session.updated_at)}`
|
||||
: ''}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{canDelete && (
|
||||
<button
|
||||
onClick={(e) => handleDeleteClick(e, session.id)}
|
||||
className="opacity-0 group-hover:opacity-100 p-1 text-gray-400 hover:text-red-400 transition-all"
|
||||
title="删除会话"
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-4 w-4"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M19 7l-.867 12.142A2 2 0 0116.138 21H7.862a2 2 0 01-1.995-1.858L5 7m5 4v6m4-6v6m1-10V4a1 1 0 00-1-1h-4a1 1 0 00-1 1v3M4 7h16"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* 底部信息 */}
|
||||
<div className="p-4 border-t border-pink-100 dark:border-pink-900">
|
||||
<div className="flex items-center gap-2 text-xs text-gray-400">
|
||||
{/* 底部:一键清空所有对话 */}
|
||||
<div className="p-3 border-t border-pink-100 dark:border-pink-900">
|
||||
<button
|
||||
onClick={() => setConfirmAction({ type: 'deleteAll' })}
|
||||
className="w-full flex items-center justify-center gap-1.5 px-3 py-2 bg-red-50 hover:bg-red-100 dark:bg-red-900/20 dark:hover:bg-red-900/30 text-red-400 hover:text-red-500 rounded-xl text-xs font-medium transition-colors border border-red-200 dark:border-red-800"
|
||||
>
|
||||
<span>🗑️</span>
|
||||
<span>清空所有对话</span>
|
||||
</button>
|
||||
<div className="flex items-center gap-2 mt-3 text-xs text-gray-400">
|
||||
<CyreneAvatar size="sm" />
|
||||
<div>
|
||||
<p className="font-medium text-pink-400">昔涟 AI</p>
|
||||
@@ -117,6 +219,42 @@ export function Sidebar({ onClose }: SidebarProps) {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 确认弹窗 */}
|
||||
{confirmAction && (
|
||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/40">
|
||||
<div className="bg-white dark:bg-gray-800 rounded-2xl shadow-xl p-6 m-4 max-w-sm w-full border border-pink-100 dark:border-pink-700">
|
||||
<h3 className="text-lg font-semibold text-gray-800 dark:text-gray-200 mb-2">
|
||||
{confirmAction.type === 'delete'
|
||||
? '删除会话'
|
||||
: confirmAction.type === 'clearMain'
|
||||
? '清空主对话'
|
||||
: '清空所有对话'}
|
||||
</h3>
|
||||
<p className="text-sm text-gray-500 dark:text-gray-400 mb-6">
|
||||
{confirmAction.type === 'delete'
|
||||
? '确定要删除这个会话吗?此操作不会删除记忆。'
|
||||
: confirmAction.type === 'clearMain'
|
||||
? '确定要清空主对话的所有消息吗?此操作不会删除记忆。'
|
||||
: '确定要删除所有对话吗?此操作不会删除记忆。管理员将回到主对话,普通用户进入新对话。'}
|
||||
</p>
|
||||
<div className="flex justify-end gap-3">
|
||||
<button
|
||||
onClick={cancelConfirm}
|
||||
className="px-4 py-2 rounded-lg text-sm text-gray-500 hover:bg-gray-100 dark:hover:bg-gray-700 transition-colors"
|
||||
>
|
||||
取消
|
||||
</button>
|
||||
<button
|
||||
onClick={handleConfirmAction}
|
||||
className="px-4 py-2 rounded-lg text-sm bg-red-400 hover:bg-red-500 text-white font-medium transition-colors"
|
||||
>
|
||||
确定
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</aside>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,24 @@
|
||||
import { useCallback, useEffect } from 'react';
|
||||
import { useSessionStore } from '@/store/sessionStore';
|
||||
import { useSessionStore, isAdminUser } from '@/store/sessionStore';
|
||||
import { useChatStore } from '@/store/chatStore';
|
||||
import { listSessions as apiListSessions, createSession as apiCreateSession, deleteSession as apiDeleteSession } from '@/api/client';
|
||||
import { createSession as apiCreateSession } from '@/api/sessions';
|
||||
import type { Session } from '@/types/session';
|
||||
|
||||
/** 生成简易随机ID */
|
||||
function randomID(n: number = 12): string {
|
||||
const letters = 'abcdefghijklmnopqrstuvwxyz0123456789';
|
||||
let result = '';
|
||||
for (let i = 0; i < n; i++) {
|
||||
result += letters.charAt(Math.floor(Math.random() * letters.length));
|
||||
}
|
||||
return `session_${result}`;
|
||||
}
|
||||
|
||||
/** 获取当前用户ID */
|
||||
function getUserId(): string {
|
||||
return localStorage.getItem('user_id') || '';
|
||||
}
|
||||
|
||||
export function useSession() {
|
||||
const {
|
||||
sessions,
|
||||
@@ -14,69 +29,105 @@ export function useSession() {
|
||||
removeSession,
|
||||
setCurrentSessionId,
|
||||
setLoading,
|
||||
loadSessionsFromServer,
|
||||
loadMessagesFromServer,
|
||||
clearMainSessionMessages,
|
||||
deleteSessionAndRefresh,
|
||||
deleteAllSessionsAndReset,
|
||||
ensureMainSession,
|
||||
} = useSessionStore();
|
||||
|
||||
const { clearMessages } = useChatStore();
|
||||
const userId = getUserId();
|
||||
const isAdmin = isAdminUser(userId);
|
||||
|
||||
/** 从服务端加载会话列表 */
|
||||
const loadSessions = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const resp = await apiListSessions();
|
||||
if (resp.data) {
|
||||
const data = resp.data as { sessions: Session[] };
|
||||
setSessions(data.sessions || []);
|
||||
}
|
||||
} catch {
|
||||
// ignore
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [setSessions, setLoading]);
|
||||
if (!userId) return;
|
||||
await loadSessionsFromServer(userId);
|
||||
}, [userId, loadSessionsFromServer]);
|
||||
|
||||
// 初始加载
|
||||
useEffect(() => {
|
||||
loadSessions();
|
||||
}, [loadSessions]);
|
||||
|
||||
const createSession = useCallback(async (title?: string) => {
|
||||
const resp = await apiCreateSession(title);
|
||||
if (resp.data) {
|
||||
const session = resp.data as Session;
|
||||
addSession(session);
|
||||
// setCurrentSessionId 内部已处理消息清空和历史加载,无需额外 clearMessages
|
||||
await setCurrentSessionId(session.id);
|
||||
return session;
|
||||
}
|
||||
return null;
|
||||
}, [addSession, setCurrentSessionId]);
|
||||
|
||||
const deleteSession = useCallback(async (id: string) => {
|
||||
await apiDeleteSession(id);
|
||||
// 如果删除的是当前活跃会话,先切换到其他会话
|
||||
if (currentSessionId === id) {
|
||||
const remaining = useSessionStore.getState().sessions.filter((s: Session) => s.id !== id);
|
||||
if (remaining.length > 0) {
|
||||
// 切换到列表中的第一个会话
|
||||
await setCurrentSessionId(remaining[0].id);
|
||||
} else {
|
||||
clearMessages();
|
||||
setCurrentSessionId(null);
|
||||
/** 创建新会话 (isMain 默认为 false) */
|
||||
const createSession = useCallback(
|
||||
async (title?: string, isMain: boolean = false) => {
|
||||
const sid = randomID();
|
||||
const created = await apiCreateSession(
|
||||
userId,
|
||||
sid,
|
||||
title || '新的对话',
|
||||
isMain
|
||||
);
|
||||
if (created) {
|
||||
const newSession: Session = {
|
||||
id: created.id,
|
||||
user_id: created.user_id,
|
||||
title: created.title,
|
||||
is_main: created.is_main,
|
||||
created_at: String(created.created_at),
|
||||
updated_at: String(created.updated_at),
|
||||
message_count: 0,
|
||||
};
|
||||
addSession(newSession);
|
||||
setCurrentSessionId(newSession.id);
|
||||
return newSession;
|
||||
}
|
||||
}
|
||||
removeSession(id);
|
||||
}, [removeSession, currentSessionId, clearMessages, setCurrentSessionId]);
|
||||
return null;
|
||||
},
|
||||
[userId, addSession, setCurrentSessionId]
|
||||
);
|
||||
|
||||
const setCurrentSession = useCallback((id: string) => {
|
||||
// setCurrentSessionId 内部已处理消息清空和历史加载
|
||||
setCurrentSessionId(id);
|
||||
}, [setCurrentSessionId]);
|
||||
/** 清空主对话消息 */
|
||||
const clearMainSession = useCallback(async () => {
|
||||
const mainSession = useSessionStore
|
||||
.getState()
|
||||
.sessions.find((s: Session) => s.is_main);
|
||||
if (!mainSession) {
|
||||
// 尝试确保主对话存在
|
||||
const ensured = await ensureMainSession(userId);
|
||||
if (!ensured) return false;
|
||||
return clearMainSessionMessages(ensured.id);
|
||||
}
|
||||
return clearMainSessionMessages(mainSession.id);
|
||||
}, [userId, clearMainSessionMessages, ensureMainSession]);
|
||||
|
||||
/** 删除单个会话 */
|
||||
const deleteSession = useCallback(
|
||||
async (id: string) => {
|
||||
await deleteSessionAndRefresh(id, userId);
|
||||
},
|
||||
[userId, deleteSessionAndRefresh]
|
||||
);
|
||||
|
||||
/** 删除所有会话 */
|
||||
const deleteAllSessions = useCallback(async () => {
|
||||
await deleteAllSessionsAndReset(userId);
|
||||
}, [userId, deleteAllSessionsAndReset]);
|
||||
|
||||
/** 切换当前会话 */
|
||||
const setCurrentSession = useCallback(
|
||||
async (id: string) => {
|
||||
setCurrentSessionId(id);
|
||||
// 加载该会话的消息历史
|
||||
await loadMessagesFromServer(id);
|
||||
},
|
||||
[setCurrentSessionId, loadMessagesFromServer]
|
||||
);
|
||||
|
||||
return {
|
||||
sessions,
|
||||
currentSessionId,
|
||||
loading,
|
||||
isAdmin,
|
||||
userId,
|
||||
loadSessions,
|
||||
createSession,
|
||||
clearMainSession,
|
||||
deleteSession,
|
||||
deleteAllSessions,
|
||||
setCurrentSession,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useEffect, useRef, useCallback, useState } from 'react';
|
||||
import { useChatStore } from '@/store/chatStore';
|
||||
import { useSessionStore } from '@/store/sessionStore';
|
||||
import { getToken } from '@/api/client';
|
||||
import { fetchMessages } from '@/api/sessions';
|
||||
import type { WSClientMessage, WSServerMessage } from '@/types/chat';
|
||||
|
||||
const WS_BASE_URL =
|
||||
@@ -12,7 +13,8 @@ export function useWebSocket() {
|
||||
const wsRef = useRef<WebSocket | null>(null);
|
||||
const reconnectTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
const shouldReconnectRef = useRef(true);
|
||||
const activeSessionRef = useRef<string | null>(null); // 追踪当前活跃会话,防止竞态
|
||||
const activeSessionRef = useRef<string | null>(null);
|
||||
const loadingRef = useRef(false); // 防止重复加载消息
|
||||
|
||||
// 订阅 sessionStore 中的 currentSessionId 变化
|
||||
const currentSessionId = useSessionStore((s) => s.currentSessionId);
|
||||
@@ -40,7 +42,7 @@ export function useWebSocket() {
|
||||
shouldReconnectRef.current = true;
|
||||
console.log('[WS] 已连接, session_id:', sessionID);
|
||||
|
||||
// 连接后发送会话恢复消息,恢复历史上下文
|
||||
// 连接后发送会话恢复消息,恢复后端上下文
|
||||
const sid = useSessionStore.getState().currentSessionId;
|
||||
if (sid) {
|
||||
const resumeMsg: WSClientMessage = {
|
||||
@@ -81,10 +83,39 @@ export function useWebSocket() {
|
||||
wsRef.current = ws;
|
||||
}, []);
|
||||
|
||||
// 初始连接 + 会话切换时重连
|
||||
// 会话切换时:先通过 REST API 加载历史消息,再建立 WS 连接
|
||||
useEffect(() => {
|
||||
activeSessionRef.current = currentSessionId;
|
||||
connect();
|
||||
|
||||
const loadAndConnect = async () => {
|
||||
// 如果是从 URL 恢复的 session,先加载历史消息
|
||||
if (currentSessionId && !loadingRef.current) {
|
||||
loadingRef.current = true;
|
||||
try {
|
||||
const resp = await fetchMessages(currentSessionId);
|
||||
const rawMessages = resp.messages || [];
|
||||
const msgs = rawMessages.map((m: any, i: number) => ({
|
||||
id: m.id ? String(m.id) : `hist_${i}_${Date.now()}`,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
timestamp:
|
||||
typeof m.created_at === 'number' ? m.created_at : Date.now(),
|
||||
isStreaming: false,
|
||||
}));
|
||||
useSessionStore.getState().setMessages(msgs);
|
||||
useChatStore.getState().setMessages(msgs);
|
||||
} catch {
|
||||
// 加载失败不影响后续连接
|
||||
} finally {
|
||||
loadingRef.current = false;
|
||||
}
|
||||
}
|
||||
|
||||
connect();
|
||||
};
|
||||
|
||||
loadAndConnect();
|
||||
|
||||
return () => {
|
||||
if (reconnectTimerRef.current) {
|
||||
clearTimeout(reconnectTimerRef.current);
|
||||
@@ -96,12 +127,13 @@ export function useWebSocket() {
|
||||
|
||||
const sendMessage = useCallback((msg: WSClientMessage) => {
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
// 自动附上 session_id
|
||||
const sessionID = useSessionStore.getState().currentSessionId;
|
||||
wsRef.current.send(JSON.stringify({
|
||||
...msg,
|
||||
session_id: msg.session_id || sessionID || undefined,
|
||||
}));
|
||||
wsRef.current.send(
|
||||
JSON.stringify({
|
||||
...msg,
|
||||
session_id: msg.session_id || sessionID || undefined,
|
||||
})
|
||||
);
|
||||
}
|
||||
}, []);
|
||||
|
||||
@@ -109,7 +141,8 @@ export function useWebSocket() {
|
||||
}
|
||||
|
||||
function handleServerMessage(msg: WSServerMessage) {
|
||||
const { addMessage, appendToLastMessage, finishStreaming, setTyping } = useChatStore.getState();
|
||||
const { addMessage, appendToLastMessage, finishStreaming, setTyping } =
|
||||
useChatStore.getState();
|
||||
const { setMessages } = useSessionStore.getState();
|
||||
const chatState = useChatStore.getState();
|
||||
|
||||
@@ -128,21 +161,22 @@ function handleServerMessage(msg: WSServerMessage) {
|
||||
|
||||
case 'stream_chunk':
|
||||
if (msg.content) {
|
||||
// 首个 chunk 到达时创建消息并隐藏 typing indicator
|
||||
const { messages } = useChatStore.getState();
|
||||
const lastMsg = messages[messages.length - 1];
|
||||
if (!lastMsg || lastMsg.role !== 'assistant' || !lastMsg.isStreaming) {
|
||||
// 创建新的流式消息
|
||||
if (
|
||||
!lastMsg ||
|
||||
lastMsg.role !== 'assistant' ||
|
||||
!lastMsg.isStreaming
|
||||
) {
|
||||
addMessage({
|
||||
id: msg.message_id || ('msg_' + Date.now()),
|
||||
id: msg.message_id || 'msg_' + Date.now(),
|
||||
role: 'assistant',
|
||||
content: msg.content,
|
||||
timestamp: msg.timestamp,
|
||||
isStreaming: true,
|
||||
});
|
||||
setTyping(false); // 首个 chunk 到达,隐藏 typing 指示器
|
||||
setTyping(false);
|
||||
} else {
|
||||
// 追加到现有流式消息
|
||||
appendToLastMessage(msg.content);
|
||||
}
|
||||
}
|
||||
@@ -154,28 +188,23 @@ function handleServerMessage(msg: WSServerMessage) {
|
||||
|
||||
case 'history_response':
|
||||
if (msg.messages) {
|
||||
// 确保每条消息都有 id
|
||||
const msgsWithIds = msg.messages.map((m: any, i: number) => ({
|
||||
...m,
|
||||
id: m.id || `hist_${i}_${Date.now()}`,
|
||||
}));
|
||||
// 同步历史消息到两个 store
|
||||
setMessages(msgsWithIds);
|
||||
useChatStore.getState().setMessages(msgsWithIds);
|
||||
}
|
||||
// 确保历史加载后 typing indicator 关闭
|
||||
setTyping(false);
|
||||
break;
|
||||
|
||||
case 'device_update':
|
||||
// 处理 IoT 设备状态更新
|
||||
if (msg.devices && msg.devices.length > 0) {
|
||||
chatState.setIoTDevices(msg.devices);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'background_thinking':
|
||||
// 处理后端推送的后台思考状态
|
||||
if (msg.thinking_status) {
|
||||
chatState.setBackgroundThinkingStatus(msg.thinking_status);
|
||||
}
|
||||
@@ -187,7 +216,6 @@ function handleServerMessage(msg: WSServerMessage) {
|
||||
break;
|
||||
|
||||
case 'pong':
|
||||
// 忽略心跳响应
|
||||
break;
|
||||
|
||||
default:
|
||||
|
||||
@@ -1,15 +1,38 @@
|
||||
import { create } from 'zustand';
|
||||
import type { Session } from '@/types/session';
|
||||
import type { Message } from '@/types/chat';
|
||||
import { fetchSessionMessages as apiFetchMessages } from '@/api/client';
|
||||
import {
|
||||
fetchSessions,
|
||||
fetchMessages,
|
||||
createSession as apiCreateSession,
|
||||
deleteSession as apiDeleteSession,
|
||||
deleteAllSessions as apiDeleteAllSessions,
|
||||
clearSessionMessages as apiClearMessages,
|
||||
} from '@/api/sessions';
|
||||
import { useChatStore } from '@/store/chatStore';
|
||||
|
||||
/** 生成简易随机ID */
|
||||
function randomID(n: number = 12): string {
|
||||
const letters = 'abcdefghijklmnopqrstuvwxyz0123456789';
|
||||
let result = '';
|
||||
for (let i = 0; i < n; i++) {
|
||||
result += letters.charAt(Math.floor(Math.random() * letters.length));
|
||||
}
|
||||
return `session_${result}`;
|
||||
}
|
||||
|
||||
/** 判断是否为管理员用户 (user_id 以 "admin_" 开头) */
|
||||
export function isAdminUser(userId: string | null): boolean {
|
||||
return userId?.startsWith('admin_') ?? false;
|
||||
}
|
||||
|
||||
interface SessionStore {
|
||||
sessions: Session[];
|
||||
currentSessionId: string | null;
|
||||
loading: boolean;
|
||||
messages: Message[];
|
||||
|
||||
// 基础操作
|
||||
setSessions: (sessions: Session[]) => void;
|
||||
addSession: (session: Session) => void;
|
||||
removeSession: (id: string) => void;
|
||||
@@ -17,9 +40,17 @@ interface SessionStore {
|
||||
setLoading: (loading: boolean) => void;
|
||||
setMessages: (messages: Message[]) => void;
|
||||
clearMessages: () => void;
|
||||
|
||||
// 服务端持久化操作
|
||||
loadSessionsFromServer: (userId: string) => Promise<void>;
|
||||
loadMessagesFromServer: (sessionId: string) => Promise<void>;
|
||||
clearMainSessionMessages: (sessionId: string) => Promise<boolean>;
|
||||
deleteSessionAndRefresh: (id: string, userId: string) => Promise<void>;
|
||||
deleteAllSessionsAndReset: (userId: string) => Promise<void>;
|
||||
ensureMainSession: (userId: string) => Promise<Session | null>;
|
||||
}
|
||||
|
||||
export const useSessionStore = create<SessionStore>((set) => ({
|
||||
export const useSessionStore = create<SessionStore>((set, get) => ({
|
||||
sessions: [],
|
||||
currentSessionId: null,
|
||||
loading: false,
|
||||
@@ -34,46 +65,143 @@ export const useSessionStore = create<SessionStore>((set) => ({
|
||||
currentSessionId: state.currentSessionId === id ? null : state.currentSessionId,
|
||||
messages: state.currentSessionId === id ? [] : state.messages,
|
||||
})),
|
||||
setCurrentSessionId: async (id) => {
|
||||
// 立即清除旧消息,防止闪旧数据
|
||||
set({ currentSessionId: id, messages: [], loading: true });
|
||||
useChatStore.getState().clearMessages();
|
||||
|
||||
// 清除旧消息(同时清 chatStore)
|
||||
if (id === null) {
|
||||
set({ messages: [], loading: false });
|
||||
return;
|
||||
}
|
||||
|
||||
// 从后端加载历史消息
|
||||
try {
|
||||
const resp = await apiFetchMessages(id);
|
||||
if (resp.data) {
|
||||
const data = resp.data as { messages: Message[] };
|
||||
const msgs = (data.messages || []).map((m: Message, i: number) => ({
|
||||
...m,
|
||||
id: m.id || `hist_${i}_${Date.now()}`,
|
||||
}));
|
||||
set({ messages: msgs, loading: false });
|
||||
// 同步到 chatStore 以便 ChatContainer 渲染
|
||||
useChatStore.getState().setMessages(msgs);
|
||||
} else {
|
||||
set({ messages: [], loading: false });
|
||||
useChatStore.getState().clearMessages();
|
||||
}
|
||||
} catch {
|
||||
set({ messages: [], loading: false });
|
||||
setCurrentSessionId: (id) => {
|
||||
set({ currentSessionId: id });
|
||||
// 切换会话时清空旧消息,等待加载
|
||||
if (id !== get().currentSessionId) {
|
||||
set({ messages: [], loading: true });
|
||||
useChatStore.getState().clearMessages();
|
||||
}
|
||||
},
|
||||
setLoading: (loading) => set({ loading }),
|
||||
setMessages: (messages) => {
|
||||
set({ messages });
|
||||
// 同步到 chatStore
|
||||
useChatStore.getState().setMessages(messages);
|
||||
},
|
||||
clearMessages: () => {
|
||||
set({ messages: [] });
|
||||
useChatStore.getState().clearMessages();
|
||||
},
|
||||
|
||||
// ========== 服务端持久化操作 ==========
|
||||
|
||||
/**
|
||||
* 从服务端加载会话列表
|
||||
*/
|
||||
loadSessionsFromServer: async (userId: string) => {
|
||||
set({ loading: true });
|
||||
try {
|
||||
const sessions = await fetchSessions(userId);
|
||||
set({ sessions, loading: false });
|
||||
} catch {
|
||||
set({ loading: false });
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 从服务端加载指定会话的消息历史
|
||||
*/
|
||||
loadMessagesFromServer: async (sessionId: string) => {
|
||||
set({ loading: true });
|
||||
try {
|
||||
const resp = await fetchMessages(sessionId);
|
||||
const rawMessages = resp.messages || [];
|
||||
const msgs: Message[] = rawMessages.map((m: any, i: number) => ({
|
||||
id: m.id ? String(m.id) : `hist_${i}_${Date.now()}`,
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
timestamp: typeof m.created_at === 'number' ? m.created_at : Date.now(),
|
||||
isStreaming: false,
|
||||
}));
|
||||
set({ messages: msgs, loading: false });
|
||||
useChatStore.getState().setMessages(msgs);
|
||||
} catch {
|
||||
set({ messages: [], loading: false });
|
||||
useChatStore.getState().clearMessages();
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 清空主对话的消息但保留会话本身
|
||||
*/
|
||||
clearMainSessionMessages: async (sessionId: string) => {
|
||||
const ok = await apiClearMessages(sessionId);
|
||||
if (ok) {
|
||||
set({ messages: [] });
|
||||
useChatStore.getState().clearMessages();
|
||||
}
|
||||
return ok;
|
||||
},
|
||||
|
||||
/**
|
||||
* 删除会话并自动切换到下一个可用会话
|
||||
*/
|
||||
deleteSessionAndRefresh: async (id: string, userId: string) => {
|
||||
const ok = await apiDeleteSession(id);
|
||||
if (!ok) return;
|
||||
|
||||
const state = get();
|
||||
const remaining = state.sessions.filter((s) => s.id !== id);
|
||||
const wasCurrent = state.currentSessionId === id;
|
||||
|
||||
// 更新本地列表
|
||||
set({ sessions: remaining });
|
||||
|
||||
if (wasCurrent) {
|
||||
if (remaining.length > 0) {
|
||||
// 切换到列表中的第一个会话
|
||||
const nextId = remaining[0].id;
|
||||
set({ currentSessionId: nextId });
|
||||
await get().loadMessagesFromServer(nextId);
|
||||
} else {
|
||||
// 没有会话了:管理员回到主对话,普通用户创建新对话
|
||||
set({ currentSessionId: null, messages: [] });
|
||||
useChatStore.getState().clearMessages();
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 删除所有会话并重置状态
|
||||
*/
|
||||
deleteAllSessionsAndReset: async (userId: string) => {
|
||||
const ok = await apiDeleteAllSessions(userId);
|
||||
if (!ok) return;
|
||||
|
||||
set({ sessions: [], currentSessionId: null, messages: [] });
|
||||
useChatStore.getState().clearMessages();
|
||||
},
|
||||
|
||||
/**
|
||||
* 确保主对话存在(管理员用户)。若不存在则创建并返回。
|
||||
*/
|
||||
ensureMainSession: async (userId: string) => {
|
||||
const state = get();
|
||||
// 先检查本地列表
|
||||
const existing = state.sessions.find((s) => s.is_main);
|
||||
if (existing) return existing;
|
||||
|
||||
// 本地没有,尝试从服务端加载
|
||||
await get().loadSessionsFromServer(userId);
|
||||
const refreshed = get().sessions.find((s) => s.is_main);
|
||||
if (refreshed) return refreshed;
|
||||
|
||||
// 服务端也没有,创建主对话
|
||||
const sid = randomID();
|
||||
const created = await apiCreateSession(userId, sid, '主对话', true);
|
||||
if (created) {
|
||||
const newSession: Session = {
|
||||
id: created.id,
|
||||
user_id: created.user_id,
|
||||
title: created.title,
|
||||
is_main: created.is_main,
|
||||
created_at: String(created.created_at),
|
||||
updated_at: String(created.updated_at),
|
||||
message_count: 0,
|
||||
};
|
||||
set((s) => ({ sessions: [newSession, ...s.sessions] }));
|
||||
return newSession;
|
||||
}
|
||||
return null;
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -5,15 +5,16 @@ export interface Session {
|
||||
id: string;
|
||||
user_id: string;
|
||||
title: string;
|
||||
is_main: boolean;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
message_count: number;
|
||||
is_active: boolean;
|
||||
message_count?: number;
|
||||
}
|
||||
|
||||
/** 创建会话参数 */
|
||||
export interface CreateSessionParams {
|
||||
title?: string;
|
||||
is_main?: boolean;
|
||||
}
|
||||
|
||||
/** 会话列表响应 */
|
||||
@@ -26,6 +27,16 @@ export interface SessionMessagesResponse {
|
||||
messages: import('@/types/chat').Message[];
|
||||
}
|
||||
|
||||
/** 单个会话响应 */
|
||||
export interface SessionResponse {
|
||||
id: string;
|
||||
user_id: string;
|
||||
title: string;
|
||||
is_main: boolean;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
/** 认证相关 */
|
||||
export interface AuthResponse {
|
||||
user_id: string;
|
||||
|
||||
Reference in New Issue
Block a user