refactor: 认证系统重构 + DevTools CLI 重写 + 文档全面更新
- auth: Login 简化为管理员始终通过 .env 验证,GetProfile 修正 admin DB 查询 - devtools: .sh/.bat 同步重写为完整 CLI (start/stop/status/logs/build/db:*) - docs: 新增 devtools.md,重写 Deploy.md (三种方式+Windows说明),更新 README/gateway-api - voice-service: DashScope 实时流式 STT 支持 - gateway: Phase 6 多模型配置 + 多端客户端管理 + WebSocket 增强 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -3,7 +3,6 @@ package handler
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/yourname/cyrene-ai/pkg/logger"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -126,7 +125,9 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// Login 用户登录 (支持管理员账户和普通用户)
|
||||
// Login 用户登录
|
||||
// 管理员始终通过 .env 配置验证,不受数据库状态影响
|
||||
// 普通用户通过数据库 bcrypt 密码哈希验证
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
@@ -138,7 +139,6 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 用户名格式校验:仅允许字母、数字、下划线,长度 3-32
|
||||
if !usernameRegex.MatchString(req.Username) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "用户名格式无效"})
|
||||
return
|
||||
@@ -147,56 +147,36 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var userID string
|
||||
var nickname string
|
||||
|
||||
// 尝试从 users 表查询用户
|
||||
authenticated, err := h.verifyUserPassword(req.Username, req.Password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器内部错误"})
|
||||
return
|
||||
}
|
||||
|
||||
if authenticated {
|
||||
// 用户存在于 users 表中且密码验证通过
|
||||
if req.Username == h.cfg.AdminUsername {
|
||||
userID = "admin"
|
||||
} else {
|
||||
userID = "user_" + req.Username
|
||||
// 管理员:始终通过 .env 配置验证,不依赖数据库
|
||||
if req.Username == h.cfg.AdminUsername && req.Password == h.cfg.AdminPassword {
|
||||
userID = "admin"
|
||||
nickname = h.cfg.AdminNickname
|
||||
if nickname == "" {
|
||||
nickname = "管理员"
|
||||
}
|
||||
// 获取用户昵称
|
||||
// 数据库可用时从 DB 获取昵称(覆盖配置默认值)
|
||||
if h.db != nil {
|
||||
if u, err := store.GetUserByUsername(h.db, req.Username); err == nil && u != nil {
|
||||
nickname = u.Nickname
|
||||
}
|
||||
}
|
||||
} else if req.Username == h.cfg.AdminUsername && h.db != nil {
|
||||
// 管理员用户尚未迁移到 users 表,尝试用配置中的密码验证
|
||||
if req.Password != h.cfg.AdminPassword {
|
||||
} else {
|
||||
// 普通用户:数据库 bcrypt 验证
|
||||
authenticated, err := h.verifyUserPassword(req.Username, req.Password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器内部错误"})
|
||||
return
|
||||
}
|
||||
if !authenticated {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户名或密码错误"})
|
||||
return
|
||||
}
|
||||
// 密码正确,迁移 admin 到 users 表
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
logger.Printf("⚠ 迁移管理员密码哈希失败: %v", err)
|
||||
} else {
|
||||
if _, err := store.CreateUser(h.db, req.Username, "管理员", string(passwordHash), true); err != nil {
|
||||
logger.Printf("⚠ 迁移管理员到 users 表失败: %v", err)
|
||||
} else {
|
||||
logger.Println("✅ 管理员已迁移到 users 表")
|
||||
userID = "user_" + req.Username
|
||||
if h.db != nil {
|
||||
if u, err := store.GetUserByUsername(h.db, req.Username); err == nil && u != nil {
|
||||
nickname = u.Nickname
|
||||
}
|
||||
}
|
||||
userID = "admin"
|
||||
nickname = "管理员"
|
||||
} else if req.Username == h.cfg.AdminUsername {
|
||||
// 数据库不可用时的回退:使用配置中的管理员密码
|
||||
if req.Password != h.cfg.AdminPassword {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户名或密码错误"})
|
||||
return
|
||||
}
|
||||
userID = "admin"
|
||||
nickname = "管理员"
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户名或密码错误"})
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.cfg.GenerateToken(userID)
|
||||
@@ -205,7 +185,6 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 生成 refresh_token (长期有效)
|
||||
refreshToken, err := h.cfg.GenerateRefreshToken(userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成刷新令牌失败"})
|
||||
@@ -311,7 +290,7 @@ func (h *AuthHandler) GetProfile(c *gin.Context) {
|
||||
var nickname string
|
||||
|
||||
if isAdmin {
|
||||
username = "admin"
|
||||
username = h.cfg.AdminUsername
|
||||
} else if strings.HasPrefix(userID, "user_") {
|
||||
username = strings.TrimPrefix(userID, "user_")
|
||||
} else {
|
||||
|
||||
@@ -4,11 +4,12 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/yourname/cyrene-ai/pkg/logger"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/ws"
|
||||
"github.com/yourname/cyrene-ai/pkg/logger"
|
||||
)
|
||||
|
||||
// ChatHandler 聊天处理器
|
||||
@@ -422,6 +424,8 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
Type: "stream_end",
|
||||
MessageID: msgID,
|
||||
SessionID: client.SessionID,
|
||||
Content: fullText,
|
||||
Text: fullText,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
|
||||
@@ -458,14 +462,145 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
|
||||
// handleVoiceInput 处理语音输入
|
||||
func (h *ChatHandler) handleVoiceInput(client *ws.Client, msg ws.ClientMessage) {
|
||||
// MVP阶段:返回提示
|
||||
response := ws.ServerMessage{
|
||||
Type: "error",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Error: "语音处理功能将在后续版本中启用",
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
audioB64 := msg.AudioData
|
||||
if audioB64 == "" {
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "error",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Error: "语音数据为空",
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
return
|
||||
}
|
||||
client.SendMessage(response)
|
||||
|
||||
format := msg.Mode
|
||||
if format == "" {
|
||||
format = "webm"
|
||||
}
|
||||
|
||||
// 在 goroutine 中处理转录,避免阻塞 ReadPump
|
||||
go func() {
|
||||
text, err := h.transcribeAudio(audioB64, format)
|
||||
if err != nil {
|
||||
logger.Printf("[voice] 转录失败: %v", err)
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "voice_transcript",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Error: fmt.Sprintf("语音识别失败: %v", err),
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if text == "" {
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "voice_transcript",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Text: "",
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 发送转录结果给前端
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "voice_transcript",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Text: text,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
|
||||
// 将转录文本作为聊天消息处理
|
||||
chatMsg := ws.ClientMessage{
|
||||
Type: "message",
|
||||
Content: text,
|
||||
Mode: msg.Mode,
|
||||
}
|
||||
h.handleChatMessage(client, chatMsg)
|
||||
}()
|
||||
}
|
||||
|
||||
// transcribeAudio 将 base64 编码的音频发送到 voice-service 进行转录。
|
||||
func (h *ChatHandler) transcribeAudio(audioB64 string, format string) (string, error) {
|
||||
audioData, err := decodeBase64(audioB64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解码音频数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建 multipart form
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
|
||||
ext := ".webm"
|
||||
switch format {
|
||||
case "wav", "wave":
|
||||
ext = ".wav"
|
||||
case "mp3", "mpeg":
|
||||
ext = ".mp3"
|
||||
case "ogg", "opus":
|
||||
ext = ".ogg"
|
||||
case "pcm":
|
||||
ext = ".pcm"
|
||||
}
|
||||
|
||||
fw, err := mw.CreateFormFile("audio", "recording"+ext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建表单字段失败: %w", err)
|
||||
}
|
||||
if _, err := fw.Write(audioData); err != nil {
|
||||
return "", fmt.Errorf("写入音频数据失败: %w", err)
|
||||
}
|
||||
mw.Close()
|
||||
|
||||
voiceURL := h.cfg.VoiceServiceURL
|
||||
if voiceURL == "" {
|
||||
voiceURL = "http://localhost:8093"
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", voiceURL+"/api/v1/transcribe", &buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
httpClient := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("voice-service 调用失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Success bool `json:"success"`
|
||||
Text string `json:"text"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
if result.Error != "" {
|
||||
return "", fmt.Errorf("%s", result.Error)
|
||||
}
|
||||
return "", fmt.Errorf("转录返回空结果")
|
||||
}
|
||||
|
||||
return result.Text, nil
|
||||
}
|
||||
|
||||
// decodeBase64 解码 base64 字符串(支持 Data URL 前缀)。
|
||||
func decodeBase64(s string) ([]byte, error) {
|
||||
// 移除 data:xxx;base64, 前缀
|
||||
if idx := strings.Index(s, ","); idx != -1 {
|
||||
s = s[idx+1:]
|
||||
}
|
||||
return base64.StdEncoding.DecodeString(s)
|
||||
}
|
||||
|
||||
// handleHistoryRequest 处理历史消息请求
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
type User struct {
|
||||
ID int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Nickname string `json:"nickname"`
|
||||
PasswordHash string `json:"-"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
@@ -23,6 +24,7 @@ func CreateUsersTable(db *sql.DB) error {
|
||||
query := `CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(255) UNIQUE NOT NULL,
|
||||
nickname VARCHAR(255) DEFAULT '',
|
||||
password_hash VARCHAR(255) NOT NULL,
|
||||
is_admin BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
@@ -33,6 +35,9 @@ func CreateUsersTable(db *sql.DB) error {
|
||||
return fmt.Errorf("创建 users 表失败: %w", err)
|
||||
}
|
||||
|
||||
// 迁移:为已有 users 表添加 nickname 列
|
||||
db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS nickname VARCHAR(255) DEFAULT ''`)
|
||||
|
||||
// 创建索引
|
||||
indexQueries := []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)`,
|
||||
@@ -52,10 +57,10 @@ func CreateUsersTable(db *sql.DB) error {
|
||||
func GetUserByUsername(db *sql.DB, username string) (*User, error) {
|
||||
var u User
|
||||
err := db.QueryRow(
|
||||
`SELECT id, username, password_hash, is_admin, created_at, updated_at
|
||||
`SELECT id, username, COALESCE(nickname, '') as nickname, password_hash, is_admin, created_at, updated_at
|
||||
FROM users WHERE username = $1`,
|
||||
username,
|
||||
).Scan(&u.ID, &u.Username, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt)
|
||||
).Scan(&u.ID, &u.Username, &u.Nickname, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
@@ -70,7 +75,7 @@ func GetUserByUsername(db *sql.DB, username string) (*User, error) {
|
||||
// ListUsers 列出所有用户
|
||||
func ListUsers(db *sql.DB) ([]User, error) {
|
||||
rows, err := db.Query(
|
||||
`SELECT id, username, password_hash, is_admin, created_at, updated_at
|
||||
`SELECT id, username, COALESCE(nickname, '') as nickname, password_hash, is_admin, created_at, updated_at
|
||||
FROM users ORDER BY id`,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -81,7 +86,7 @@ func ListUsers(db *sql.DB) ([]User, error) {
|
||||
var users []User
|
||||
for rows.Next() {
|
||||
var u User
|
||||
if err := rows.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
||||
if err := rows.Scan(&u.ID, &u.Username, &u.Nickname, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描用户行失败: %w", err)
|
||||
}
|
||||
users = append(users, u)
|
||||
@@ -109,16 +114,16 @@ func DeleteUser(db *sql.DB, userID int) error {
|
||||
}
|
||||
|
||||
// CreateUser 创建新用户
|
||||
func CreateUser(db *sql.DB, username, passwordHash string, isAdmin bool) (*User, error) {
|
||||
func CreateUser(db *sql.DB, username, nickname, passwordHash string, isAdmin bool) (*User, error) {
|
||||
now := time.Now()
|
||||
var u User
|
||||
err := db.QueryRow(
|
||||
`INSERT INTO users (username, password_hash, is_admin, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $4)
|
||||
ON CONFLICT (username) DO UPDATE SET updated_at = $4
|
||||
RETURNING id, username, password_hash, is_admin, created_at, updated_at`,
|
||||
username, passwordHash, isAdmin, now,
|
||||
).Scan(&u.ID, &u.Username, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt)
|
||||
`INSERT INTO users (username, nickname, password_hash, is_admin, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $5)
|
||||
ON CONFLICT (username) DO UPDATE SET nickname = $2, updated_at = $5
|
||||
RETURNING id, username, COALESCE(nickname, '') as nickname, password_hash, is_admin, created_at, updated_at`,
|
||||
username, nickname, passwordHash, isAdmin, now,
|
||||
).Scan(&u.ID, &u.Username, &u.Nickname, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建用户失败: %w", err)
|
||||
|
||||
@@ -19,7 +19,7 @@ const (
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
|
||||
// 最大消息大小
|
||||
maxMessageSize = 65536
|
||||
maxMessageSize = 2 * 1024 * 1024 // 2MB 支持语音消息
|
||||
)
|
||||
|
||||
// Client WebSocket客户端
|
||||
@@ -135,11 +135,16 @@ func (c *Client) SendMessage(msg ServerMessage) error {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Printf("[WS] 发送消息时连接已关闭: type=%s user=%s", msg.Type, c.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case c.Send <- data:
|
||||
return nil
|
||||
default:
|
||||
// 通道满:记录警告并返回错误(避免静默丢弃)
|
||||
logger.Printf("[WS] 发送通道已满,丢弃消息: type=%s user=%s", msg.Type, c.UserID)
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user