dev 分支暂存
This commit is contained in:
@@ -0,0 +1,124 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// Config 应用配置
|
||||
type Config struct {
|
||||
Env string
|
||||
Port string
|
||||
|
||||
// 数据库
|
||||
PostgresHost string
|
||||
PostgresPort string
|
||||
PostgresUser string
|
||||
PostgresPass string
|
||||
PostgresDB string
|
||||
|
||||
// Redis
|
||||
RedisHost string
|
||||
RedisPort string
|
||||
RedisPass string
|
||||
|
||||
// JWT
|
||||
JWTSecret string
|
||||
JWTExpiryHours time.Duration
|
||||
|
||||
// AI-Core 服务
|
||||
AICoreURL string
|
||||
|
||||
// LLM (透传给AI-Core,Gateway可能也需要)
|
||||
LLMAPIURL string
|
||||
LLMAPIKey string
|
||||
LLMModel string
|
||||
|
||||
// WebSocket
|
||||
WSMaxConnections int
|
||||
}
|
||||
|
||||
// Load 从环境变量加载配置
|
||||
func Load() *Config {
|
||||
return &Config{
|
||||
Env: getEnv("ENV", "development"),
|
||||
Port: getEnv("GATEWAY_PORT", "8080"),
|
||||
|
||||
PostgresHost: getEnv("POSTGRES_HOST", "localhost"),
|
||||
PostgresPort: getEnv("POSTGRES_PORT", "5432"),
|
||||
PostgresUser: getEnv("POSTGRES_USER", "cyrene"),
|
||||
PostgresPass: getEnv("POSTGRES_PASSWORD", "change_me"),
|
||||
PostgresDB: getEnv("POSTGRES_DB", "cyrene_ai"),
|
||||
|
||||
RedisHost: getEnv("REDIS_HOST", "localhost"),
|
||||
RedisPort: getEnv("REDIS_PORT", "6379"),
|
||||
RedisPass: getEnv("REDIS_PASSWORD", ""),
|
||||
|
||||
JWTSecret: getEnv("JWT_SECRET", "change-me-in-production"),
|
||||
JWTExpiryHours: time.Duration(getEnvInt("JWT_EXPIRY_HOURS", 720)) * time.Hour,
|
||||
|
||||
AICoreURL: getEnv("AI_CORE_URL", "http://localhost:8081"),
|
||||
|
||||
LLMAPIURL: getEnv("LLM_API_URL", "https://api.openai.com/v1"),
|
||||
LLMAPIKey: getEnv("LLM_API_KEY", ""),
|
||||
LLMModel: getEnv("LLM_MODEL", "gpt-4o"),
|
||||
|
||||
WSMaxConnections: getEnvInt("WS_MAX_CONNECTIONS", 1000),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
func (c *Config) GenerateToken(userID string) (string, error) {
|
||||
claims := jwt.MapClaims{
|
||||
"user_id": userID,
|
||||
"exp": time.Now().Add(c.JWTExpiryHours).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(c.JWTSecret))
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT token
|
||||
func (c *Config) ValidateToken(tokenString string) (string, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return []byte(c.JWTSecret), nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", jwt.ErrSignatureInvalid
|
||||
}
|
||||
|
||||
userID, _ := claims["user_id"].(string)
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func getEnv(key, fallback string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func getEnvInt(key string, fallback int) int {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return fallback
|
||||
}
|
||||
var result int
|
||||
for _, c := range v {
|
||||
if c < '0' || c > '9' {
|
||||
return fallback
|
||||
}
|
||||
result = result*10 + int(c-'0')
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
)
|
||||
|
||||
// AuthHandler 认证处理器
|
||||
type AuthHandler struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewAuthHandler 创建认证处理器
|
||||
func NewAuthHandler(cfg *config.Config) *AuthHandler {
|
||||
return &AuthHandler{cfg: cfg}
|
||||
}
|
||||
|
||||
// Register 用户注册
|
||||
func (h *AuthHandler) Register(c *gin.Context) {
|
||||
var req struct {
|
||||
Username string `json:"username" binding:"required,min=2,max=32"`
|
||||
Password string `json:"password" binding:"required,min=6,max=64"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// MVP阶段:使用username直接作为userID
|
||||
// 后续需要接入用户服务进行真实注册
|
||||
userID := "user_" + req.Username
|
||||
|
||||
// 生成JWT
|
||||
token, err := h.cfg.GenerateToken(userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成令牌失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"user_id": userID,
|
||||
"token": token,
|
||||
"expires": time.Now().Add(h.cfg.JWTExpiryHours).Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效"})
|
||||
return
|
||||
}
|
||||
|
||||
// MVP阶段:简化的登录逻辑
|
||||
// 后续需要验证密码哈希
|
||||
userID := "user_" + req.Username
|
||||
|
||||
token, err := h.cfg.GenerateToken(userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成令牌失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"user_id": userID,
|
||||
"token": token,
|
||||
"expires": time.Now().Add(h.cfg.JWTExpiryHours).Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
// RefreshToken 刷新令牌
|
||||
func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" || len(authHeader) < 8 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未提供认证令牌"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := authHeader[7:] // 去掉 "Bearer "
|
||||
userID, err := h.cfg.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
// 允许使用已过期但未超过刷新窗口的token
|
||||
// MVP简化:直接重新签发
|
||||
_ = json.Unmarshal([]byte("{}"), &struct{}{})
|
||||
}
|
||||
|
||||
newToken, err := h.cfg.GenerateToken(userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "刷新令牌失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": newToken,
|
||||
"expires": time.Now().Add(h.cfg.JWTExpiryHours).Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/ws"
|
||||
)
|
||||
|
||||
// ChatHandler 聊天处理器
|
||||
type ChatHandler struct {
|
||||
cfg *config.Config
|
||||
hub *ws.Hub
|
||||
upgrader websocket.Upgrader
|
||||
}
|
||||
|
||||
// NewChatHandler 创建聊天处理器
|
||||
func NewChatHandler(cfg *config.Config, hub *ws.Hub) *ChatHandler {
|
||||
return &ChatHandler{
|
||||
cfg: cfg,
|
||||
hub: hub,
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // 开发阶段允许所有来源
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// HandleWebSocket 处理WebSocket升级和消息路由
|
||||
func (h *ChatHandler) HandleWebSocket(c *gin.Context) {
|
||||
// 从query参数获取token和session_id
|
||||
token := c.Query("token")
|
||||
sessionID := c.Query("session_id")
|
||||
|
||||
if token == "" {
|
||||
// 也尝试从Authorization头读取
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
||||
token = authHeader[7:]
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "需要认证令牌"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
userID, err := h.cfg.ValidateToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证令牌无效"})
|
||||
return
|
||||
}
|
||||
|
||||
if sessionID == "" {
|
||||
sessionID = "session_" + generateID()
|
||||
}
|
||||
|
||||
// 升级WebSocket连接
|
||||
conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("[WS] 升级连接失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建客户端
|
||||
client := ws.NewClient(h.hub, conn, userID, sessionID)
|
||||
|
||||
// 注册到Hub
|
||||
h.hub.register <- client
|
||||
|
||||
// 启动读写协程
|
||||
go client.WritePump()
|
||||
go client.ReadPump(func(client *ws.Client, msg ws.ClientMessage) {
|
||||
h.handleMessage(client, msg)
|
||||
})
|
||||
}
|
||||
|
||||
// handleMessage 处理WebSocket消息
|
||||
func (h *ChatHandler) handleMessage(client *ws.Client, msg ws.ClientMessage) {
|
||||
switch msg.Type {
|
||||
case "message":
|
||||
h.handleChatMessage(client, msg)
|
||||
case "voice_input":
|
||||
h.handleVoiceInput(client, msg)
|
||||
default:
|
||||
log.Printf("[WS] 未知消息类型: %s from user=%s", msg.Type, client.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
// handleChatMessage 处理文字聊天消息
|
||||
func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage) {
|
||||
mode := msg.Mode
|
||||
if mode == "" {
|
||||
mode = "text"
|
||||
}
|
||||
|
||||
// MVP阶段:生成模拟回复(后续对接AI-Core)
|
||||
// 实际部署时,这里应转发消息到AI-Core并等待响应
|
||||
|
||||
response := ws.ServerMessage{
|
||||
Type: "response",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Text: h.generateMockResponse(msg.Content, mode),
|
||||
ResponseMode: mode,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
// 发送响应给客户端
|
||||
if err := client.SendMessage(response); err != nil {
|
||||
log.Printf("[WS] 发送响应失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleVoiceInput 处理语音输入
|
||||
func (h *ChatHandler) handleVoiceInput(client *ws.Client, msg ws.ClientMessage) {
|
||||
// MVP阶段:返回提示
|
||||
response := ws.ServerMessage{
|
||||
Type: "error",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Error: "语音处理功能将在后续版本中启用",
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
client.SendMessage(response)
|
||||
}
|
||||
|
||||
// generateMockResponse 生成模拟回复
|
||||
func (h *ChatHandler) generateMockResponse(content, mode string) string {
|
||||
// MVP阶段:没有对接AI-Core时的默认回复
|
||||
responses := []string{
|
||||
"嗯嗯,人家听到了哦♪ 开拓者想和昔涟聊些什么呢?",
|
||||
"嘻嘻,开拓者说的话真有趣呢♪ 让我想想怎么回答……",
|
||||
"啊,这个问题很有意思呢!虽然人家现在还在学习阶段,但我很乐意倾听开拓者说的每一句话哦♡",
|
||||
}
|
||||
|
||||
// 简单hash选一条
|
||||
hash := 0
|
||||
for _, c := range content {
|
||||
hash += int(c)
|
||||
}
|
||||
return responses[hash%len(responses)]
|
||||
}
|
||||
|
||||
// SendSystemMessage 向用户发送系统消息(用于主动通知)
|
||||
func (h *ChatHandler) SendSystemMessage(userID, sessionID, text string) error {
|
||||
msg := ws.ServerMessage{
|
||||
Type: "response",
|
||||
MessageID: "sys_" + generateID(),
|
||||
Text: text,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.hub.SendToSession(userID, sessionID, data)
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateID() string {
|
||||
return time.Now().Format("20060102150405") + randomStr(6)
|
||||
}
|
||||
|
||||
func randomStr(n int) string {
|
||||
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letters[time.Now().UnixNano()%int64(len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// 确保未使用变量不报错
|
||||
var _ = middleware.GetUserID
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
)
|
||||
|
||||
// MemoryHandler 记忆查询处理器
|
||||
type MemoryHandler struct {
|
||||
// MVP阶段:直接透传到AI-Core,Gateway本身不需要记忆存储
|
||||
aiCoreURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewMemoryHandler 创建记忆处理器
|
||||
func NewMemoryHandler(aiCoreURL string) *MemoryHandler {
|
||||
return &MemoryHandler{
|
||||
aiCoreURL: aiCoreURL,
|
||||
client: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// Query 查询用户记忆
|
||||
func (h *MemoryHandler) Query(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
query := c.Query("q")
|
||||
if query == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "查询参数q不能为空"})
|
||||
return
|
||||
}
|
||||
|
||||
// MVP阶段:返回简单的内存数据
|
||||
// 后续将请求转发到AI-Core的记忆API
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"user_id": userID,
|
||||
"query": query,
|
||||
"memories": []gin.H{},
|
||||
"message": "记忆查询功能将在后续版本中接入AI-Core",
|
||||
})
|
||||
}
|
||||
|
||||
// List 列出用户所有记忆
|
||||
func (h *MemoryHandler) List(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"user_id": userID,
|
||||
"memories": []gin.H{},
|
||||
"message": "记忆列表功能将在后续版本中接入AI-Core",
|
||||
})
|
||||
}
|
||||
|
||||
// Add 手动添加记忆
|
||||
func (h *MemoryHandler) Add(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
var req struct {
|
||||
Content string `json:"content" binding:"required"`
|
||||
Category string `json:"category"`
|
||||
Priority int `json:"priority"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Category == "" {
|
||||
req.Category = "other"
|
||||
}
|
||||
if req.Priority <= 0 {
|
||||
req.Priority = 1
|
||||
}
|
||||
|
||||
// MVP阶段:返回成功但暂不持久化
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"status": "accepted",
|
||||
"user_id": userID,
|
||||
"content": req.Content,
|
||||
"category": req.Category,
|
||||
"priority": req.Priority,
|
||||
"message": "记忆手动添加功能将在后续版本中接入AI-Core",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
)
|
||||
|
||||
// SessionHandler 会话管理处理器
|
||||
type SessionHandler struct {
|
||||
// MVP阶段使用内存存储,后续迁移到PostgreSQL
|
||||
sessions map[string][]SessionInfo // userID -> sessions
|
||||
}
|
||||
|
||||
// SessionInfo 会话信息
|
||||
type SessionInfo struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Title string `json:"title"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
// NewSessionHandler 创建会话处理器
|
||||
func NewSessionHandler() *SessionHandler {
|
||||
return &SessionHandler{
|
||||
sessions: make(map[string][]SessionInfo),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建新会话
|
||||
func (h *SessionHandler) Create(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
var req struct {
|
||||
Title string `json:"title"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// 允许空body
|
||||
req.Title = "新的对话"
|
||||
}
|
||||
if req.Title == "" {
|
||||
req.Title = "新的对话"
|
||||
}
|
||||
|
||||
session := SessionInfo{
|
||||
ID: "session_" + randomID(12),
|
||||
UserID: userID,
|
||||
Title: req.Title,
|
||||
CreatedAt: nowMillis(),
|
||||
UpdatedAt: nowMillis(),
|
||||
}
|
||||
|
||||
h.sessions[userID] = append([]SessionInfo{session}, h.sessions[userID]...)
|
||||
|
||||
c.JSON(http.StatusCreated, session)
|
||||
}
|
||||
|
||||
// List 获取会话列表
|
||||
func (h *SessionHandler) List(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
sessions, ok := h.sessions[userID]
|
||||
if !ok {
|
||||
sessions = []SessionInfo{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"sessions": sessions,
|
||||
})
|
||||
}
|
||||
|
||||
// Delete 删除会话
|
||||
func (h *SessionHandler) Delete(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
sessionID := c.Param("id")
|
||||
|
||||
sessions := h.sessions[userID]
|
||||
for i, s := range sessions {
|
||||
if s.ID == sessionID {
|
||||
h.sessions[userID] = append(sessions[:i], sessions[i+1:]...)
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "会话不存在"})
|
||||
}
|
||||
|
||||
// Get 获取单个会话信息
|
||||
func (h *SessionHandler) Get(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
sessionID := c.Param("id")
|
||||
|
||||
for _, s := range h.sessions[userID] {
|
||||
if s.ID == sessionID {
|
||||
c.JSON(http.StatusOK, s)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "会话不存在"})
|
||||
}
|
||||
|
||||
// 简单的工具函数
|
||||
func randomID(n int) string {
|
||||
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letters[i%len(letters)]
|
||||
}
|
||||
// 使用纳秒时间戳增加唯一性
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func nowMillis() int64 {
|
||||
// 避免引入time包,直接返回一个值
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
)
|
||||
|
||||
// Auth 用户键值在context中的key
|
||||
const UserIDKey = "user_id"
|
||||
|
||||
// JWTAuth JWT认证中间件
|
||||
func JWTAuth(cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未提供认证令牌"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Bearer token
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证格式错误"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
userID, err := cfg.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证令牌无效或已过期"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将userID注入上下文
|
||||
c.Set(UserIDKey, userID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserID 从上下文获取用户ID
|
||||
func GetUserID(c *gin.Context) string {
|
||||
userID, _ := c.Get(UserIDKey)
|
||||
if userID == nil {
|
||||
return ""
|
||||
}
|
||||
return userID.(string)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Authorization, X-Request-ID")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Max-Age", "86400")
|
||||
|
||||
// 预检请求
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RequestLogging 请求日志中间件
|
||||
func RequestLogging() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 记录日志
|
||||
duration := time.Since(start)
|
||||
statusCode := c.Writer.Status()
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
logLevel := "[INFO]"
|
||||
if statusCode >= 500 {
|
||||
logLevel = "[ERROR]"
|
||||
} else if statusCode >= 400 {
|
||||
logLevel = "[WARN]"
|
||||
}
|
||||
|
||||
log.Printf("%s %s %s %d %v %s",
|
||||
logLevel, method, path, statusCode, duration, clientIP,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RateLimiter 基于内存令牌桶的限流中间件
|
||||
// MVP阶段使用内存实现,后续可迁移到Redis
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string]*tokenBucket
|
||||
rate float64 // 每秒生成的令牌数
|
||||
burst int // 桶容量
|
||||
}
|
||||
|
||||
type tokenBucket struct {
|
||||
tokens float64
|
||||
lastTime time.Time
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建限流器
|
||||
func NewRateLimiter(rate float64, burst int) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
buckets: make(map[string]*tokenBucket),
|
||||
rate: rate,
|
||||
burst: burst,
|
||||
}
|
||||
|
||||
// 定期清理过期桶
|
||||
go rl.cleanup()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// Handler 返回Gin中间件
|
||||
func (rl *RateLimiter) Handler() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
key := c.ClientIP() // 按IP限流
|
||||
|
||||
if !rl.allow(key) {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "请求过于频繁,请稍后再试",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) allow(key string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
bucket, ok := rl.buckets[key]
|
||||
now := time.Now()
|
||||
|
||||
if !ok {
|
||||
rl.buckets[key] = &tokenBucket{
|
||||
tokens: float64(rl.burst) - 1,
|
||||
lastTime: now,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 补充令牌
|
||||
elapsed := now.Sub(bucket.lastTime).Seconds()
|
||||
bucket.tokens += elapsed * rl.rate
|
||||
if bucket.tokens > float64(rl.burst) {
|
||||
bucket.tokens = float64(rl.burst)
|
||||
}
|
||||
bucket.lastTime = now
|
||||
|
||||
// 消耗令牌
|
||||
if bucket.tokens >= 1 {
|
||||
bucket.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// cleanup 定期清理长时间未使用的桶
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
for {
|
||||
time.Sleep(5 * time.Minute)
|
||||
|
||||
rl.mu.Lock()
|
||||
cutoff := time.Now().Add(-10 * time.Minute)
|
||||
for key, bucket := range rl.buckets {
|
||||
if bucket.lastTime.Before(cutoff) {
|
||||
delete(rl.buckets, key)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/handler"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/ws"
|
||||
)
|
||||
|
||||
// Setup 注册所有路由
|
||||
func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config) {
|
||||
// 限流器
|
||||
rateLimiter := middleware.NewRateLimiter(10, 20) // 每秒10个请求,突发20
|
||||
|
||||
// 初始化处理器
|
||||
authHandler := handler.NewAuthHandler(cfg)
|
||||
sessionHandler := handler.NewSessionHandler()
|
||||
memoryHandler := handler.NewMemoryHandler(cfg.AICoreURL)
|
||||
chatHandler := handler.NewChatHandler(cfg, hub)
|
||||
|
||||
// ========== 公开路由 ==========
|
||||
api := r.Group("/api/v1")
|
||||
|
||||
// 健康检查
|
||||
api.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"status": "ok",
|
||||
"service": "cyrene-gateway",
|
||||
"ws_connections": hub.ClientCount(),
|
||||
})
|
||||
})
|
||||
|
||||
// 认证 (无需JWT)
|
||||
auth := api.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", authHandler.Register)
|
||||
auth.POST("/login", authHandler.Login)
|
||||
}
|
||||
|
||||
// ========== 需要认证的路由 ==========
|
||||
protected := api.Group("")
|
||||
protected.Use(middleware.JWTAuth(cfg))
|
||||
protected.Use(rateLimiter.Handler())
|
||||
{
|
||||
// Token刷新
|
||||
protected.POST("/auth/refresh", authHandler.RefreshToken)
|
||||
|
||||
// 会话管理
|
||||
sessions := protected.Group("/sessions")
|
||||
{
|
||||
sessions.POST("", sessionHandler.Create)
|
||||
sessions.GET("", sessionHandler.List)
|
||||
sessions.GET("/:id", sessionHandler.Get)
|
||||
sessions.DELETE("/:id", sessionHandler.Delete)
|
||||
}
|
||||
|
||||
// 记忆管理
|
||||
memory := protected.Group("/memory")
|
||||
{
|
||||
memory.GET("/search", memoryHandler.Query)
|
||||
memory.GET("", memoryHandler.List)
|
||||
memory.POST("", memoryHandler.Add)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== WebSocket路由 ==========
|
||||
// WebSocket升级在HTTP层,token通过query参数或Header传递
|
||||
wsGroup := r.Group("/ws")
|
||||
{
|
||||
wsGroup.GET("/chat", chatHandler.HandleWebSocket)
|
||||
}
|
||||
|
||||
// ========== 静态文件服务 (生产环境) ==========
|
||||
if cfg.Env == "production" {
|
||||
r.Static("/assets", "./public/assets")
|
||||
r.StaticFile("/", "./public/index.html")
|
||||
r.NoRoute(func(c *gin.Context) {
|
||||
c.File("./public/index.html")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
// 写入超时
|
||||
writeWait = 10 * time.Second
|
||||
|
||||
// 读取pong超时
|
||||
pongWait = 60 * time.Second
|
||||
|
||||
// pong发送后等待下一次ping的间隔
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
|
||||
// 最大消息大小
|
||||
maxMessageSize = 65536
|
||||
)
|
||||
|
||||
// Client WebSocket客户端
|
||||
type Client struct {
|
||||
Hub *Hub
|
||||
Conn *websocket.Conn
|
||||
Send chan []byte
|
||||
UserID string
|
||||
SessionID string
|
||||
}
|
||||
|
||||
// NewClient 创建WebSocket客户端
|
||||
func NewClient(hub *Hub, conn *websocket.Conn, userID, sessionID string) *Client {
|
||||
return &Client{
|
||||
Hub: hub,
|
||||
Conn: conn,
|
||||
Send: make(chan []byte, 256),
|
||||
UserID: userID,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadPump 读取协程 —— 从WebSocket连接读取消息
|
||||
func (c *Client) ReadPump(onMessage func(client *Client, msg ClientMessage)) {
|
||||
defer func() {
|
||||
c.Hub.unregister <- c
|
||||
c.Conn.Close()
|
||||
}()
|
||||
|
||||
c.Conn.SetReadLimit(maxMessageSize)
|
||||
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
c.Conn.SetPongHandler(func(string) error {
|
||||
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
_, rawMessage, err := c.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("[WS] 读取错误: user=%s err=%v", c.UserID, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// 解析消息
|
||||
var msg ClientMessage
|
||||
if err := json.Unmarshal(rawMessage, &msg); err != nil {
|
||||
log.Printf("[WS] 消息解析失败: user=%s err=%v", c.UserID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理ping
|
||||
if msg.Type == "ping" {
|
||||
pongMsg := ServerMessage{
|
||||
Type: "pong",
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
}
|
||||
data, _ := json.Marshal(pongMsg)
|
||||
c.Send <- data
|
||||
continue
|
||||
}
|
||||
|
||||
// 调用消息处理器
|
||||
if onMessage != nil {
|
||||
onMessage(c, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WritePump 写入协程 —— 向WebSocket连接写入消息
|
||||
func (c *Client) WritePump() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
c.Conn.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.Send:
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if !ok {
|
||||
// Hub关闭了通道
|
||||
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.Conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
||||
log.Printf("[WS] 写入错误: user=%s err=%v", c.UserID, err)
|
||||
return
|
||||
}
|
||||
|
||||
case <-ticker.C:
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SendMessage 向客户端发送消息
|
||||
func (c *Client) SendMessage(msg ServerMessage) error {
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case c.Send <- data:
|
||||
return nil
|
||||
default:
|
||||
return nil // 通道满则丢弃
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Hub WebSocket连接池
|
||||
type Hub struct {
|
||||
mu sync.RWMutex
|
||||
clients map[*Client]bool
|
||||
broadcast chan []byte
|
||||
register chan *Client
|
||||
unregister chan *Client
|
||||
|
||||
// 按用户ID索引的客户端映射
|
||||
userClients map[string]map[*Client]bool
|
||||
}
|
||||
|
||||
// NewHub 创建WebSocket Hub
|
||||
func NewHub() *Hub {
|
||||
return &Hub{
|
||||
clients: make(map[*Client]bool),
|
||||
broadcast: make(chan []byte, 256),
|
||||
register: make(chan *Client),
|
||||
unregister: make(chan *Client),
|
||||
userClients: make(map[string]map[*Client]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// Run 启动Hub主循环
|
||||
func (h *Hub) Run() {
|
||||
for {
|
||||
select {
|
||||
case client := <-h.register:
|
||||
h.mu.Lock()
|
||||
h.clients[client] = true
|
||||
|
||||
// 用户索引
|
||||
if h.userClients[client.UserID] == nil {
|
||||
h.userClients[client.UserID] = make(map[*Client]bool)
|
||||
}
|
||||
h.userClients[client.UserID][client] = true
|
||||
h.mu.Unlock()
|
||||
|
||||
log.Printf("[WS] 客户端连接: user=%s session=%s (当前连接数: %d)",
|
||||
client.UserID, client.SessionID, len(h.clients))
|
||||
|
||||
case client := <-h.unregister:
|
||||
h.mu.Lock()
|
||||
if _, ok := h.clients[client]; ok {
|
||||
delete(h.clients, client)
|
||||
close(client.Send)
|
||||
|
||||
// 清理用户索引
|
||||
if h.userClients[client.UserID] != nil {
|
||||
delete(h.userClients[client.UserID], client)
|
||||
if len(h.userClients[client.UserID]) == 0 {
|
||||
delete(h.userClients, client.UserID)
|
||||
}
|
||||
}
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
log.Printf("[WS] 客户端断开: user=%s session=%s (当前连接数: %d)",
|
||||
client.UserID, client.SessionID, len(h.clients))
|
||||
|
||||
case message := <-h.broadcast:
|
||||
h.mu.RLock()
|
||||
for client := range h.clients {
|
||||
select {
|
||||
case client.Send <- message:
|
||||
default:
|
||||
// 客户端发送通道已满,跳过
|
||||
close(client.Send)
|
||||
delete(h.clients, client)
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SendToUser 向指定用户的所有连接发送消息
|
||||
func (h *Hub) SendToUser(userID string, message []byte) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
if clients, ok := h.userClients[userID]; ok {
|
||||
for client := range clients {
|
||||
select {
|
||||
case client.Send <- message:
|
||||
default:
|
||||
// 跳过阻塞的客户端
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SendToSession 向指定会话的连接发送消息
|
||||
func (h *Hub) SendToSession(userID, sessionID string, message []byte) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
if clients, ok := h.userClients[userID]; ok {
|
||||
for client := range clients {
|
||||
if client.SessionID == sessionID {
|
||||
select {
|
||||
case client.Send <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ClientCount 获取当前连接数
|
||||
func (h *Hub) ClientCount() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return len(h.clients)
|
||||
}
|
||||
|
||||
// UserClientCount 获取指定用户的连接数
|
||||
func (h *Hub) UserClientCount(userID string) int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
if clients, ok := h.userClients[userID]; ok {
|
||||
return len(clients)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user