fix: round 10 critical fixes - WebSocket race, rate limiting, XSS protection, Caddyfile, and input validation
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -16,6 +17,9 @@ import (
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||||
)
|
||||
|
||||
// usernameRegex 用户名格式校验:仅允许字母、数字、下划线,长度 3-32
|
||||
var usernameRegex = regexp.MustCompile(`^[a-zA-Z0-9_]{3,32}$`)
|
||||
|
||||
// AuthHandler 认证处理器
|
||||
type AuthHandler struct {
|
||||
cfg *config.Config
|
||||
@@ -49,6 +53,12 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 用户名格式校验:仅允许字母、数字、下划线,长度 3-32
|
||||
if !usernameRegex.MatchString(req.Username) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "用户名格式无效:仅允许字母、数字和下划线,长度 3-32 位"})
|
||||
return
|
||||
}
|
||||
|
||||
// MVP阶段:验证码简单校验 (开发环境接受 "000000")
|
||||
if req.VerifyCode != "000000" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "验证码错误 (开发阶段请使用 000000)"})
|
||||
@@ -118,6 +128,12 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 用户名格式校验:仅允许字母、数字、下划线,长度 3-32
|
||||
if !usernameRegex.MatchString(req.Username) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "用户名格式无效"})
|
||||
return
|
||||
}
|
||||
|
||||
var userID string
|
||||
|
||||
// 尝试从 users 表查询用户
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"html"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -77,8 +78,8 @@ func (h *KnowledgeHandler) CreateKB(c *gin.Context) {
|
||||
kb := &store.KnowledgeBase{
|
||||
ID: store.GenerateUUID(),
|
||||
UserID: userID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Name: html.EscapeString(req.Name),
|
||||
Description: html.EscapeString(req.Description),
|
||||
}
|
||||
|
||||
if err := h.store.CreateKB(kb); err != nil {
|
||||
@@ -175,7 +176,7 @@ func (h *KnowledgeHandler) UpdateKB(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.UpdateKB(kbID, req.Name, req.Description); err != nil {
|
||||
if err := h.store.UpdateKB(kbID, html.EscapeString(req.Name), html.EscapeString(req.Description)); err != nil {
|
||||
log.Printf("[KnowledgeHandler] 更新知识库失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新知识库失败", "errorType": "db_error"})
|
||||
return
|
||||
@@ -315,8 +316,8 @@ func (h *KnowledgeHandler) AddDocument(c *gin.Context) {
|
||||
ID: store.GenerateUUID(),
|
||||
KBID: kbID,
|
||||
UserID: userID,
|
||||
Title: req.Title,
|
||||
SourceType: req.SourceType,
|
||||
Title: html.EscapeString(req.Title),
|
||||
SourceType: html.EscapeString(req.SourceType),
|
||||
SourceRef: sourceRef,
|
||||
ContentType: contentType,
|
||||
RawContent: content,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -144,11 +145,11 @@ func (h *MemoryHandler) Add(c *gin.Context) {
|
||||
userID = req.UserID
|
||||
}
|
||||
|
||||
// 转发到 Memory-Service
|
||||
// 转发到 Memory-Service(对用户输入进行 HTML 转义防 XSS)
|
||||
memReq := map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"content": req.Content,
|
||||
"category": req.Category,
|
||||
"content": html.EscapeString(req.Content),
|
||||
"category": html.EscapeString(req.Category),
|
||||
"priority": req.Priority,
|
||||
}
|
||||
reqBody, _ := json.Marshal(memReq)
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"html"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -113,8 +114,8 @@ func (h *ReminderHandler) Create(c *gin.Context) {
|
||||
reminder := &store.Reminder{
|
||||
ID: generateID(),
|
||||
UserID: userID,
|
||||
Title: req.Title,
|
||||
Description: req.Description,
|
||||
Title: html.EscapeString(req.Title),
|
||||
Description: html.EscapeString(req.Description),
|
||||
RemindAt: remindAt,
|
||||
Status: "pending",
|
||||
RepeatType: repeatType,
|
||||
|
||||
@@ -53,6 +53,30 @@ func (rl *RateLimiter) Handler() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// HandlerWithKey 返回按自定义 key 限流的中间件(如 IP + 端点组合)
|
||||
func (rl *RateLimiter) HandlerWithKey(keyFn func(c *gin.Context) string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
key := keyFn(c)
|
||||
|
||||
if !rl.allow(key) {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "请求过于频繁,请稍后再试",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AuthIPKey 返回按 IP + 端点限流的 key(用于认证端点)
|
||||
func AuthIPKey(endpoint string) func(c *gin.Context) string {
|
||||
return func(c *gin.Context) string {
|
||||
return "auth_" + endpoint + "_" + c.ClientIP()
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) allow(key string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
@@ -52,11 +52,14 @@ func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.S
|
||||
})
|
||||
})
|
||||
|
||||
// 认证路由专用限流器:每分钟每个IP每个端点最多5次请求(防暴力破解)
|
||||
authRateLimiter := middleware.NewRateLimiter(0.083, 5) // ~5 per minute per IP+endpoint
|
||||
|
||||
// 认证 (无需JWT)
|
||||
auth := api.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", authHandler.Register)
|
||||
auth.POST("/login", authHandler.Login)
|
||||
auth.POST("/register", authRateLimiter.HandlerWithKey(middleware.AuthIPKey("register")), authHandler.Register)
|
||||
auth.POST("/login", authRateLimiter.HandlerWithKey(middleware.AuthIPKey("login")), authHandler.Login)
|
||||
}
|
||||
|
||||
// ========== 需要认证的路由 ==========
|
||||
|
||||
@@ -236,17 +236,59 @@ func (h *Hub) Run() {
|
||||
client.UserID, client.SessionID, len(h.clients))
|
||||
|
||||
case message := <-h.broadcast:
|
||||
// 两阶段广播:Phase 1 在 RLock 下收集失效客户端,Phase 2 在 Lock 下清理
|
||||
var staleClients []*Client
|
||||
h.mu.RLock()
|
||||
for client := range h.clients {
|
||||
select {
|
||||
case client.Send <- message:
|
||||
default:
|
||||
// 客户端发送通道已满,跳过
|
||||
close(client.Send)
|
||||
delete(h.clients, client)
|
||||
// 客户端发送通道已满,标记为失效
|
||||
staleClients = append(staleClients, client)
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
// Phase 2: 在写锁下清理失效客户端
|
||||
if len(staleClients) > 0 {
|
||||
h.mu.Lock()
|
||||
for _, client := range staleClients {
|
||||
// 二次检查:客户端可能已被 unregister 移除
|
||||
if _, ok := h.clients[client]; !ok {
|
||||
continue
|
||||
}
|
||||
delete(h.clients, client)
|
||||
close(client.Send)
|
||||
|
||||
// 清理用户索引
|
||||
if h.userClients[client.UserID] != nil {
|
||||
delete(h.userClients[client.UserID], client)
|
||||
if len(h.userClients[client.UserID]) == 0 {
|
||||
delete(h.userClients, client.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查该 session 是否还有其他连接
|
||||
hasOtherConn := false
|
||||
if clients, ok := h.userClients[client.UserID]; ok {
|
||||
for c := range clients {
|
||||
if c.SessionID == client.SessionID {
|
||||
hasOtherConn = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasOtherConn {
|
||||
if s, ok := h.sessions[client.SessionID]; ok {
|
||||
s.State = "idle"
|
||||
}
|
||||
}
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
log.Printf("[WS] 广播清理 %d 个失效客户端 (当前连接数: %d)",
|
||||
len(staleClients), len(h.clients))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user