fix(security): 修复 P0 安全漏洞 (Session越权+CORS白名单+用户名枚举)
This commit is contained in:
@@ -55,6 +55,9 @@ ENABLE_BACKGROUND_THINKING=true
|
||||
# ========== Webhook (第三方平台接入) ==========
|
||||
WEBHOOK_API_KEY=your-webhook-api-key
|
||||
|
||||
# ========== CORS 跨域白名单 (逗号分隔) ==========
|
||||
ALLOWED_ORIGINS=http://localhost:5173,http://localhost:5199,http://localhost:3000
|
||||
|
||||
# ========== 记忆系统 ==========
|
||||
MEMORY_FILE_PATH=./data/memory
|
||||
VECTOR_DB_URL=http://localhost:6333
|
||||
|
||||
@@ -158,7 +158,7 @@ func main() {
|
||||
r := gin.New()
|
||||
|
||||
// 中间件
|
||||
r.Use(middleware.CORS())
|
||||
r.Use(middleware.CORS(cfg.AllowedOrigins))
|
||||
r.Use(middleware.RequestLogging())
|
||||
r.Use(gin.Recovery())
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
@@ -69,6 +70,9 @@ type Config struct {
|
||||
// Internal Service Token (内部服务间认证)
|
||||
InternalServiceToken string
|
||||
|
||||
// CORS 允许的 Origin 白名单
|
||||
AllowedOrigins []string
|
||||
|
||||
// 每日简报时间 (HH:MM 格式)
|
||||
BriefingTime string
|
||||
}
|
||||
@@ -114,12 +118,14 @@ func Load() *Config {
|
||||
LLMAPIKey: getEnv("LLM_API_KEY", ""),
|
||||
LLMModel: getEnv("LLM_MODEL", "gpt-4o"),
|
||||
|
||||
WSMaxConnections: getEnvInt("WS_MAX_CONNECTIONS", 1000),
|
||||
WSMaxConnections: getEnvInt("WS_MAX_CONNECTIONS", 1000),
|
||||
SessionIdleTimeoutMin: getEnvInt("SESSION_IDLE_TIMEOUT_MIN", 30),
|
||||
|
||||
WebhookAPIKey: getEnv("WEBHOOK_API_KEY", ""),
|
||||
InternalServiceToken: getEnv("INTERNAL_SERVICE_TOKEN", "cyrene-internal-token-change-me"),
|
||||
|
||||
AllowedOrigins: parseAllowedOrigins(getEnv("ALLOWED_ORIGINS", "http://localhost:5173,http://localhost:5199,http://localhost:3000")),
|
||||
|
||||
BriefingTime: getEnv("BRIEFING_TIME", "08:00"),
|
||||
}
|
||||
}
|
||||
@@ -195,3 +201,19 @@ func getEnvBool(key string, fallback bool) bool {
|
||||
}
|
||||
return v == "true" || v == "1" || v == "yes"
|
||||
}
|
||||
|
||||
// parseAllowedOrigins 解析逗号分隔的 origins 字符串为切片
|
||||
func parseAllowedOrigins(s string) []string {
|
||||
if s == "" {
|
||||
return []string{}
|
||||
}
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
} else if req.Username == h.cfg.AdminUsername && h.db != nil {
|
||||
// 管理员用户尚未迁移到 users 表,尝试用配置中的密码验证
|
||||
if req.Password != h.cfg.AdminPassword {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "管理员密码错误"})
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户名或密码错误"})
|
||||
return
|
||||
}
|
||||
// 密码正确,迁移 admin 到 users 表
|
||||
@@ -171,7 +171,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
} else if req.Username == h.cfg.AdminUsername {
|
||||
// 数据库不可用时的回退:使用配置中的管理员密码
|
||||
if req.Password != h.cfg.AdminPassword {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "管理员密码错误"})
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户名或密码错误"})
|
||||
return
|
||||
}
|
||||
userID = "admin"
|
||||
|
||||
@@ -32,6 +32,16 @@ func NewSessionHandler(hub *ws.Hub, s *store.SessionStore) *SessionHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// canAccess 检查当前用户是否有权访问指定 userID 所属的资源
|
||||
// admin 用户豁免所有权检查
|
||||
func (h *SessionHandler) canAccess(c *gin.Context, resourceOwnerID string) bool {
|
||||
currentUserID := middleware.GetUserID(c)
|
||||
if middleware.GetIsAdmin(c) {
|
||||
return true
|
||||
}
|
||||
return currentUserID == resourceOwnerID
|
||||
}
|
||||
|
||||
// ========== POST /api/v1/sessions — 创建会话 ==========
|
||||
|
||||
type createSessionRequest struct {
|
||||
@@ -49,9 +59,6 @@ func (h *SessionHandler) Create(c *gin.Context) {
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// 允许空 body
|
||||
}
|
||||
if req.UserID != "" {
|
||||
userID = req.UserID
|
||||
}
|
||||
if req.Title == "" {
|
||||
req.Title = "新的对话"
|
||||
}
|
||||
@@ -81,9 +88,19 @@ func (h *SessionHandler) Create(c *gin.Context) {
|
||||
|
||||
// List 获取会话列表 (按 updated_at DESC 排序)
|
||||
func (h *SessionHandler) List(c *gin.Context) {
|
||||
currentUserID := middleware.GetUserID(c)
|
||||
isAdmin := middleware.GetIsAdmin(c)
|
||||
|
||||
// 非管理员只能查询自己的会话
|
||||
userID := c.Query("user_id")
|
||||
if userID == "" {
|
||||
userID = middleware.GetUserID(c)
|
||||
userID = currentUserID
|
||||
} else if !isAdmin {
|
||||
// 非管理员试图查询其他用户的会话
|
||||
if userID != currentUserID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问其他用户的会话"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if h.useDB {
|
||||
@@ -134,6 +151,13 @@ func (h *SessionHandler) Get(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 所有权校验:非管理员只能访问自己的会话
|
||||
if !h.canAccess(c, session.UserID) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此会话"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": session.ID,
|
||||
"user_id": session.UserID,
|
||||
@@ -159,6 +183,25 @@ func (h *SessionHandler) Delete(c *gin.Context) {
|
||||
sessionID := c.Param("id")
|
||||
|
||||
if h.useDB {
|
||||
// 所有权校验:先获取session再验证归属
|
||||
session, err := h.store.GetSession(sessionID)
|
||||
if err != nil {
|
||||
log.Printf("[SessionHandler] 查询会话失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除会话失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if session == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "会话不存在",
|
||||
"errorType": "session_not_found",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !h.canAccess(c, session.UserID) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此会话"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.DeleteSession(sessionID); err != nil {
|
||||
log.Printf("[SessionHandler] 删除会话失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除会话失败", "errorType": "db_error"})
|
||||
@@ -176,9 +219,18 @@ func (h *SessionHandler) Delete(c *gin.Context) {
|
||||
|
||||
// DeleteAll 删除用户所有会话 (不删除记忆)
|
||||
func (h *SessionHandler) DeleteAll(c *gin.Context) {
|
||||
currentUserID := middleware.GetUserID(c)
|
||||
isAdmin := middleware.GetIsAdmin(c)
|
||||
|
||||
userID := c.Query("user_id")
|
||||
if userID == "" {
|
||||
userID = middleware.GetUserID(c)
|
||||
userID = currentUserID
|
||||
} else if !isAdmin {
|
||||
// 非管理员只能删除自己的会话
|
||||
if userID != currentUserID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除其他用户的会话"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if h.useDB {
|
||||
@@ -197,6 +249,28 @@ func (h *SessionHandler) DeleteAll(c *gin.Context) {
|
||||
// GetMessages 获取会话的完整消息列表
|
||||
func (h *SessionHandler) GetMessages(c *gin.Context) {
|
||||
sessionID := c.Param("id")
|
||||
|
||||
// 所有权校验
|
||||
if h.useDB {
|
||||
session, err := h.store.GetSession(sessionID)
|
||||
if err != nil {
|
||||
log.Printf("[SessionHandler] 查询会话失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询消息失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if session == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "会话不存在",
|
||||
"errorType": "session_not_found",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !h.canAccess(c, session.UserID) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此会话的消息"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if l := c.Query("limit"); l != "" {
|
||||
parsed := 0
|
||||
@@ -248,6 +322,25 @@ func (h *SessionHandler) ClearMessages(c *gin.Context) {
|
||||
sessionID := c.Param("id")
|
||||
|
||||
if h.useDB {
|
||||
// 所有权校验
|
||||
session, err := h.store.GetSession(sessionID)
|
||||
if err != nil {
|
||||
log.Printf("[SessionHandler] 查询会话失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空消息失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if session == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "会话不存在",
|
||||
"errorType": "session_not_found",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !h.canAccess(c, session.UserID) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权清空此会话的消息"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.ClearSessionMessages(sessionID); err != nil {
|
||||
log.Printf("[SessionHandler] 清空消息失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "清空消息失败", "errorType": "db_error"})
|
||||
@@ -436,6 +529,12 @@ func (h *SessionHandler) ExportSession(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 所有权校验:非管理员只能导出自己的会话
|
||||
if !h.canAccess(c, session.UserID) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权导出此会话"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取所有消息 (不限制数量,导出全部)
|
||||
messages, err := h.store.GetMessages(sessionID, 0)
|
||||
if err != nil {
|
||||
|
||||
@@ -55,3 +55,12 @@ func GetUserID(c *gin.Context) string {
|
||||
}
|
||||
return userID.(string)
|
||||
}
|
||||
|
||||
// GetIsAdmin 从上下文获取是否为管理员
|
||||
func GetIsAdmin(c *gin.Context) bool {
|
||||
isAdmin, _ := c.Get(IsAdminKey)
|
||||
if isAdmin == nil {
|
||||
return false
|
||||
}
|
||||
return isAdmin.(bool)
|
||||
}
|
||||
|
||||
@@ -6,17 +6,30 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域中间件 (含安全头)
|
||||
func CORS() gin.HandlerFunc {
|
||||
// CORS 跨域中间件 (含安全头) — 白名单模式
|
||||
func CORS(allowedOrigins []string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
origin = "*"
|
||||
|
||||
// 白名单校验
|
||||
allowed := false
|
||||
if origin != "" {
|
||||
for _, o := range allowedOrigins {
|
||||
if o == origin {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
|
||||
// 仅对白名单中的 origin 设置 CORS 头
|
||||
if allowed {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Request-ID")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
|
||||
// 安全头
|
||||
|
||||
Reference in New Issue
Block a user