fix(security): 修复 P0 安全漏洞 (Session越权+CORS白名单+用户名枚举)

This commit is contained in:
2026-05-21 16:12:54 +08:00
parent 702d4ee1fe
commit 380cc24913
7 changed files with 161 additions and 15 deletions
@@ -153,7 +153,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
} else if req.Username == h.cfg.AdminUsername && h.db != nil {
// 管理员用户尚未迁移到 users 表,尝试用配置中的密码验证
if req.Password != h.cfg.AdminPassword {
c.JSON(http.StatusUnauthorized, gin.H{"error": "管理员密码错误"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户名或密码错误"})
return
}
// 密码正确,迁移 admin 到 users 表
@@ -171,7 +171,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
} else if req.Username == h.cfg.AdminUsername {
// 数据库不可用时的回退:使用配置中的管理员密码
if req.Password != h.cfg.AdminPassword {
c.JSON(http.StatusUnauthorized, gin.H{"error": "管理员密码错误"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户名或密码错误"})
return
}
userID = "admin"
@@ -32,6 +32,16 @@ func NewSessionHandler(hub *ws.Hub, s *store.SessionStore) *SessionHandler {
}
}
// canAccess 检查当前用户是否有权访问指定 userID 所属的资源
// admin 用户豁免所有权检查
func (h *SessionHandler) canAccess(c *gin.Context, resourceOwnerID string) bool {
currentUserID := middleware.GetUserID(c)
if middleware.GetIsAdmin(c) {
return true
}
return currentUserID == resourceOwnerID
}
// ========== POST /api/v1/sessions — 创建会话 ==========
type createSessionRequest struct {
@@ -49,9 +59,6 @@ func (h *SessionHandler) Create(c *gin.Context) {
if err := c.ShouldBindJSON(&req); err != nil {
// 允许空 body
}
if req.UserID != "" {
userID = req.UserID
}
if req.Title == "" {
req.Title = "新的对话"
}
@@ -81,9 +88,19 @@ func (h *SessionHandler) Create(c *gin.Context) {
// List 获取会话列表 (按 updated_at DESC 排序)
func (h *SessionHandler) List(c *gin.Context) {
currentUserID := middleware.GetUserID(c)
isAdmin := middleware.GetIsAdmin(c)
// 非管理员只能查询自己的会话
userID := c.Query("user_id")
if userID == "" {
userID = middleware.GetUserID(c)
userID = currentUserID
} else if !isAdmin {
// 非管理员试图查询其他用户的会话
if userID != currentUserID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问其他用户的会话"})
return
}
}
if h.useDB {
@@ -134,6 +151,13 @@ func (h *SessionHandler) Get(c *gin.Context) {
})
return
}
// 所有权校验:非管理员只能访问自己的会话
if !h.canAccess(c, session.UserID) {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此会话"})
return
}
c.JSON(http.StatusOK, gin.H{
"id": session.ID,
"user_id": session.UserID,
@@ -159,6 +183,25 @@ func (h *SessionHandler) Delete(c *gin.Context) {
sessionID := c.Param("id")
if h.useDB {
// 所有权校验:先获取session再验证归属
session, err := h.store.GetSession(sessionID)
if err != nil {
log.Printf("[SessionHandler] 查询会话失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除会话失败", "errorType": "db_error"})
return
}
if session == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "会话不存在",
"errorType": "session_not_found",
})
return
}
if !h.canAccess(c, session.UserID) {
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此会话"})
return
}
if err := h.store.DeleteSession(sessionID); err != nil {
log.Printf("[SessionHandler] 删除会话失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除会话失败", "errorType": "db_error"})
@@ -176,9 +219,18 @@ func (h *SessionHandler) Delete(c *gin.Context) {
// DeleteAll 删除用户所有会话 (不删除记忆)
func (h *SessionHandler) DeleteAll(c *gin.Context) {
currentUserID := middleware.GetUserID(c)
isAdmin := middleware.GetIsAdmin(c)
userID := c.Query("user_id")
if userID == "" {
userID = middleware.GetUserID(c)
userID = currentUserID
} else if !isAdmin {
// 非管理员只能删除自己的会话
if userID != currentUserID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除其他用户的会话"})
return
}
}
if h.useDB {
@@ -197,6 +249,28 @@ func (h *SessionHandler) DeleteAll(c *gin.Context) {
// GetMessages 获取会话的完整消息列表
func (h *SessionHandler) GetMessages(c *gin.Context) {
sessionID := c.Param("id")
// 所有权校验
if h.useDB {
session, err := h.store.GetSession(sessionID)
if err != nil {
log.Printf("[SessionHandler] 查询会话失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询消息失败", "errorType": "db_error"})
return
}
if session == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "会话不存在",
"errorType": "session_not_found",
})
return
}
if !h.canAccess(c, session.UserID) {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此会话的消息"})
return
}
}
limit := 50
if l := c.Query("limit"); l != "" {
parsed := 0
@@ -248,6 +322,25 @@ func (h *SessionHandler) ClearMessages(c *gin.Context) {
sessionID := c.Param("id")
if h.useDB {
// 所有权校验
session, err := h.store.GetSession(sessionID)
if err != nil {
log.Printf("[SessionHandler] 查询会话失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空消息失败", "errorType": "db_error"})
return
}
if session == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "会话不存在",
"errorType": "session_not_found",
})
return
}
if !h.canAccess(c, session.UserID) {
c.JSON(http.StatusForbidden, gin.H{"error": "无权清空此会话的消息"})
return
}
if err := h.store.ClearSessionMessages(sessionID); err != nil {
log.Printf("[SessionHandler] 清空消息失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空消息失败", "errorType": "db_error"})
@@ -436,6 +529,12 @@ func (h *SessionHandler) ExportSession(c *gin.Context) {
return
}
// 所有权校验:非管理员只能导出自己的会话
if !h.canAccess(c, session.UserID) {
c.JSON(http.StatusForbidden, gin.H{"error": "无权导出此会话"})
return
}
// 获取所有消息 (不限制数量,导出全部)
messages, err := h.store.GetMessages(sessionID, 0)
if err != nil {