refactor: 认证系统重构 + DevTools CLI 重写 + 文档全面更新
- auth: Login 简化为管理员始终通过 .env 验证,GetProfile 修正 admin DB 查询 - devtools: .sh/.bat 同步重写为完整 CLI (start/stop/status/logs/build/db:*) - docs: 新增 devtools.md,重写 Deploy.md (三种方式+Windows说明),更新 README/gateway-api - voice-service: DashScope 实时流式 STT 支持 - gateway: Phase 6 多模型配置 + 多端客户端管理 + WebSocket 增强 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -70,7 +70,7 @@ func main() {
|
||||
if err != nil {
|
||||
logger.Printf("⚠ 管理员密码哈希生成失败: %v", err)
|
||||
} else {
|
||||
if _, err := store.CreateUser(s.DB(), cfg.AdminUsername, string(passwordHash), true); err != nil {
|
||||
if _, err := store.CreateUser(s.DB(), cfg.AdminUsername, "管理员", string(passwordHash), true); err != nil {
|
||||
logger.Printf("⚠ 创建默认管理员失败: %v", err)
|
||||
} else {
|
||||
logger.Printf("✅ 默认管理员用户已创建 (username: %s)", cfg.AdminUsername)
|
||||
|
||||
@@ -3,7 +3,6 @@ package handler
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/yourname/cyrene-ai/pkg/logger"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -126,7 +125,9 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// Login 用户登录 (支持管理员账户和普通用户)
|
||||
// Login 用户登录
|
||||
// 管理员始终通过 .env 配置验证,不受数据库状态影响
|
||||
// 普通用户通过数据库 bcrypt 密码哈希验证
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
@@ -138,7 +139,6 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 用户名格式校验:仅允许字母、数字、下划线,长度 3-32
|
||||
if !usernameRegex.MatchString(req.Username) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "用户名格式无效"})
|
||||
return
|
||||
@@ -147,56 +147,36 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var userID string
|
||||
var nickname string
|
||||
|
||||
// 尝试从 users 表查询用户
|
||||
authenticated, err := h.verifyUserPassword(req.Username, req.Password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器内部错误"})
|
||||
return
|
||||
}
|
||||
|
||||
if authenticated {
|
||||
// 用户存在于 users 表中且密码验证通过
|
||||
if req.Username == h.cfg.AdminUsername {
|
||||
userID = "admin"
|
||||
} else {
|
||||
userID = "user_" + req.Username
|
||||
// 管理员:始终通过 .env 配置验证,不依赖数据库
|
||||
if req.Username == h.cfg.AdminUsername && req.Password == h.cfg.AdminPassword {
|
||||
userID = "admin"
|
||||
nickname = h.cfg.AdminNickname
|
||||
if nickname == "" {
|
||||
nickname = "管理员"
|
||||
}
|
||||
// 获取用户昵称
|
||||
// 数据库可用时从 DB 获取昵称(覆盖配置默认值)
|
||||
if h.db != nil {
|
||||
if u, err := store.GetUserByUsername(h.db, req.Username); err == nil && u != nil {
|
||||
nickname = u.Nickname
|
||||
}
|
||||
}
|
||||
} else if req.Username == h.cfg.AdminUsername && h.db != nil {
|
||||
// 管理员用户尚未迁移到 users 表,尝试用配置中的密码验证
|
||||
if req.Password != h.cfg.AdminPassword {
|
||||
} else {
|
||||
// 普通用户:数据库 bcrypt 验证
|
||||
authenticated, err := h.verifyUserPassword(req.Username, req.Password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器内部错误"})
|
||||
return
|
||||
}
|
||||
if !authenticated {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户名或密码错误"})
|
||||
return
|
||||
}
|
||||
// 密码正确,迁移 admin 到 users 表
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
logger.Printf("⚠ 迁移管理员密码哈希失败: %v", err)
|
||||
} else {
|
||||
if _, err := store.CreateUser(h.db, req.Username, "管理员", string(passwordHash), true); err != nil {
|
||||
logger.Printf("⚠ 迁移管理员到 users 表失败: %v", err)
|
||||
} else {
|
||||
logger.Println("✅ 管理员已迁移到 users 表")
|
||||
userID = "user_" + req.Username
|
||||
if h.db != nil {
|
||||
if u, err := store.GetUserByUsername(h.db, req.Username); err == nil && u != nil {
|
||||
nickname = u.Nickname
|
||||
}
|
||||
}
|
||||
userID = "admin"
|
||||
nickname = "管理员"
|
||||
} else if req.Username == h.cfg.AdminUsername {
|
||||
// 数据库不可用时的回退:使用配置中的管理员密码
|
||||
if req.Password != h.cfg.AdminPassword {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户名或密码错误"})
|
||||
return
|
||||
}
|
||||
userID = "admin"
|
||||
nickname = "管理员"
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "用户名或密码错误"})
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.cfg.GenerateToken(userID)
|
||||
@@ -205,7 +185,6 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 生成 refresh_token (长期有效)
|
||||
refreshToken, err := h.cfg.GenerateRefreshToken(userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "生成刷新令牌失败"})
|
||||
@@ -311,7 +290,7 @@ func (h *AuthHandler) GetProfile(c *gin.Context) {
|
||||
var nickname string
|
||||
|
||||
if isAdmin {
|
||||
username = "admin"
|
||||
username = h.cfg.AdminUsername
|
||||
} else if strings.HasPrefix(userID, "user_") {
|
||||
username = strings.TrimPrefix(userID, "user_")
|
||||
} else {
|
||||
|
||||
@@ -4,11 +4,12 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/yourname/cyrene-ai/pkg/logger"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/ws"
|
||||
"github.com/yourname/cyrene-ai/pkg/logger"
|
||||
)
|
||||
|
||||
// ChatHandler 聊天处理器
|
||||
@@ -422,6 +424,8 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
Type: "stream_end",
|
||||
MessageID: msgID,
|
||||
SessionID: client.SessionID,
|
||||
Content: fullText,
|
||||
Text: fullText,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
|
||||
@@ -458,14 +462,145 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
|
||||
// handleVoiceInput 处理语音输入
|
||||
func (h *ChatHandler) handleVoiceInput(client *ws.Client, msg ws.ClientMessage) {
|
||||
// MVP阶段:返回提示
|
||||
response := ws.ServerMessage{
|
||||
Type: "error",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Error: "语音处理功能将在后续版本中启用",
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
audioB64 := msg.AudioData
|
||||
if audioB64 == "" {
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "error",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Error: "语音数据为空",
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
return
|
||||
}
|
||||
client.SendMessage(response)
|
||||
|
||||
format := msg.Mode
|
||||
if format == "" {
|
||||
format = "webm"
|
||||
}
|
||||
|
||||
// 在 goroutine 中处理转录,避免阻塞 ReadPump
|
||||
go func() {
|
||||
text, err := h.transcribeAudio(audioB64, format)
|
||||
if err != nil {
|
||||
logger.Printf("[voice] 转录失败: %v", err)
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "voice_transcript",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Error: fmt.Sprintf("语音识别失败: %v", err),
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if text == "" {
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "voice_transcript",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Text: "",
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 发送转录结果给前端
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "voice_transcript",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Text: text,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
|
||||
// 将转录文本作为聊天消息处理
|
||||
chatMsg := ws.ClientMessage{
|
||||
Type: "message",
|
||||
Content: text,
|
||||
Mode: msg.Mode,
|
||||
}
|
||||
h.handleChatMessage(client, chatMsg)
|
||||
}()
|
||||
}
|
||||
|
||||
// transcribeAudio 将 base64 编码的音频发送到 voice-service 进行转录。
|
||||
func (h *ChatHandler) transcribeAudio(audioB64 string, format string) (string, error) {
|
||||
audioData, err := decodeBase64(audioB64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解码音频数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建 multipart form
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
|
||||
ext := ".webm"
|
||||
switch format {
|
||||
case "wav", "wave":
|
||||
ext = ".wav"
|
||||
case "mp3", "mpeg":
|
||||
ext = ".mp3"
|
||||
case "ogg", "opus":
|
||||
ext = ".ogg"
|
||||
case "pcm":
|
||||
ext = ".pcm"
|
||||
}
|
||||
|
||||
fw, err := mw.CreateFormFile("audio", "recording"+ext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建表单字段失败: %w", err)
|
||||
}
|
||||
if _, err := fw.Write(audioData); err != nil {
|
||||
return "", fmt.Errorf("写入音频数据失败: %w", err)
|
||||
}
|
||||
mw.Close()
|
||||
|
||||
voiceURL := h.cfg.VoiceServiceURL
|
||||
if voiceURL == "" {
|
||||
voiceURL = "http://localhost:8093"
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", voiceURL+"/api/v1/transcribe", &buf)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", mw.FormDataContentType())
|
||||
|
||||
httpClient := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("voice-service 调用失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Success bool `json:"success"`
|
||||
Text string `json:"text"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return "", fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if !result.Success {
|
||||
if result.Error != "" {
|
||||
return "", fmt.Errorf("%s", result.Error)
|
||||
}
|
||||
return "", fmt.Errorf("转录返回空结果")
|
||||
}
|
||||
|
||||
return result.Text, nil
|
||||
}
|
||||
|
||||
// decodeBase64 解码 base64 字符串(支持 Data URL 前缀)。
|
||||
func decodeBase64(s string) ([]byte, error) {
|
||||
// 移除 data:xxx;base64, 前缀
|
||||
if idx := strings.Index(s, ","); idx != -1 {
|
||||
s = s[idx+1:]
|
||||
}
|
||||
return base64.StdEncoding.DecodeString(s)
|
||||
}
|
||||
|
||||
// handleHistoryRequest 处理历史消息请求
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
type User struct {
|
||||
ID int `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Nickname string `json:"nickname"`
|
||||
PasswordHash string `json:"-"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
@@ -23,6 +24,7 @@ func CreateUsersTable(db *sql.DB) error {
|
||||
query := `CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(255) UNIQUE NOT NULL,
|
||||
nickname VARCHAR(255) DEFAULT '',
|
||||
password_hash VARCHAR(255) NOT NULL,
|
||||
is_admin BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
@@ -33,6 +35,9 @@ func CreateUsersTable(db *sql.DB) error {
|
||||
return fmt.Errorf("创建 users 表失败: %w", err)
|
||||
}
|
||||
|
||||
// 迁移:为已有 users 表添加 nickname 列
|
||||
db.Exec(`ALTER TABLE users ADD COLUMN IF NOT EXISTS nickname VARCHAR(255) DEFAULT ''`)
|
||||
|
||||
// 创建索引
|
||||
indexQueries := []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)`,
|
||||
@@ -52,10 +57,10 @@ func CreateUsersTable(db *sql.DB) error {
|
||||
func GetUserByUsername(db *sql.DB, username string) (*User, error) {
|
||||
var u User
|
||||
err := db.QueryRow(
|
||||
`SELECT id, username, password_hash, is_admin, created_at, updated_at
|
||||
`SELECT id, username, COALESCE(nickname, '') as nickname, password_hash, is_admin, created_at, updated_at
|
||||
FROM users WHERE username = $1`,
|
||||
username,
|
||||
).Scan(&u.ID, &u.Username, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt)
|
||||
).Scan(&u.ID, &u.Username, &u.Nickname, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
@@ -70,7 +75,7 @@ func GetUserByUsername(db *sql.DB, username string) (*User, error) {
|
||||
// ListUsers 列出所有用户
|
||||
func ListUsers(db *sql.DB) ([]User, error) {
|
||||
rows, err := db.Query(
|
||||
`SELECT id, username, password_hash, is_admin, created_at, updated_at
|
||||
`SELECT id, username, COALESCE(nickname, '') as nickname, password_hash, is_admin, created_at, updated_at
|
||||
FROM users ORDER BY id`,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -81,7 +86,7 @@ func ListUsers(db *sql.DB) ([]User, error) {
|
||||
var users []User
|
||||
for rows.Next() {
|
||||
var u User
|
||||
if err := rows.Scan(&u.ID, &u.Username, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
||||
if err := rows.Scan(&u.ID, &u.Username, &u.Nickname, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描用户行失败: %w", err)
|
||||
}
|
||||
users = append(users, u)
|
||||
@@ -109,16 +114,16 @@ func DeleteUser(db *sql.DB, userID int) error {
|
||||
}
|
||||
|
||||
// CreateUser 创建新用户
|
||||
func CreateUser(db *sql.DB, username, passwordHash string, isAdmin bool) (*User, error) {
|
||||
func CreateUser(db *sql.DB, username, nickname, passwordHash string, isAdmin bool) (*User, error) {
|
||||
now := time.Now()
|
||||
var u User
|
||||
err := db.QueryRow(
|
||||
`INSERT INTO users (username, password_hash, is_admin, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $4)
|
||||
ON CONFLICT (username) DO UPDATE SET updated_at = $4
|
||||
RETURNING id, username, password_hash, is_admin, created_at, updated_at`,
|
||||
username, passwordHash, isAdmin, now,
|
||||
).Scan(&u.ID, &u.Username, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt)
|
||||
`INSERT INTO users (username, nickname, password_hash, is_admin, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $5)
|
||||
ON CONFLICT (username) DO UPDATE SET nickname = $2, updated_at = $5
|
||||
RETURNING id, username, COALESCE(nickname, '') as nickname, password_hash, is_admin, created_at, updated_at`,
|
||||
username, nickname, passwordHash, isAdmin, now,
|
||||
).Scan(&u.ID, &u.Username, &u.Nickname, &u.PasswordHash, &u.IsAdmin, &u.CreatedAt, &u.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建用户失败: %w", err)
|
||||
|
||||
@@ -19,7 +19,7 @@ const (
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
|
||||
// 最大消息大小
|
||||
maxMessageSize = 65536
|
||||
maxMessageSize = 2 * 1024 * 1024 // 2MB 支持语音消息
|
||||
)
|
||||
|
||||
// Client WebSocket客户端
|
||||
@@ -135,11 +135,16 @@ func (c *Client) SendMessage(msg ServerMessage) error {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Printf("[WS] 发送消息时连接已关闭: type=%s user=%s", msg.Type, c.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case c.Send <- data:
|
||||
return nil
|
||||
default:
|
||||
// 通道满:记录警告并返回错误(避免静默丢弃)
|
||||
logger.Printf("[WS] 发送通道已满,丢弃消息: type=%s user=%s", msg.Type, c.UserID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ func main() {
|
||||
if sttSvc.IsAvailable() {
|
||||
dashAvailable := cfg.DashScopeAPIKey != ""
|
||||
if dashAvailable {
|
||||
logger.Println("STT: DashScope Gummy (主) + Whisper (回退)")
|
||||
logger.Printf("STT: DashScope 实时=%s, 离线=%s + Whisper (回退)", cfg.DashScopeSTTRealtime, cfg.DashScopeModel)
|
||||
} else {
|
||||
logger.Println("STT: Whisper 本地引擎")
|
||||
}
|
||||
|
||||
@@ -11,20 +11,44 @@ type Config struct {
|
||||
MaxAudioSize int64 // 字节
|
||||
|
||||
// DashScope STT 配置
|
||||
DashScopeAPIKey string
|
||||
DashScopeModel string
|
||||
DashScopeAPIKey string
|
||||
DashScopeModel string // 离线/非实时 ASR 模型
|
||||
DashScopeSTTRealtime string // 实时 ASR 模型
|
||||
}
|
||||
|
||||
// Load 从环境变量加载配置
|
||||
// Load 从 models.json 和环境变量加载配置。
|
||||
// models.json 优先级高于环境变量。
|
||||
func Load() *Config {
|
||||
// 从 models.json 加载 ASR 配置
|
||||
modelsAPIKey, modelsOffline, modelsRealtime := LoadModelsConfig()
|
||||
|
||||
// .env / 环境变量作为回退
|
||||
envAPIKey := getEnv("DASHSCOPE_API_KEY", "")
|
||||
envModel := getEnv("DASHSCOPE_STT_MODEL", "qwen3-asr-flash-2026-02-10")
|
||||
envRealtime := getEnv("DASHSCOPE_STT_REALTIME_MODEL", "qwen3-asr-flash-realtime")
|
||||
|
||||
apiKey := modelsAPIKey
|
||||
if apiKey == "" {
|
||||
apiKey = envAPIKey
|
||||
}
|
||||
offlineModel := modelsOffline
|
||||
if offlineModel == "" {
|
||||
offlineModel = envModel
|
||||
}
|
||||
realtimeModel := modelsRealtime
|
||||
if realtimeModel == "" {
|
||||
realtimeModel = envRealtime
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Port: getEnv("PORT", "8093"),
|
||||
WhisperBinary: getEnv("WHISPER_BINARY", "./whisper.cpp/main"),
|
||||
WhisperModel: getEnv("WHISPER_MODEL", "./whisper.cpp/models/ggml-small.bin"),
|
||||
WhisperLanguage: getEnv("WHISPER_LANGUAGE", "zh"),
|
||||
MaxAudioSize: 10 * 1024 * 1024, // 10MB
|
||||
DashScopeAPIKey: getEnv("DASHSCOPE_API_KEY", ""),
|
||||
DashScopeModel: getEnv("DASHSCOPE_STT_MODEL", "gummy-chat-v1"),
|
||||
Port: getEnv("PORT", "8093"),
|
||||
WhisperBinary: getEnv("WHISPER_BINARY", "./whisper.cpp/main"),
|
||||
WhisperModel: getEnv("WHISPER_MODEL", "./whisper.cpp/models/ggml-small.bin"),
|
||||
WhisperLanguage: getEnv("WHISPER_LANGUAGE", "zh"),
|
||||
MaxAudioSize: 10 * 1024 * 1024, // 10MB
|
||||
DashScopeAPIKey: apiKey,
|
||||
DashScopeModel: offlineModel,
|
||||
DashScopeSTTRealtime: realtimeModel,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// ModelsJSON 映射 models.json 文件结构(仅提取语音相关字段)。
|
||||
type ModelsJSON struct {
|
||||
Providers map[string]ModelsProvider `json:"providers"`
|
||||
Models map[string]ModelsModel `json:"models"`
|
||||
Routing map[string]ModelsRouting `json:"routing"`
|
||||
}
|
||||
|
||||
type ModelsProvider struct {
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
}
|
||||
|
||||
type ModelsModel struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
}
|
||||
|
||||
type ModelsRouting struct {
|
||||
Purpose string `json:"purpose"`
|
||||
FallbackChain []string `json:"fallback_chain"`
|
||||
}
|
||||
|
||||
// LoadModelsConfig 从 backend/models.json 加载模型配置。
|
||||
// 返回 provider API key 和 ASR 模型名称。如果文件不存在则返回零值。
|
||||
func LoadModelsConfig() (apiKey string, offlineModel string, realtimeModel string) {
|
||||
// 尝试多个可能的路径
|
||||
candidates := []string{
|
||||
"models.json",
|
||||
"../models.json",
|
||||
"../../models.json",
|
||||
filepath.Join("..", "models.json"),
|
||||
}
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
for _, p := range candidates {
|
||||
data, err = os.ReadFile(p)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if data == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var cfg ModelsJSON
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 从 speech_recognition_offline 路由取离线模型
|
||||
if r, ok := cfg.Routing["speech_recognition_offline"]; ok && len(r.FallbackChain) > 0 {
|
||||
modelID := r.FallbackChain[0]
|
||||
if m, ok := cfg.Models[modelID]; ok {
|
||||
offlineModel = m.Name
|
||||
if m.Name == "" {
|
||||
offlineModel = m.ID
|
||||
}
|
||||
// 从 provider 取 API key
|
||||
if p, ok := cfg.Providers[m.Provider]; ok && apiKey == "" {
|
||||
apiKey = p.APIKey
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 从 speech_recognition 路由取实时模型
|
||||
if r, ok := cfg.Routing["speech_recognition"]; ok && len(r.FallbackChain) > 0 {
|
||||
modelID := r.FallbackChain[0]
|
||||
if m, ok := cfg.Models[modelID]; ok {
|
||||
realtimeModel = m.Name
|
||||
if realtimeModel == "" {
|
||||
realtimeModel = m.ID
|
||||
}
|
||||
// 从 provider 取 API key
|
||||
if p, ok := cfg.Providers[m.Provider]; ok && apiKey == "" {
|
||||
apiKey = p.APIKey
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 routing 里没有,尝试从所有 models 中找带有 ASR tag 的模型
|
||||
if offlineModel == "" {
|
||||
for _, m := range cfg.Models {
|
||||
if m.Provider == "dashscope" && m.Name != "" {
|
||||
if offlineModel == "" {
|
||||
offlineModel = m.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 从 provider 取 API key (兜底)
|
||||
if apiKey == "" {
|
||||
if p, ok := cfg.Providers["dashscope"]; ok {
|
||||
apiKey = p.APIKey
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
@@ -21,8 +21,8 @@ var upgrader = websocket.Upgrader{
|
||||
}
|
||||
|
||||
// StreamingSTTHandler 处理实时语音识别 WebSocket 连接。
|
||||
// 客户端通过 WebSocket 流式发送音频二进制帧,服务端逐帧转发到 DashScope,
|
||||
// 将识别结果通过 WebSocket JSON 消息返回。
|
||||
// 客户端通过 WebSocket 流式发送音频二进制帧,服务端通过一条持久的
|
||||
// DashScope WebSocket 连接转发音频并持续返回识别结果。
|
||||
type StreamingSTTHandler struct {
|
||||
svc *service.STTService
|
||||
}
|
||||
@@ -46,6 +46,10 @@ func (h *StreamingSTTHandler) HandleStreamingSTT(w http.ResponseWriter, r *http.
|
||||
if language == "" {
|
||||
language = "zh"
|
||||
}
|
||||
format := r.URL.Query().Get("format")
|
||||
if format == "" {
|
||||
format = "pcm"
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
@@ -54,12 +58,47 @@ func (h *StreamingSTTHandler) HandleStreamingSTT(w http.ResponseWriter, r *http.
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
logger.Printf("[stream-stt] 客户端已连接")
|
||||
logger.Printf("[stream-stt] 客户端已连接, format=%s, language=%s", format, language)
|
||||
|
||||
// 创建持久的 DashScope 流式会话
|
||||
session, err := h.svc.StartStreaming(format, language)
|
||||
if err != nil {
|
||||
logger.Printf("[stream-stt] 创建 DashScope 会话失败: %v", err)
|
||||
conn.WriteJSON(map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": "启动语音识别失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
var mu sync.Mutex
|
||||
conn.SetWriteDeadline(time.Now().Add(60 * time.Second))
|
||||
conn.SetReadDeadline(time.Now().Add(300 * time.Second)) // 5 分钟超时
|
||||
|
||||
// 读取音频帧并发送到 DashScope
|
||||
// goroutine: 读取 DashScope 结果并推送到客户端
|
||||
resultDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(resultDone)
|
||||
for result := range session.Results() {
|
||||
mu.Lock()
|
||||
if result.Error != "" {
|
||||
logger.Printf("[stream-stt] DashScope 错误: %s", result.Error)
|
||||
conn.WriteJSON(map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": result.Error,
|
||||
})
|
||||
} else if result.Text != "" {
|
||||
conn.WriteJSON(map[string]interface{}{
|
||||
"type": "result",
|
||||
"text": result.Text,
|
||||
"isFinal": result.IsFinal,
|
||||
})
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// 主循环: 读取客户端音频帧
|
||||
for {
|
||||
msgType, data, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
@@ -69,11 +108,13 @@ func (h *StreamingSTTHandler) HandleStreamingSTT(w http.ResponseWriter, r *http.
|
||||
break
|
||||
}
|
||||
|
||||
// 支持文本控制消息
|
||||
// 文本控制消息
|
||||
if msgType == websocket.TextMessage {
|
||||
var ctrl map[string]interface{}
|
||||
if json.Unmarshal(data, &ctrl) == nil {
|
||||
if ctrl["action"] == "stop" {
|
||||
action, _ := ctrl["action"].(string)
|
||||
if action == "stop" {
|
||||
logger.Printf("[stream-stt] 客户端请求停止")
|
||||
mu.Lock()
|
||||
conn.WriteJSON(map[string]interface{}{
|
||||
"type": "done",
|
||||
@@ -82,34 +123,33 @@ func (h *StreamingSTTHandler) HandleStreamingSTT(w http.ResponseWriter, r *http.
|
||||
mu.Unlock()
|
||||
break
|
||||
}
|
||||
// 支持动态切换语言
|
||||
if lang, ok := ctrl["language"].(string); ok && lang != "" {
|
||||
language = lang
|
||||
logger.Printf("[stream-stt] 切换语言: %s", lang)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 二进制音频帧:进行识别
|
||||
if msgType == websocket.BinaryMessage {
|
||||
format := r.URL.Query().Get("format")
|
||||
if format == "" {
|
||||
format = "pcm"
|
||||
}
|
||||
|
||||
text, err := h.svc.Transcribe(data, format, language)
|
||||
mu.Lock()
|
||||
if err != nil {
|
||||
// 二进制音频帧: 发送到 DashScope
|
||||
if msgType == websocket.BinaryMessage && len(data) > 0 {
|
||||
if err := session.SendAudio(data); err != nil {
|
||||
logger.Printf("[stream-stt] 发送音频帧失败: %v", err)
|
||||
mu.Lock()
|
||||
conn.WriteJSON(map[string]interface{}{
|
||||
"type": "error",
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else if text != "" {
|
||||
conn.WriteJSON(map[string]interface{}{
|
||||
"type": "result",
|
||||
"text": text,
|
||||
"final": true,
|
||||
"error": "发送音频失败: " + err.Error(),
|
||||
})
|
||||
mu.Unlock()
|
||||
break
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// 等待结果推送完成
|
||||
<-resultDone
|
||||
logger.Printf("[stream-stt] 会话结束")
|
||||
}
|
||||
|
||||
// RegisterStreamingRoutes 注册流式 STT 路由。
|
||||
|
||||
@@ -2,15 +2,21 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// DashScopeSTT 使用阿里云百炼 Gummy 模型进行语音识别。
|
||||
// WebSocket API: wss://dashscope.aliyuncs.com/api-ws/v1/inference
|
||||
// DashScopeSTT 使用阿里云百炼 Qwen ASR 模型进行语音识别。
|
||||
// 实时模型 (qwen3-asr-flash-realtime) 通过 WebSocket realtime 端点进行流式识别,
|
||||
// 基于 session/VAD 协议(类似 OpenAI Realtime API)。
|
||||
type DashScopeSTT struct {
|
||||
apiKey string
|
||||
model string
|
||||
@@ -20,7 +26,7 @@ type DashScopeSTT struct {
|
||||
// NewDashScopeSTT 创建 DashScope STT 客户端。
|
||||
func NewDashScopeSTT(apiKey, model string) *DashScopeSTT {
|
||||
if model == "" {
|
||||
model = "gummy-chat-v1"
|
||||
model = "qwen3-asr-flash-realtime"
|
||||
}
|
||||
return &DashScopeSTT{
|
||||
apiKey: apiKey,
|
||||
@@ -34,232 +40,402 @@ func (d *DashScopeSTT) IsAvailable() bool {
|
||||
return d.apiKey != ""
|
||||
}
|
||||
|
||||
// sttMessage 定义 STT WebSocket 协议消息格式。
|
||||
type sttMessage struct {
|
||||
Header sttHeader `json:"header"`
|
||||
Payload sttPayload `json:"payload"`
|
||||
// Model 返回模型名。
|
||||
func (d *DashScopeSTT) Model() string { return d.model }
|
||||
|
||||
// --- Realtime 端点协议消息类型 ---
|
||||
|
||||
type rtClientMsg struct {
|
||||
EventID string `json:"event_id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Session interface{} `json:"session,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
type sttHeader struct {
|
||||
Streaming string `json:"streaming"`
|
||||
TaskID string `json:"task_id"`
|
||||
Action string `json:"action"`
|
||||
type rtServerMsg struct {
|
||||
EventID string `json:"event_id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Session json.RawMessage `json:"session,omitempty"`
|
||||
Error *rtError `json:"error,omitempty"`
|
||||
|
||||
// response.audio_transcript.delta
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Response *struct {
|
||||
Output []struct {
|
||||
Transcript string `json:"transcript,omitempty"`
|
||||
} `json:"output,omitempty"`
|
||||
} `json:"response,omitempty"`
|
||||
|
||||
// transcription completed transcript
|
||||
Transcript string `json:"transcript,omitempty"`
|
||||
|
||||
// conversation.item.input_audio_transcription.completed
|
||||
Item *struct {
|
||||
Content []struct {
|
||||
Transcript string `json:"transcript,omitempty"`
|
||||
} `json:"content,omitempty"`
|
||||
} `json:"item,omitempty"`
|
||||
}
|
||||
|
||||
type sttPayload struct {
|
||||
Model string `json:"model"`
|
||||
TaskGroup string `json:"task_group"`
|
||||
Task string `json:"task"`
|
||||
Function string `json:"function"`
|
||||
Input map[string]interface{} `json:"input,omitempty"`
|
||||
Parameters sttParameters `json:"parameters"`
|
||||
Output map[string]interface{} `json:"output,omitempty"`
|
||||
}
|
||||
|
||||
type sttParameters struct {
|
||||
SampleRate int `json:"sample_rate"`
|
||||
Format string `json:"format"`
|
||||
TranscriptionEnabled bool `json:"transcription_enabled"`
|
||||
TranslationEnabled bool `json:"translation_enabled"`
|
||||
SourceLanguage string `json:"source_language,omitempty"`
|
||||
MaxEndSilence int `json:"max_end_silence,omitempty"`
|
||||
}
|
||||
|
||||
// sttServerMsg 服务端返回的消息格式。
|
||||
type sttServerMsg struct {
|
||||
Header sttServerHeader `json:"header"`
|
||||
Payload sttServerPayload `json:"payload"`
|
||||
}
|
||||
|
||||
type sttServerHeader struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Event string `json:"event"`
|
||||
}
|
||||
|
||||
type sttServerPayload struct {
|
||||
Output map[string]interface{} `json:"output,omitempty"`
|
||||
Usage map[string]interface{} `json:"usage,omitempty"`
|
||||
Error sttError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type sttError struct {
|
||||
type rtError struct {
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Param string `json:"param,omitempty"`
|
||||
}
|
||||
|
||||
// Transcribe 将音频数据发送到 DashScope 进行识别,返回识别文本。
|
||||
// 使用 realtime 端点,通过 Server VAD 自动检测语音并触发转录。
|
||||
func (d *DashScopeSTT) Transcribe(ctx context.Context, audioData []byte, format string, language string) (string, error) {
|
||||
if !d.IsAvailable() {
|
||||
return "", fmt.Errorf("DashScope API Key 未配置")
|
||||
}
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
url := fmt.Sprintf("wss://dashscope.aliyuncs.com/api-ws/v1/realtime?model=%s", d.model)
|
||||
header := map[string][]string{
|
||||
"Authorization": {"Bearer " + d.apiKey},
|
||||
}
|
||||
|
||||
header := make(map[string][]string)
|
||||
header["Authorization"] = []string{"Bearer " + d.apiKey}
|
||||
|
||||
conn, _, err := dialer.DialContext(ctx, "wss://dashscope.aliyuncs.com/api-ws/v1/inference", header)
|
||||
dialer := websocket.Dialer{HandshakeTimeout: 10 * time.Second}
|
||||
conn, _, err := dialer.DialContext(ctx, url, header)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("连接 DashScope STT 失败: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(d.timeout))
|
||||
// 1. session.created
|
||||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
var msg rtServerMsg
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
return "", fmt.Errorf("等待 session.created 失败: %w", err)
|
||||
}
|
||||
if msg.Type != "session.created" {
|
||||
return "", fmt.Errorf("预期 session.created 但收到: %s", msg.Type)
|
||||
}
|
||||
|
||||
taskID := fmt.Sprintf("cyrene-stt-%d", time.Now().UnixNano())
|
||||
|
||||
// 规范化音频格式
|
||||
normFormat := normalizeSTTFormat(format)
|
||||
// 2. session.update
|
||||
if language == "" || language == "auto" {
|
||||
language = "zh"
|
||||
}
|
||||
|
||||
// 发送 run-task
|
||||
startMsg := sttMessage{
|
||||
Header: sttHeader{
|
||||
Streaming: "duplex",
|
||||
TaskID: taskID,
|
||||
Action: "run-task",
|
||||
},
|
||||
Payload: sttPayload{
|
||||
Model: d.model,
|
||||
TaskGroup: "audio",
|
||||
Task: "asr",
|
||||
Function: "recognition",
|
||||
Parameters: sttParameters{
|
||||
SampleRate: 16000,
|
||||
Format: normFormat,
|
||||
TranscriptionEnabled: true,
|
||||
TranslationEnabled: false,
|
||||
SourceLanguage: language,
|
||||
updateMsg := rtClientMsg{
|
||||
Type: "session.update",
|
||||
Session: map[string]interface{}{
|
||||
"modalities": []string{"text"},
|
||||
"input_audio_format": "pcm",
|
||||
"sample_rate": 16000,
|
||||
"input_audio_transcription": map[string]interface{}{
|
||||
"language": language,
|
||||
},
|
||||
"turn_detection": map[string]interface{}{
|
||||
"type": "server_vad",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := conn.WriteJSON(startMsg); err != nil {
|
||||
return "", fmt.Errorf("发送 run-task 失败: %w", err)
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteJSON(updateMsg); err != nil {
|
||||
return "", fmt.Errorf("发送 session.update 失败: %w", err)
|
||||
}
|
||||
|
||||
// 等待 task-started
|
||||
var textResult string
|
||||
var mu sync.Mutex
|
||||
started := make(chan struct{})
|
||||
errc := make(chan error, 1)
|
||||
done := make(chan struct{})
|
||||
// 3. session.updated
|
||||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
return "", fmt.Errorf("等待 session.updated 失败: %w", err)
|
||||
}
|
||||
if msg.Type == "error" && msg.Error != nil {
|
||||
return "", fmt.Errorf("session.update 失败: %s", msg.Error.Message)
|
||||
}
|
||||
|
||||
// 4. 规范化音频格式并发送
|
||||
pcmData, err := convertToPCM16(audioData, format)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("音频格式转换失败: %w", err)
|
||||
}
|
||||
|
||||
chunkSize := 3200
|
||||
for i := 0; i < len(pcmData); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(pcmData) {
|
||||
end = len(pcmData)
|
||||
}
|
||||
chunkB64 := base64.StdEncoding.EncodeToString(pcmData[i:end])
|
||||
audioMsg := rtClientMsg{
|
||||
Type: "input_audio_buffer.append",
|
||||
Audio: chunkB64,
|
||||
}
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteJSON(audioMsg); err != nil {
|
||||
return "", fmt.Errorf("发送音频数据失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 等待转录结果
|
||||
// 用 goroutine + channel 避免 gorilla/websocket 超时后重复读取 panic
|
||||
type readResult struct {
|
||||
msg rtServerMsg
|
||||
err error
|
||||
}
|
||||
msgCh := make(chan readResult, 1)
|
||||
readDone := make(chan struct{})
|
||||
defer close(readDone)
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
startedClosed := false
|
||||
for {
|
||||
var msg sttServerMsg
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
select {
|
||||
case errc <- fmt.Errorf("读取响应失败: %w", err):
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-readDone:
|
||||
return
|
||||
default:
|
||||
}
|
||||
var m rtServerMsg
|
||||
err := conn.ReadJSON(&m)
|
||||
select {
|
||||
case msgCh <- readResult{m, err}:
|
||||
case <-readDone:
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.Header.Event {
|
||||
case "task-started":
|
||||
if !startedClosed {
|
||||
close(started)
|
||||
startedClosed = true
|
||||
}
|
||||
case "result-generated":
|
||||
if out, ok := msg.Payload.Output["transcription"]; ok {
|
||||
if transMap, ok := out.(map[string]interface{}); ok {
|
||||
if text, ok := transMap["text"].(string); ok {
|
||||
mu.Lock()
|
||||
textResult = text
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
case "task-finished":
|
||||
return
|
||||
case "task-failed":
|
||||
errMsg := msg.Payload.Error.Message
|
||||
if errMsg == "" {
|
||||
errMsg = "未知错误"
|
||||
}
|
||||
select {
|
||||
case errc <- fmt.Errorf("DashScope 识别失败: %s (code=%s)", errMsg, msg.Payload.Error.Code):
|
||||
default:
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待 task-started 或错误
|
||||
select {
|
||||
case <-started:
|
||||
case err := <-errc:
|
||||
return "", err
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
}
|
||||
var textResult string
|
||||
silenceTimeout := 3 * time.Second
|
||||
timer := time.NewTimer(60 * time.Second)
|
||||
defer timer.Stop()
|
||||
|
||||
// 发送音频数据(分块发送,每块 ~10KB)
|
||||
chunkSize := 10240
|
||||
for i := 0; i < len(audioData); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(audioData) {
|
||||
end = len(audioData)
|
||||
}
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteMessage(websocket.BinaryMessage, audioData[i:end]); err != nil {
|
||||
return "", fmt.Errorf("发送音频数据失败: %w", err)
|
||||
}
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case result := <-msgCh:
|
||||
if result.err != nil {
|
||||
if websocket.IsUnexpectedCloseError(result.err) {
|
||||
return "", fmt.Errorf("连接异常关闭: %w", result.err)
|
||||
}
|
||||
return textResult, nil
|
||||
}
|
||||
|
||||
// 发送 finish-task
|
||||
finishMsg := sttMessage{
|
||||
Header: sttHeader{
|
||||
Streaming: "duplex",
|
||||
TaskID: taskID,
|
||||
Action: "finish-task",
|
||||
},
|
||||
}
|
||||
if err := conn.WriteJSON(finishMsg); err != nil {
|
||||
return "", fmt.Errorf("发送 finish-task 失败: %w", err)
|
||||
}
|
||||
msg := result.msg
|
||||
|
||||
// 等待完成
|
||||
select {
|
||||
case <-done:
|
||||
mu.Lock()
|
||||
text := textResult
|
||||
mu.Unlock()
|
||||
if text == "" {
|
||||
return "", fmt.Errorf("未收到识别结果")
|
||||
switch msg.Type {
|
||||
case "conversation.item.input_audio_transcription.completed":
|
||||
if msg.Transcript != "" {
|
||||
if textResult != "" {
|
||||
textResult += "\n"
|
||||
}
|
||||
textResult += msg.Transcript
|
||||
}
|
||||
if textResult == "" && msg.Item != nil {
|
||||
for _, c := range msg.Item.Content {
|
||||
if c.Transcript != "" {
|
||||
textResult = c.Transcript
|
||||
}
|
||||
}
|
||||
}
|
||||
case "response.audio_transcript.delta":
|
||||
if msg.Delta != "" {
|
||||
textResult += msg.Delta
|
||||
}
|
||||
case "response.done":
|
||||
if textResult == "" && msg.Response != nil {
|
||||
for _, o := range msg.Response.Output {
|
||||
if o.Transcript != "" {
|
||||
textResult += o.Transcript
|
||||
}
|
||||
}
|
||||
}
|
||||
if textResult != "" {
|
||||
return textResult, nil
|
||||
}
|
||||
case "error":
|
||||
if msg.Error != nil {
|
||||
return "", fmt.Errorf("DashScope 识别失败: %s", msg.Error.Message)
|
||||
}
|
||||
return "", fmt.Errorf("DashScope 返回未知错误")
|
||||
}
|
||||
|
||||
if textResult != "" {
|
||||
timer.Reset(silenceTimeout)
|
||||
}
|
||||
|
||||
case <-timer.C:
|
||||
return textResult, nil
|
||||
}
|
||||
return text, nil
|
||||
case err := <-errc:
|
||||
return "", err
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeSTTFormat 将音频格式映射到 DashScope 支持的格式名。
|
||||
func normalizeSTTFormat(format string) string {
|
||||
switch format {
|
||||
case "wav":
|
||||
return "wav"
|
||||
case "mp3", "mpeg":
|
||||
return "mp3"
|
||||
case "ogg", "opus":
|
||||
return "ogg"
|
||||
case "flac":
|
||||
return "flac"
|
||||
case "m4a", "aac", "mp4":
|
||||
return "aac"
|
||||
default:
|
||||
return "pcm"
|
||||
// --- 流式识别 (StreamingSession) ---
|
||||
|
||||
// StreamingSession 维护一个持久的 DashScope WebSocket 连接,用于实时语音识别。
|
||||
type StreamingSession struct {
|
||||
conn *websocket.Conn
|
||||
results chan StreamingResult
|
||||
done chan struct{}
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
// StreamingResult 实时识别结果。
|
||||
type StreamingResult struct {
|
||||
Text string `json:"text"`
|
||||
IsFinal bool `json:"is_final"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// StartStreaming 建立 DashScope realtime WebSocket 连接并返回 StreamingSession。
|
||||
func (d *DashScopeSTT) StartStreaming(ctx context.Context, format, language string) (*StreamingSession, error) {
|
||||
if !d.IsAvailable() {
|
||||
return nil, fmt.Errorf("DashScope API Key 未配置")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("wss://dashscope.aliyuncs.com/api-ws/v1/realtime?model=%s", d.model)
|
||||
header := map[string][]string{
|
||||
"Authorization": {"Bearer " + d.apiKey},
|
||||
}
|
||||
dialer := websocket.Dialer{HandshakeTimeout: 10 * time.Second}
|
||||
conn, _, err := dialer.DialContext(ctx, url, header)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接 DashScope STT 失败: %w", err)
|
||||
}
|
||||
|
||||
// 1. session.created
|
||||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
var msg rtServerMsg
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("等待 session.created 失败: %w", err)
|
||||
}
|
||||
if msg.Type != "session.created" {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("预期 session.created 但收到: %s", msg.Type)
|
||||
}
|
||||
|
||||
// 2. session.update
|
||||
if language == "" || language == "auto" {
|
||||
language = "zh"
|
||||
}
|
||||
updateMsg := rtClientMsg{
|
||||
Type: "session.update",
|
||||
Session: map[string]interface{}{
|
||||
"modalities": []string{"text"},
|
||||
"input_audio_format": "pcm",
|
||||
"sample_rate": 16000,
|
||||
"input_audio_transcription": map[string]interface{}{
|
||||
"language": language,
|
||||
},
|
||||
"turn_detection": map[string]interface{}{
|
||||
"type": "server_vad",
|
||||
},
|
||||
},
|
||||
}
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteJSON(updateMsg); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("发送 session.update 失败: %w", err)
|
||||
}
|
||||
|
||||
// 3. session.updated
|
||||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.ReadJSON(&msg); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("等待 session.updated 失败: %w", err)
|
||||
}
|
||||
if msg.Type == "error" && msg.Error != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("session.update 失败: %s", msg.Error.Message)
|
||||
}
|
||||
|
||||
session := &StreamingSession{
|
||||
conn: conn,
|
||||
results: make(chan StreamingResult, 64),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
go session.readLoop()
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// SendAudio 发送一帧 PCM 音频数据到 DashScope。
|
||||
// data 必须是 16-bit little-endian PCM,16000Hz,mono。
|
||||
func (s *StreamingSession) SendAudio(data []byte) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return fmt.Errorf("session 已关闭")
|
||||
}
|
||||
|
||||
b64 := base64.StdEncoding.EncodeToString(data)
|
||||
msg := rtClientMsg{
|
||||
Type: "input_audio_buffer.append",
|
||||
Audio: b64,
|
||||
}
|
||||
s.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
return s.conn.WriteJSON(msg)
|
||||
}
|
||||
|
||||
// Results 返回识别结果通道。
|
||||
func (s *StreamingSession) Results() <-chan StreamingResult {
|
||||
return s.results
|
||||
}
|
||||
|
||||
// Close 结束会话并关闭 WebSocket 连接。
|
||||
func (s *StreamingSession) Close() error {
|
||||
s.mu.Lock()
|
||||
if s.closed {
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
s.mu.Unlock()
|
||||
|
||||
finishMsg := rtClientMsg{Type: "session.finish"}
|
||||
s.conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
s.conn.WriteJSON(finishMsg)
|
||||
|
||||
select {
|
||||
case <-s.done:
|
||||
case <-time.After(5 * time.Second):
|
||||
}
|
||||
|
||||
close(s.results)
|
||||
return s.conn.Close()
|
||||
}
|
||||
|
||||
// readLoop 读取 DashScope 服务端返回的消息并转换为 StreamingResult。
|
||||
func (s *StreamingSession) readLoop() {
|
||||
defer close(s.done)
|
||||
for {
|
||||
var msg rtServerMsg
|
||||
if err := s.conn.ReadJSON(&msg); err != nil {
|
||||
s.results <- StreamingResult{Error: fmt.Sprintf("读取响应失败: %v", err)}
|
||||
return
|
||||
}
|
||||
|
||||
switch msg.Type {
|
||||
case "conversation.item.input_audio_transcription.completed":
|
||||
if msg.Transcript != "" {
|
||||
s.results <- StreamingResult{Text: msg.Transcript, IsFinal: true}
|
||||
} else if msg.Item != nil {
|
||||
for _, c := range msg.Item.Content {
|
||||
if c.Transcript != "" {
|
||||
s.results <- StreamingResult{Text: c.Transcript, IsFinal: true}
|
||||
}
|
||||
}
|
||||
}
|
||||
case "response.audio_transcript.delta":
|
||||
s.results <- StreamingResult{Text: msg.Delta, IsFinal: false}
|
||||
case "response.done":
|
||||
// 全部完成
|
||||
case "error":
|
||||
errMsg := "未知错误"
|
||||
if msg.Error != nil {
|
||||
errMsg = msg.Error.Message
|
||||
}
|
||||
s.results <- StreamingResult{Error: fmt.Sprintf("DashScope 识别失败: %s", errMsg)}
|
||||
return
|
||||
case "response.created", "input_audio_buffer.committed",
|
||||
"input_audio_buffer.speech_started", "input_audio_buffer.speech_stopped",
|
||||
"conversation.item.created", "conversation.item.input_audio_transcription.text",
|
||||
"response.audio_transcript.done":
|
||||
// 内部事件,忽略
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,3 +447,74 @@ func (d *DashScopeSTT) GetStatus() map[string]interface{} {
|
||||
"provider": "dashscope",
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeSTTFormat 规范化音频格式字符串。
|
||||
func normalizeSTTFormat(format string) string {
|
||||
switch strings.ToLower(format) {
|
||||
case "pcm", "wav", "mp3", "mpeg", "ogg", "opus", "flac", "m4a", "mp4", "aac", "webm":
|
||||
return strings.ToLower(format)
|
||||
default:
|
||||
return format
|
||||
}
|
||||
}
|
||||
|
||||
// convertToPCM16 将音频数据转换为 16-bit PCM 16000Hz mono。
|
||||
func convertToPCM16(data []byte, format string) ([]byte, error) {
|
||||
normFormat := normalizeSTTFormat(format)
|
||||
switch normFormat {
|
||||
case "pcm":
|
||||
return data, nil
|
||||
case "wav":
|
||||
if len(data) > 44 {
|
||||
return data[44:], nil
|
||||
}
|
||||
return data, nil
|
||||
default:
|
||||
return transcodeToPCM(data, normFormat)
|
||||
}
|
||||
}
|
||||
|
||||
// transcodeToPCM 使用 ffmpeg 将音频数据转码为 PCM 16-bit 16000Hz mono。
|
||||
func transcodeToPCM(data []byte, format string) ([]byte, error) {
|
||||
inFile, err := os.CreateTemp(os.TempDir(), "cyrene-asr-in-*."+format)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建输入临时文件失败: %w", err)
|
||||
}
|
||||
inPath := inFile.Name()
|
||||
defer os.Remove(inPath)
|
||||
if _, err := inFile.Write(data); err != nil {
|
||||
inFile.Close()
|
||||
return nil, fmt.Errorf("写入输入临时文件失败: %w", err)
|
||||
}
|
||||
inFile.Close()
|
||||
|
||||
outFile, err := os.CreateTemp(os.TempDir(), "cyrene-asr-out-*.pcm")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建输出临时文件失败: %w", err)
|
||||
}
|
||||
outPath := outFile.Name()
|
||||
outFile.Close()
|
||||
defer os.Remove(outPath)
|
||||
|
||||
cmd := exec.Command("ffmpeg",
|
||||
"-i", inPath,
|
||||
"-ar", "16000",
|
||||
"-ac", "1",
|
||||
"-c:a", "pcm_s16le",
|
||||
"-f", "s16le",
|
||||
outPath,
|
||||
"-y",
|
||||
)
|
||||
cmd.Stderr = nil
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("音频转码失败 (ffmpeg): %w", err)
|
||||
}
|
||||
|
||||
outData, err := os.ReadFile(outPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取转码结果失败: %w", err)
|
||||
}
|
||||
|
||||
return outData, nil
|
||||
}
|
||||
|
||||
@@ -16,21 +16,27 @@ import (
|
||||
var SupportedLanguages = []string{"zh", "en", "ja", "ko", "auto"}
|
||||
|
||||
// STTService 语音转文字服务。
|
||||
// 优先使用 DashScope Gummy API,不可用时回退到本地 Whisper。
|
||||
// 优先使用 DashScope API,不可用时回退到本地 Whisper。
|
||||
type STTService struct {
|
||||
whisperBinary string
|
||||
whisperModel string
|
||||
language string
|
||||
dashscope *DashScopeSTT
|
||||
dashscope *DashScopeSTT // 实时 ASR (qwen3-asr-flash-realtime)
|
||||
}
|
||||
|
||||
// NewSTTService 创建 STT 服务。
|
||||
func NewSTTService(cfg *config.Config) *STTService {
|
||||
// 实时模型用于所有 WebSocket ASR 请求(支持 one-shot 和 streaming)
|
||||
// 离线模型 (qwen3-asr-flash-2026-02-10) 是 HTTP REST API,暂未实现
|
||||
model := cfg.DashScopeSTTRealtime
|
||||
if model == "" {
|
||||
model = cfg.DashScopeModel
|
||||
}
|
||||
return &STTService{
|
||||
whisperBinary: cfg.WhisperBinary,
|
||||
whisperModel: cfg.WhisperModel,
|
||||
language: cfg.WhisperLanguage,
|
||||
dashscope: NewDashScopeSTT(cfg.DashScopeAPIKey, cfg.DashScopeModel),
|
||||
dashscope: NewDashScopeSTT(cfg.DashScopeAPIKey, model),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,15 +64,30 @@ func (s *STTService) Transcribe(audioData []byte, format string, language string
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
text, err := s.dashscope.Transcribe(ctx, audioData, format, language)
|
||||
if err == nil && text != "" {
|
||||
if err == nil {
|
||||
return text, nil
|
||||
}
|
||||
// DashScope 失败,返回具体错误而不是回退到 Whisper
|
||||
return "", fmt.Errorf("语音识别失败: %w", err)
|
||||
}
|
||||
|
||||
// 回退到本地 Whisper
|
||||
return s.transcribeWhisper(audioData, format, language)
|
||||
}
|
||||
|
||||
// StartStreaming 创建持久的流式语音识别会话。
|
||||
func (s *STTService) StartStreaming(format, language string) (*StreamingSession, error) {
|
||||
if !s.dashscope.IsAvailable() {
|
||||
return nil, fmt.Errorf("流式识别需要 DashScope,请配置 DASHSCOPE_API_KEY")
|
||||
}
|
||||
if language == "" {
|
||||
language = s.language
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
return s.dashscope.StartStreaming(ctx, format, language)
|
||||
}
|
||||
|
||||
// transcribeWhisper 使用本地 Whisper 引擎转录。
|
||||
func (s *STTService) transcribeWhisper(audioData []byte, format string, language string) (string, error) {
|
||||
if _, err := os.Stat(s.whisperBinary); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user