127 lines
2.4 KiB
Go
127 lines
2.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// RateLimiter 基于内存令牌桶的限流中间件
|
|
// MVP阶段使用内存实现,后续可迁移到Redis
|
|
type RateLimiter struct {
|
|
mu sync.Mutex
|
|
buckets map[string]*tokenBucket
|
|
rate float64 // 每秒生成的令牌数
|
|
burst int // 桶容量
|
|
}
|
|
|
|
type tokenBucket struct {
|
|
tokens float64
|
|
lastTime time.Time
|
|
}
|
|
|
|
// NewRateLimiter 创建限流器
|
|
func NewRateLimiter(rate float64, burst int) *RateLimiter {
|
|
rl := &RateLimiter{
|
|
buckets: make(map[string]*tokenBucket),
|
|
rate: rate,
|
|
burst: burst,
|
|
}
|
|
|
|
// 定期清理过期桶
|
|
go rl.cleanup()
|
|
|
|
return rl
|
|
}
|
|
|
|
// Handler 返回Gin中间件
|
|
func (rl *RateLimiter) Handler() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
key := c.ClientIP() // 按IP限流
|
|
|
|
if !rl.allow(key) {
|
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
|
"error": "请求过于频繁,请稍后再试",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// HandlerWithKey 返回按自定义 key 限流的中间件(如 IP + 端点组合)
|
|
func (rl *RateLimiter) HandlerWithKey(keyFn func(c *gin.Context) string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
key := keyFn(c)
|
|
|
|
if !rl.allow(key) {
|
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
|
"error": "请求过于频繁,请稍后再试",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// AuthIPKey 返回按 IP + 端点限流的 key(用于认证端点)
|
|
func AuthIPKey(endpoint string) func(c *gin.Context) string {
|
|
return func(c *gin.Context) string {
|
|
return "auth_" + endpoint + "_" + c.ClientIP()
|
|
}
|
|
}
|
|
|
|
func (rl *RateLimiter) allow(key string) bool {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
bucket, ok := rl.buckets[key]
|
|
now := time.Now()
|
|
|
|
if !ok {
|
|
rl.buckets[key] = &tokenBucket{
|
|
tokens: float64(rl.burst) - 1,
|
|
lastTime: now,
|
|
}
|
|
return true
|
|
}
|
|
|
|
// 补充令牌
|
|
elapsed := now.Sub(bucket.lastTime).Seconds()
|
|
bucket.tokens += elapsed * rl.rate
|
|
if bucket.tokens > float64(rl.burst) {
|
|
bucket.tokens = float64(rl.burst)
|
|
}
|
|
bucket.lastTime = now
|
|
|
|
// 消耗令牌
|
|
if bucket.tokens >= 1 {
|
|
bucket.tokens--
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// cleanup 定期清理长时间未使用的桶
|
|
func (rl *RateLimiter) cleanup() {
|
|
for {
|
|
time.Sleep(5 * time.Minute)
|
|
|
|
rl.mu.Lock()
|
|
cutoff := time.Now().Add(-10 * time.Minute)
|
|
for key, bucket := range rl.buckets {
|
|
if bucket.lastTime.Before(cutoff) {
|
|
delete(rl.buckets, key)
|
|
}
|
|
}
|
|
rl.mu.Unlock()
|
|
}
|
|
}
|