fix(security): 修复 P0 安全漏洞 (Session越权+CORS白名单+用户名枚举)

This commit is contained in:
2026-05-21 16:12:54 +08:00
parent 702d4ee1fe
commit 380cc24913
7 changed files with 161 additions and 15 deletions
+3
View File
@@ -55,6 +55,9 @@ ENABLE_BACKGROUND_THINKING=true
# ========== Webhook (第三方平台接入) ==========
WEBHOOK_API_KEY=your-webhook-api-key
# ========== CORS 跨域白名单 (逗号分隔) ==========
ALLOWED_ORIGINS=http://localhost:5173,http://localhost:5199,http://localhost:3000
# ========== 记忆系统 ==========
MEMORY_FILE_PATH=./data/memory
VECTOR_DB_URL=http://localhost:6333
+1 -1
View File
@@ -158,7 +158,7 @@ func main() {
r := gin.New()
// 中间件
r.Use(middleware.CORS())
r.Use(middleware.CORS(cfg.AllowedOrigins))
r.Use(middleware.RequestLogging())
r.Use(gin.Recovery())
+23 -1
View File
@@ -3,6 +3,7 @@ package config
import (
"fmt"
"os"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
@@ -69,6 +70,9 @@ type Config struct {
// Internal Service Token (内部服务间认证)
InternalServiceToken string
// CORS 允许的 Origin 白名单
AllowedOrigins []string
// 每日简报时间 (HH:MM 格式)
BriefingTime string
}
@@ -114,12 +118,14 @@ func Load() *Config {
LLMAPIKey: getEnv("LLM_API_KEY", ""),
LLMModel: getEnv("LLM_MODEL", "gpt-4o"),
WSMaxConnections: getEnvInt("WS_MAX_CONNECTIONS", 1000),
WSMaxConnections: getEnvInt("WS_MAX_CONNECTIONS", 1000),
SessionIdleTimeoutMin: getEnvInt("SESSION_IDLE_TIMEOUT_MIN", 30),
WebhookAPIKey: getEnv("WEBHOOK_API_KEY", ""),
InternalServiceToken: getEnv("INTERNAL_SERVICE_TOKEN", "cyrene-internal-token-change-me"),
AllowedOrigins: parseAllowedOrigins(getEnv("ALLOWED_ORIGINS", "http://localhost:5173,http://localhost:5199,http://localhost:3000")),
BriefingTime: getEnv("BRIEFING_TIME", "08:00"),
}
}
@@ -195,3 +201,19 @@ func getEnvBool(key string, fallback bool) bool {
}
return v == "true" || v == "1" || v == "yes"
}
// parseAllowedOrigins 解析逗号分隔的 origins 字符串为切片
func parseAllowedOrigins(s string) []string {
if s == "" {
return []string{}
}
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
result = append(result, p)
}
}
return result
}
@@ -153,7 +153,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
} else if req.Username == h.cfg.AdminUsername && h.db != nil {
// 管理员用户尚未迁移到 users 表,尝试用配置中的密码验证
if req.Password != h.cfg.AdminPassword {
c.JSON(http.StatusUnauthorized, gin.H{"error": "管理员密码错误"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户名或密码错误"})
return
}
// 密码正确,迁移 admin 到 users 表
@@ -171,7 +171,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
} else if req.Username == h.cfg.AdminUsername {
// 数据库不可用时的回退:使用配置中的管理员密码
if req.Password != h.cfg.AdminPassword {
c.JSON(http.StatusUnauthorized, gin.H{"error": "管理员密码错误"})
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户名或密码错误"})
return
}
userID = "admin"
@@ -32,6 +32,16 @@ func NewSessionHandler(hub *ws.Hub, s *store.SessionStore) *SessionHandler {
}
}
// canAccess 检查当前用户是否有权访问指定 userID 所属的资源
// admin 用户豁免所有权检查
func (h *SessionHandler) canAccess(c *gin.Context, resourceOwnerID string) bool {
currentUserID := middleware.GetUserID(c)
if middleware.GetIsAdmin(c) {
return true
}
return currentUserID == resourceOwnerID
}
// ========== POST /api/v1/sessions — 创建会话 ==========
type createSessionRequest struct {
@@ -49,9 +59,6 @@ func (h *SessionHandler) Create(c *gin.Context) {
if err := c.ShouldBindJSON(&req); err != nil {
// 允许空 body
}
if req.UserID != "" {
userID = req.UserID
}
if req.Title == "" {
req.Title = "新的对话"
}
@@ -81,9 +88,19 @@ func (h *SessionHandler) Create(c *gin.Context) {
// List 获取会话列表 (按 updated_at DESC 排序)
func (h *SessionHandler) List(c *gin.Context) {
currentUserID := middleware.GetUserID(c)
isAdmin := middleware.GetIsAdmin(c)
// 非管理员只能查询自己的会话
userID := c.Query("user_id")
if userID == "" {
userID = middleware.GetUserID(c)
userID = currentUserID
} else if !isAdmin {
// 非管理员试图查询其他用户的会话
if userID != currentUserID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问其他用户的会话"})
return
}
}
if h.useDB {
@@ -134,6 +151,13 @@ func (h *SessionHandler) Get(c *gin.Context) {
})
return
}
// 所有权校验:非管理员只能访问自己的会话
if !h.canAccess(c, session.UserID) {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此会话"})
return
}
c.JSON(http.StatusOK, gin.H{
"id": session.ID,
"user_id": session.UserID,
@@ -159,6 +183,25 @@ func (h *SessionHandler) Delete(c *gin.Context) {
sessionID := c.Param("id")
if h.useDB {
// 所有权校验:先获取session再验证归属
session, err := h.store.GetSession(sessionID)
if err != nil {
log.Printf("[SessionHandler] 查询会话失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除会话失败", "errorType": "db_error"})
return
}
if session == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "会话不存在",
"errorType": "session_not_found",
})
return
}
if !h.canAccess(c, session.UserID) {
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此会话"})
return
}
if err := h.store.DeleteSession(sessionID); err != nil {
log.Printf("[SessionHandler] 删除会话失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除会话失败", "errorType": "db_error"})
@@ -176,9 +219,18 @@ func (h *SessionHandler) Delete(c *gin.Context) {
// DeleteAll 删除用户所有会话 (不删除记忆)
func (h *SessionHandler) DeleteAll(c *gin.Context) {
currentUserID := middleware.GetUserID(c)
isAdmin := middleware.GetIsAdmin(c)
userID := c.Query("user_id")
if userID == "" {
userID = middleware.GetUserID(c)
userID = currentUserID
} else if !isAdmin {
// 非管理员只能删除自己的会话
if userID != currentUserID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除其他用户的会话"})
return
}
}
if h.useDB {
@@ -197,6 +249,28 @@ func (h *SessionHandler) DeleteAll(c *gin.Context) {
// GetMessages 获取会话的完整消息列表
func (h *SessionHandler) GetMessages(c *gin.Context) {
sessionID := c.Param("id")
// 所有权校验
if h.useDB {
session, err := h.store.GetSession(sessionID)
if err != nil {
log.Printf("[SessionHandler] 查询会话失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询消息失败", "errorType": "db_error"})
return
}
if session == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "会话不存在",
"errorType": "session_not_found",
})
return
}
if !h.canAccess(c, session.UserID) {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此会话的消息"})
return
}
}
limit := 50
if l := c.Query("limit"); l != "" {
parsed := 0
@@ -248,6 +322,25 @@ func (h *SessionHandler) ClearMessages(c *gin.Context) {
sessionID := c.Param("id")
if h.useDB {
// 所有权校验
session, err := h.store.GetSession(sessionID)
if err != nil {
log.Printf("[SessionHandler] 查询会话失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空消息失败", "errorType": "db_error"})
return
}
if session == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "会话不存在",
"errorType": "session_not_found",
})
return
}
if !h.canAccess(c, session.UserID) {
c.JSON(http.StatusForbidden, gin.H{"error": "无权清空此会话的消息"})
return
}
if err := h.store.ClearSessionMessages(sessionID); err != nil {
log.Printf("[SessionHandler] 清空消息失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空消息失败", "errorType": "db_error"})
@@ -436,6 +529,12 @@ func (h *SessionHandler) ExportSession(c *gin.Context) {
return
}
// 所有权校验:非管理员只能导出自己的会话
if !h.canAccess(c, session.UserID) {
c.JSON(http.StatusForbidden, gin.H{"error": "无权导出此会话"})
return
}
// 获取所有消息 (不限制数量,导出全部)
messages, err := h.store.GetMessages(sessionID, 0)
if err != nil {
@@ -55,3 +55,12 @@ func GetUserID(c *gin.Context) string {
}
return userID.(string)
}
// GetIsAdmin 从上下文获取是否为管理员
func GetIsAdmin(c *gin.Context) bool {
isAdmin, _ := c.Get(IsAdminKey)
if isAdmin == nil {
return false
}
return isAdmin.(bool)
}
+19 -6
View File
@@ -6,17 +6,30 @@ import (
"github.com/gin-gonic/gin"
)
// CORS 跨域中间件 (含安全头)
func CORS() gin.HandlerFunc {
// CORS 跨域中间件 (含安全头) — 白名单模式
func CORS(allowedOrigins []string) gin.HandlerFunc {
return func(c *gin.Context) {
origin := c.Request.Header.Get("Origin")
if origin == "" {
origin = "*"
// 白名单校验
allowed := false
if origin != "" {
for _, o := range allowedOrigins {
if o == origin {
allowed = true
break
}
}
}
c.Header("Access-Control-Allow-Origin", origin)
// 仅对白名单中的 origin 设置 CORS 头
if allowed {
c.Header("Access-Control-Allow-Origin", origin)
c.Header("Access-Control-Allow-Credentials", "true")
}
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Request-ID")
c.Header("Access-Control-Allow-Credentials", "true")
c.Header("Access-Control-Max-Age", "86400")
// 安全头