package handler import ( "encoding/json" "io" "github.com/yourname/cyrene-ai/pkg/logger" "net/http" "path/filepath" "strings" "time" "github.com/yourname/cyrene-ai/voice-service/internal/config" "github.com/yourname/cyrene-ai/voice-service/internal/service" ) // STTHandler HTTP API 处理器 type STTHandler struct { svc *service.STTService ttsSvc *service.TTSService cfg *config.Config } // NewSTTHandler 创建 STT 处理器(可选传入 TTSService 用于组合状态) func NewSTTHandler(svc *service.STTService, cfg *config.Config) *STTHandler { return &STTHandler{svc: svc, cfg: cfg} } // SetTTSService 设置 TTS 服务引用,用于组合状态端点 func (h *STTHandler) SetTTSService(ttsSvc *service.TTSService) { h.ttsSvc = ttsSvc } // RegisterRoutes 注册所有路由到 mux func (h *STTHandler) RegisterRoutes(mux *http.ServeMux) { mux.HandleFunc("/api/v1/transcribe", h.handleTranscribe) mux.HandleFunc("/api/v1/health", h.handleHealth) mux.HandleFunc("/api/v1/status", h.handleStatus) } // handleTranscribe POST /api/v1/transcribe // 接受 multipart/form-data,字段 audio (文件) 和 language (可选) func (h *STTHandler) handleTranscribe(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return } // 限制上传大小 r.Body = http.MaxBytesReader(w, r.Body, h.cfg.MaxAudioSize) if err := r.ParseMultipartForm(h.cfg.MaxAudioSize); err != nil { writeError(w, http.StatusBadRequest, "文件过大或解析失败,最大支持 10MB") return } // 获取上传的文件 file, header, err := r.FormFile("audio") if err != nil { writeError(w, http.StatusBadRequest, "缺少 audio 文件字段") return } defer file.Close() // 读取文件内容 audioData, err := io.ReadAll(file) if err != nil { writeError(w, http.StatusInternalServerError, "读取音频文件失败") return } if len(audioData) == 0 { writeError(w, http.StatusBadRequest, "音频文件为空") return } // 获取语言参数 (可选) language := r.FormValue("language") // 推断音频格式 format := inferFormat(header.Filename) if !isSupportedFormat(format) { writeError(w, http.StatusBadRequest, "不支持的音频格式: "+format+",支持的格式: WAV, MP3, OGG, FLAC, M4A") return } // 执行转录 startTime := time.Now() text, err := h.svc.Transcribe(audioData, format, language) durationMs := time.Since(startTime).Milliseconds() if err != nil { logger.Printf("[stt-handler] 转录失败: %v", err) writeJSON(w, http.StatusInternalServerError, map[string]interface{}{ "success": false, "error": err.Error(), }) return } actualLang := language if actualLang == "" { actualLang = h.cfg.WhisperLanguage } writeJSON(w, http.StatusOK, map[string]interface{}{ "success": true, "text": text, "language": actualLang, "duration_ms": durationMs, }) } // handleHealth GET /api/v1/health func (h *STTHandler) handleHealth(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return } sttStatus := h.svc.GetStatus() healthStatus := "ok" if !sttStatus["available"].(bool) { healthStatus = "degraded" } resp := map[string]interface{}{ "status": healthStatus, "service": "voice-service", "stt": sttStatus, } // 如果有 TTS 服务,也包含 TTS 状态 if h.ttsSvc != nil { resp["tts"] = h.ttsSvc.GetEngineStatus() } writeJSON(w, http.StatusOK, resp) } // handleStatus GET /api/v1/status func (h *STTHandler) handleStatus(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeError(w, http.StatusMethodNotAllowed, "method not allowed") return } resp := map[string]interface{}{ "service": "voice-service", "stt": h.svc.GetStatus(), } // 如果有 TTS 服务,也包含 TTS 状态 if h.ttsSvc != nil { resp["tts"] = h.ttsSvc.GetEngineStatus() } writeJSON(w, http.StatusOK, resp) } // --- 辅助函数 --- func writeJSON(w http.ResponseWriter, status int, data interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(data) } func writeError(w http.ResponseWriter, status int, message string) { writeJSON(w, status, map[string]interface{}{ "error": message, }) } // inferFormat 根据文件名推断音频格式 func inferFormat(filename string) string { ext := strings.ToLower(filepath.Ext(filename)) switch ext { case ".wav", ".wave": return "wav" case ".mp3", ".mpeg": return "mp3" case ".ogg", ".opus": return "ogg" case ".flac": return "flac" case ".m4a", ".mp4", ".aac": return "m4a" default: return ext } } // isSupportedFormat 检查是否支持的音频格式 func isSupportedFormat(format string) bool { switch format { case "wav", "mp3", "ogg", "flac", "m4a": return true default: return false } }