Files
Cyrene/backend/gateway/internal/config/config.go
T
AskaEth a058b0ab8e fix: 第一轮修复 - 记忆管理/IoT操控/历史消息持久化/动作消息/链路优化/安全配置
- 修复记忆管理数据库连接不可用 (ai-core重编译+Unicode修复)
- 修复IoT子会话工具调用链路日志缺失
- 新增最终审查子会话(review_provider) 支持消息格式解析拆分
- 实现历史消息持久化(后端存储+前端分页加载)
- 前端新增动作消息(ActionMessage)类型和渲染
- 优化对话链路速度(非阻塞子会话+快速问候通道)
- JWT密钥环境变量化(无默认值启动panic)
- Token自动刷新机制(401拦截器+refresh接口)
- WebSocket指数退避重连(jitter+最大10次)
- localStorage清理一致性(cyrene_前缀+版本检查)
- IoT环境变量统一为IOT_SERVICE_URL
2026-05-21 23:10:07 +08:00

277 lines
7.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package config
import (
"fmt"
"os"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
// Config 应用配置
type Config struct {
Env string
Port string
// 数据库
PostgresHost string
PostgresPort string
PostgresUser string
PostgresPass string
PostgresDB string
// Redis
RedisHost string
RedisPort string
RedisPass string
// JWT
JWTSecret string
JWTExpiryHours time.Duration
// 管理员账户 (开发阶段使用)
AdminUsername string
AdminPassword string
AdminNickname string // 昔涟对用户的基本称呼
// 注册开关
RegistrationEnabled bool
// AI-Core 服务
AICoreURL string
// Memory 服务
MemoryServiceURL string
// IoT 调试服务
IoTDebugServiceURL string
// Voice 语音识别服务
VoiceServiceURL string
// Tool-Engine 工具引擎服务
ToolEngineURL string
// LLM (透传给AI-CoreGateway可能也需要)
LLMAPIURL string
LLMAPIKey string
LLMModel string
// WebSocket
WSMaxConnections int
// 会话闲置超时 (分钟) — 超过此时间后会话标记为 idle 但不删除
SessionIdleTimeoutMin int
// Webhook (第三方平台接入)
WebhookAPIKey string
// Internal Service Token (内部服务间认证)
InternalServiceToken string
// CORS 允许的 Origin 白名单
AllowedOrigins []string
// 每日简报时间 (HH:MM 格式)
BriefingTime string
}
// Load 从环境变量加载配置
// 注意:JWT_SECRET 和 INTERNAL_SERVICE_TOKEN 必须在环境变量中设置,否则启动时 panic
func Load() *Config {
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
panic("致命错误: 环境变量 JWT_SECRET 未设置,服务拒绝启动。请在 .env 文件中设置 JWT_SECRET。")
}
internalServiceToken := os.Getenv("INTERNAL_SERVICE_TOKEN")
if internalServiceToken == "" {
panic("致命错误: 环境变量 INTERNAL_SERVICE_TOKEN 未设置,服务拒绝启动。请在 .env 文件中设置 INTERNAL_SERVICE_TOKEN。")
}
// IoT 服务 URL:优先使用 IOT_SERVICE_URL,回退到 IOT_DEBUG_SERVICE_URL(向后兼容)
iotServiceURL := os.Getenv("IOT_SERVICE_URL")
if iotServiceURL == "" {
iotServiceURL = getEnv("IOT_DEBUG_SERVICE_URL", "http://localhost:8083")
}
return &Config{
Env: getEnv("ENV", "development"),
Port: getEnv("GATEWAY_PORT", "8080"),
PostgresHost: getEnv("POSTGRES_HOST", "localhost"),
PostgresPort: getEnv("POSTGRES_PORT", "5432"),
PostgresUser: getEnv("POSTGRES_USER", "cyrene"),
PostgresPass: getEnv("POSTGRES_PASSWORD", "change_me"),
PostgresDB: getEnv("POSTGRES_DB", "cyrene_ai"),
RedisHost: getEnv("REDIS_HOST", "localhost"),
RedisPort: getEnv("REDIS_PORT", "6379"),
RedisPass: getEnv("REDIS_PASSWORD", ""),
JWTSecret: jwtSecret,
JWTExpiryHours: time.Duration(getEnvInt("JWT_EXPIRY_HOURS", 720)) * time.Hour,
// 管理员账户 (开发阶段使用)
AdminUsername: getEnv("ADMIN_USERNAME", "admin"),
AdminPassword: getEnv("ADMIN_PASSWORD", "cyrene-dev-admin"),
AdminNickname: getEnv("ADMIN_NICKNAME", "管理员"),
// 注册开关 (开发阶段默认关闭)
RegistrationEnabled: getEnvBool("REGISTRATION_ENABLED", false),
AICoreURL: getEnv("AI_CORE_URL", "http://localhost:8081"),
MemoryServiceURL: getEnv("MEMORY_SERVICE_URL", "http://localhost:8091"),
IoTDebugServiceURL: iotServiceURL,
VoiceServiceURL: getEnv("VOICE_SERVICE_URL", "http://localhost:8093"),
ToolEngineURL: getEnv("TOOL_ENGINE_URL", "http://localhost:8092"),
LLMAPIURL: getEnv("LLM_API_URL", "https://api.openai.com/v1"),
LLMAPIKey: getEnv("LLM_API_KEY", ""),
LLMModel: getEnv("LLM_MODEL", "gpt-4o"),
WSMaxConnections: getEnvInt("WS_MAX_CONNECTIONS", 1000),
SessionIdleTimeoutMin: getEnvInt("SESSION_IDLE_TIMEOUT_MIN", 30),
WebhookAPIKey: getEnv("WEBHOOK_API_KEY", ""),
InternalServiceToken: internalServiceToken,
AllowedOrigins: parseAllowedOrigins(getEnv("ALLOWED_ORIGINS", "http://localhost:5173,http://localhost:5199,http://localhost:3000")),
BriefingTime: getEnv("BRIEFING_TIME", "08:00"),
}
}
// DatabaseURL 构建 PostgreSQL 连接字符串
func (c *Config) DatabaseURL() string {
return fmt.Sprintf(
"postgres://%s:%s@%s:%s/%s?sslmode=disable",
c.PostgresUser, c.PostgresPass,
c.PostgresHost, c.PostgresPort,
c.PostgresDB,
)
}
// GenerateToken 生成JWT token (短期 access token)
func (c *Config) GenerateToken(userID string) (string, error) {
claims := jwt.MapClaims{
"user_id": userID,
"type": "access",
"exp": time.Now().Add(c.JWTExpiryHours).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(c.JWTSecret))
}
// GenerateRefreshToken 生成 refresh token (长期有效,30天)
func (c *Config) GenerateRefreshToken(userID string) (string, error) {
claims := jwt.MapClaims{
"user_id": userID,
"type": "refresh",
"exp": time.Now().Add(30 * 24 * time.Hour).Unix(), // 30天
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(c.JWTSecret))
}
// ValidateToken 验证JWT token
func (c *Config) ValidateToken(tokenString string) (string, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(c.JWTSecret), nil
})
if err != nil {
return "", err
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return "", jwt.ErrSignatureInvalid
}
userID, _ := claims["user_id"].(string)
return userID, nil
}
// ValidateRefreshToken 验证 refresh token
func (c *Config) ValidateRefreshToken(tokenString string) (string, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(c.JWTSecret), nil
})
if err != nil {
return "", err
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return "", jwt.ErrSignatureInvalid
}
// 验证类型必须是 "refresh"
tokenType, _ := claims["type"].(string)
if tokenType != "refresh" {
return "", fmt.Errorf("无效的刷新令牌类型")
}
userID, _ := claims["user_id"].(string)
return userID, nil
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
func getEnvInt(key string, fallback int) int {
v := os.Getenv(key)
if v == "" {
return fallback
}
var result int
for _, c := range v {
if c < '0' || c > '9' {
return fallback
}
result = result*10 + int(c-'0')
}
return result
}
func getEnvBool(key string, fallback bool) bool {
v := os.Getenv(key)
if v == "" {
return fallback
}
return v == "true" || v == "1" || v == "yes"
}
// parseAllowedOrigins 解析逗号分隔的 origins 字符串为切片
func parseAllowedOrigins(s string) []string {
if s == "" {
return []string{}
}
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
result = append(result, p)
}
}
return result
}