dev 分支暂存
This commit is contained in:
@@ -0,0 +1,54 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
)
|
||||
|
||||
// Auth 用户键值在context中的key
|
||||
const UserIDKey = "user_id"
|
||||
|
||||
// JWTAuth JWT认证中间件
|
||||
func JWTAuth(cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未提供认证令牌"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Bearer token
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证格式错误"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
userID, err := cfg.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证令牌无效或已过期"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将userID注入上下文
|
||||
c.Set(UserIDKey, userID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserID 从上下文获取用户ID
|
||||
func GetUserID(c *gin.Context) string {
|
||||
userID, _ := c.Get(UserIDKey)
|
||||
if userID == nil {
|
||||
return ""
|
||||
}
|
||||
return userID.(string)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Request-ID")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
|
||||
// 预检请求
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RequestLogging 请求日志中间件
|
||||
func RequestLogging() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 记录日志
|
||||
duration := time.Since(start)
|
||||
statusCode := c.Writer.Status()
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
logLevel := "[INFO]"
|
||||
if statusCode >= 500 {
|
||||
logLevel = "[ERROR]"
|
||||
} else if statusCode >= 400 {
|
||||
logLevel = "[WARN]"
|
||||
}
|
||||
|
||||
log.Printf("%s %s %s %d %v %s",
|
||||
logLevel, method, path, statusCode, duration, clientIP,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RateLimiter 基于内存令牌桶的限流中间件
|
||||
// MVP阶段使用内存实现,后续可迁移到Redis
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string]*tokenBucket
|
||||
rate float64 // 每秒生成的令牌数
|
||||
burst int // 桶容量
|
||||
}
|
||||
|
||||
type tokenBucket struct {
|
||||
tokens float64
|
||||
lastTime time.Time
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建限流器
|
||||
func NewRateLimiter(rate float64, burst int) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
buckets: make(map[string]*tokenBucket),
|
||||
rate: rate,
|
||||
burst: burst,
|
||||
}
|
||||
|
||||
// 定期清理过期桶
|
||||
go rl.cleanup()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// Handler 返回Gin中间件
|
||||
func (rl *RateLimiter) Handler() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
key := c.ClientIP() // 按IP限流
|
||||
|
||||
if !rl.allow(key) {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "请求过于频繁,请稍后再试",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) allow(key string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
bucket, ok := rl.buckets[key]
|
||||
now := time.Now()
|
||||
|
||||
if !ok {
|
||||
rl.buckets[key] = &tokenBucket{
|
||||
tokens: float64(rl.burst) - 1,
|
||||
lastTime: now,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 补充令牌
|
||||
elapsed := now.Sub(bucket.lastTime).Seconds()
|
||||
bucket.tokens += elapsed * rl.rate
|
||||
if bucket.tokens > float64(rl.burst) {
|
||||
bucket.tokens = float64(rl.burst)
|
||||
}
|
||||
bucket.lastTime = now
|
||||
|
||||
// 消耗令牌
|
||||
if bucket.tokens >= 1 {
|
||||
bucket.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// cleanup 定期清理长时间未使用的桶
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
for {
|
||||
time.Sleep(5 * time.Minute)
|
||||
|
||||
rl.mu.Lock()
|
||||
cutoff := time.Now().Add(-10 * time.Minute)
|
||||
for key, bucket := range rl.buckets {
|
||||
if bucket.lastTime.Before(cutoff) {
|
||||
delete(rl.buckets, key)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user