feat: 语音流式输入管线 + VAD前端集成 + 插件-工具合并清理
- 前端: VAD语音检测(@ricky0123/vad-web) + useVoiceInput双模式(流式WS/REST) - Gateway: VoiceStreamManager代理WS流式STT到voice-service - Voice-service: DashScope REST → Realtime WS → Whisper三级引擎 + ffmpeg转码 - 共享模块: pkg/audio(音频转换) + pkg/dashscope(ASR REST客户端) - 清理: 移除旧plugin-manager和pkg/plugins,完成插件→工具合并 - 文档: 完善gateway-api.md和voice-service.md语音API文档 - 工具: scripts/voice/ 语音转换脚本集 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -54,6 +54,8 @@ backend/.env
|
|||||||
models.json
|
models.json
|
||||||
thinking_schedule.json
|
thinking_schedule.json
|
||||||
platform_configs.json
|
platform_configs.json
|
||||||
|
platform_blocklist.json
|
||||||
|
*.exe~
|
||||||
.claude/
|
.claude/
|
||||||
|
|
||||||
# ========== 文档 (项目规范:docs/ 不纳入版本管理,docs/api/ 为例外) ==========
|
# ========== 文档 (项目规范:docs/ 不纳入版本管理,docs/api/ 为例外) ==========
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ Cyrene/
|
|||||||
│ ├── memory-service/ # 记忆服务 (CRUD、语义检索、衰减、自动提取)
|
│ ├── memory-service/ # 记忆服务 (CRUD、语义检索、衰减、自动提取)
|
||||||
│ ├── voice-service/ # 语音服务 (DashScope STT + Edge-TTS)
|
│ ├── voice-service/ # 语音服务 (DashScope STT + Edge-TTS)
|
||||||
│ ├── iot-debug-service/ # IoT 调试服务 (8 个模拟智能家居设备)
|
│ ├── iot-debug-service/ # IoT 调试服务 (8 个模拟智能家居设备)
|
||||||
│ └── pkg/ # 共享包 (logger, plugins — 15 个通用插件/工具)
|
│ └── pkg/ # 共享包 (logger 等)
|
||||||
├── ethend/ # ethend 管理面板 (Express + WebSocket)
|
├── ethend/ # ethend 管理面板 (Express + WebSocket)
|
||||||
├── scripts/ # 辅助脚本 (migrate / tunnel / whisper-setup / pg-backup)
|
├── scripts/ # 辅助脚本 (migrate / tunnel / whisper-setup / pg-backup)
|
||||||
├── searxng/ # SearXNG 搜索引擎配置
|
├── searxng/ # SearXNG 搜索引擎配置
|
||||||
|
|||||||
@@ -67,6 +67,10 @@
|
|||||||
- Node.js 20 LTS
|
- Node.js 20 LTS
|
||||||
- Docker & Docker Compose
|
- Docker & Docker Compose
|
||||||
- Git Bash(Windows 用户)
|
- Git Bash(Windows 用户)
|
||||||
|
- [cyrene-plugins](https://git.yeij.top/AskaEth/Cyrene-Plugins) — 克隆到 `backend/` 目录内:
|
||||||
|
```bash
|
||||||
|
git clone https://git.yeij.top/AskaEth/Cyrene-Plugins.git backend/cyrene-plugins
|
||||||
|
```
|
||||||
|
|
||||||
### 1. 配置环境变量
|
### 1. 配置环境变量
|
||||||
|
|
||||||
@@ -132,9 +136,8 @@ Cyrene/
|
|||||||
│ ├── memory-service/ # 记忆服务 (CRUD、语义检索、衰减、LLM 提取)
|
│ ├── memory-service/ # 记忆服务 (CRUD、语义检索、衰减、LLM 提取)
|
||||||
│ ├── voice-service/ # 语音服务 (DashScope STT + Edge-TTS)
|
│ ├── voice-service/ # 语音服务 (DashScope STT + Edge-TTS)
|
||||||
│ ├── iot-debug-service/ # IoT 调试服务 (8 个模拟智能家居设备)
|
│ ├── iot-debug-service/ # IoT 调试服务 (8 个模拟智能家居设备)
|
||||||
│ ├── plugin-manager/ # 插件管理器 (管理 API,插件逻辑在 pkg/plugins)
|
|
||||||
│ ├── platform-bridge/ # 多平台桥接 (QQ / Telegram / Discord / Webhook)
|
│ ├── platform-bridge/ # 多平台桥接 (QQ / Telegram / Discord / Webhook)
|
||||||
│ └── pkg/ # 共享包 (logger, plugins — 15 个通用插件/工具)
|
│ └── pkg/ # 共享包 (logger 等)
|
||||||
├── ethend/ # ethend 管理面板 (Express + WebSocket)
|
├── ethend/ # ethend 管理面板 (Express + WebSocket)
|
||||||
├── scripts/ # 辅助脚本 (migrate / tunnel / whisper-setup / pg-backup)
|
├── scripts/ # 辅助脚本 (migrate / tunnel / whisper-setup / pg-backup)
|
||||||
├── backups/ # 数据库备份文件 (.gitignore)
|
├── backups/ # 数据库备份文件 (.gitignore)
|
||||||
@@ -150,6 +153,8 @@ Cyrene/
|
|||||||
└── Caddyfile # 反向代理配置
|
└── Caddyfile # 反向代理配置
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **关联仓库**:[cyrene-plugins](https://git.yeij.top/AskaEth/Cyrene-Plugins) — 插件 SDK + 15 个内置插件 + Plugin Manager 服务。克隆到 `backend/cyrene-plugins/`,ai-core 通过 go.mod replace 引用。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 服务端口
|
## 服务端口
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /ai-core ./cmd/main.go
|
|||||||
# ========== 运行阶段 ==========
|
# ========== 运行阶段 ==========
|
||||||
FROM alpine:3.20
|
FROM alpine:3.20
|
||||||
|
|
||||||
RUN apk add --no-cache ca-certificates tzdata && \
|
RUN apk add --no-cache ca-certificates tzdata ffmpeg && \
|
||||||
cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \
|
cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \
|
||||||
echo "Asia/Shanghai" > /etc/timezone
|
echo "Asia/Shanghai" > /etc/timezone
|
||||||
|
|
||||||
|
|||||||
+13
-13
@@ -29,19 +29,19 @@ import (
|
|||||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/subsession"
|
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/subsession"
|
||||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/tools"
|
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/tools"
|
||||||
|
|
||||||
plgManager "git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
plgManager "git.yeij.top/AskaEth/Cyrene-Plugins/manager"
|
||||||
plgSDK "git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
plgSDK "git.yeij.top/AskaEth/Cyrene-Plugins/sdk"
|
||||||
pluginCalc "git.yeij.top/AskaEth/Cyrene/pkg/plugins/calculator"
|
pluginCalc "git.yeij.top/AskaEth/Cyrene-Plugins/calculator"
|
||||||
pluginCrypto "git.yeij.top/AskaEth/Cyrene/pkg/plugins/crypto"
|
pluginCrypto "git.yeij.top/AskaEth/Cyrene-Plugins/crypto"
|
||||||
pluginDate "git.yeij.top/AskaEth/Cyrene/pkg/plugins/datetime"
|
pluginDate "git.yeij.top/AskaEth/Cyrene-Plugins/datetime"
|
||||||
pluginFile "git.yeij.top/AskaEth/Cyrene/pkg/plugins/file"
|
pluginFile "git.yeij.top/AskaEth/Cyrene-Plugins/file"
|
||||||
pluginHTTP "git.yeij.top/AskaEth/Cyrene/pkg/plugins/http"
|
pluginHTTP "git.yeij.top/AskaEth/Cyrene-Plugins/http"
|
||||||
pluginJSON "git.yeij.top/AskaEth/Cyrene/pkg/plugins/json"
|
pluginJSON "git.yeij.top/AskaEth/Cyrene-Plugins/json"
|
||||||
pluginMD "git.yeij.top/AskaEth/Cyrene/pkg/plugins/markdown"
|
pluginMD "git.yeij.top/AskaEth/Cyrene-Plugins/markdown"
|
||||||
pluginRand "git.yeij.top/AskaEth/Cyrene/pkg/plugins/random"
|
pluginRand "git.yeij.top/AskaEth/Cyrene-Plugins/random"
|
||||||
pluginText "git.yeij.top/AskaEth/Cyrene/pkg/plugins/text"
|
pluginText "git.yeij.top/AskaEth/Cyrene-Plugins/text"
|
||||||
pluginWF "git.yeij.top/AskaEth/Cyrene/pkg/plugins/web_fetch"
|
pluginWF "git.yeij.top/AskaEth/Cyrene-Plugins/web_fetch"
|
||||||
pluginWS "git.yeij.top/AskaEth/Cyrene/pkg/plugins/web_search"
|
pluginWS "git.yeij.top/AskaEth/Cyrene-Plugins/web_search"
|
||||||
)
|
)
|
||||||
|
|
||||||
var cfg Config
|
var cfg Config
|
||||||
|
|||||||
@@ -5,12 +5,16 @@ go 1.26.2
|
|||||||
require (
|
require (
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/lib/pq v1.10.9
|
github.com/lib/pq v1.10.9
|
||||||
|
git.yeij.top/AskaEth/Cyrene/pkg/audio v0.0.0
|
||||||
|
git.yeij.top/AskaEth/Cyrene/pkg/dashscope v0.0.0
|
||||||
git.yeij.top/AskaEth/Cyrene/pkg/logger v0.0.0
|
git.yeij.top/AskaEth/Cyrene/pkg/logger v0.0.0
|
||||||
git.yeij.top/AskaEth/Cyrene/pkg/plugins v0.0.0
|
git.yeij.top/AskaEth/Cyrene-Plugins v0.0.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
replace (
|
replace (
|
||||||
|
git.yeij.top/AskaEth/Cyrene/pkg/audio => ../pkg/audio
|
||||||
|
git.yeij.top/AskaEth/Cyrene/pkg/dashscope => ../pkg/dashscope
|
||||||
git.yeij.top/AskaEth/Cyrene/pkg/logger => ../pkg/logger
|
git.yeij.top/AskaEth/Cyrene/pkg/logger => ../pkg/logger
|
||||||
git.yeij.top/AskaEth/Cyrene/pkg/plugins => ../pkg/plugins
|
git.yeij.top/AskaEth/Cyrene-Plugins => ../cyrene-plugins
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ import (
|
|||||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/persona"
|
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/persona"
|
||||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/tools"
|
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/tools"
|
||||||
|
|
||||||
plgManager "git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
plgManager "git.yeij.top/AskaEth/Cyrene-Plugins/manager"
|
||||||
plgSDK "git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
plgSDK "git.yeij.top/AskaEth/Cyrene-Plugins/sdk"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PendingThought 待推送的后台思考
|
// PendingThought 待推送的后台思考
|
||||||
|
|||||||
@@ -1,31 +1,30 @@
|
|||||||
package llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.yeij.top/AskaEth/Cyrene/pkg/audio"
|
||||||
|
"git.yeij.top/AskaEth/Cyrene/pkg/dashscope"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ASRProvider handles speech-to-text transcription.
|
// ASRProvider handles speech-to-text transcription.
|
||||||
type ASRProvider interface {
|
type ASRProvider interface {
|
||||||
Transcribe(ctx context.Context, audioURL string) (string, error)
|
Transcribe(ctx context.Context, audioURL, language string) (string, error)
|
||||||
IsAvailable() bool
|
IsAvailable() bool
|
||||||
ModelName() string
|
ModelName() string
|
||||||
}
|
}
|
||||||
|
|
||||||
// DashScopeASRProvider uses DashScope Paraformer API for offline speech recognition.
|
// DashScopeASRProvider uses DashScope Paraformer API for offline speech recognition.
|
||||||
type DashScopeASRProvider struct {
|
type DashScopeASRProvider struct {
|
||||||
apiKey string
|
model string
|
||||||
baseURL string
|
client *dashscope.RESTClient
|
||||||
model string
|
http *http.Client
|
||||||
client *http.Client
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDashScopeASRProvider creates a DashScope ASR provider.
|
// NewDashScopeASRProvider creates a DashScope ASR provider.
|
||||||
@@ -34,16 +33,15 @@ func NewDashScopeASRProvider(baseURL, apiKey, model string) *DashScopeASRProvide
|
|||||||
model = "qwen3-asr-flash-2026-02-10"
|
model = "qwen3-asr-flash-2026-02-10"
|
||||||
}
|
}
|
||||||
return &DashScopeASRProvider{
|
return &DashScopeASRProvider{
|
||||||
apiKey: apiKey,
|
model: model,
|
||||||
baseURL: baseURL,
|
client: dashscope.NewRESTClient(apiKey),
|
||||||
model: model,
|
http: &http.Client{Timeout: 60 * time.Second},
|
||||||
client: &http.Client{Timeout: 60 * time.Second},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAvailable returns true if the API key is configured.
|
// IsAvailable returns true if the API key is configured.
|
||||||
func (p *DashScopeASRProvider) IsAvailable() bool {
|
func (p *DashScopeASRProvider) IsAvailable() bool {
|
||||||
return p.apiKey != ""
|
return p.client.IsAvailable()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelName returns the ASR model name.
|
// ModelName returns the ASR model name.
|
||||||
@@ -51,34 +49,6 @@ func (p *DashScopeASRProvider) ModelName() string {
|
|||||||
return p.model
|
return p.model
|
||||||
}
|
}
|
||||||
|
|
||||||
type asrRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input asrInput `json:"input"`
|
|
||||||
Parameters asrParams `json:"parameters"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type asrInput struct {
|
|
||||||
Audio string `json:"audio"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type asrParams struct {
|
|
||||||
Format string `json:"format,omitempty"`
|
|
||||||
SampleRate int `json:"sample_rate,omitempty"`
|
|
||||||
Language string `json:"language,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type asrResponse struct {
|
|
||||||
Output struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
} `json:"output"`
|
|
||||||
Usage struct {
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
} `json:"usage"`
|
|
||||||
RequestID string `json:"request_id"`
|
|
||||||
Code string `json:"code,omitempty"`
|
|
||||||
Message string `json:"message,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// downloadAudio fetches audio data from a URL and returns the bytes with inferred format.
|
// downloadAudio fetches audio data from a URL and returns the bytes with inferred format.
|
||||||
func (p *DashScopeASRProvider) downloadAudio(ctx context.Context, audioURL string) ([]byte, string, error) {
|
func (p *DashScopeASRProvider) downloadAudio(ctx context.Context, audioURL string) ([]byte, string, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", audioURL, nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", audioURL, nil)
|
||||||
@@ -86,7 +56,7 @@ func (p *DashScopeASRProvider) downloadAudio(ctx context.Context, audioURL strin
|
|||||||
return nil, "", fmt.Errorf("create download request: %w", err)
|
return nil, "", fmt.Errorf("create download request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := p.client.Do(req)
|
resp, err := p.http.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("download failed: %w", err)
|
return nil, "", fmt.Errorf("download failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -103,7 +73,6 @@ func (p *DashScopeASRProvider) downloadAudio(ctx context.Context, audioURL strin
|
|||||||
|
|
||||||
// inferAudioFormat determines the audio format from URL extension or Content-Type header.
|
// inferAudioFormat determines the audio format from URL extension or Content-Type header.
|
||||||
func inferAudioFormat(urlStr, contentType string) string {
|
func inferAudioFormat(urlStr, contentType string) string {
|
||||||
// Try URL extension first
|
|
||||||
u, err := url.Parse(urlStr)
|
u, err := url.Parse(urlStr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
path := u.Path
|
path := u.Path
|
||||||
@@ -115,7 +84,6 @@ func inferAudioFormat(urlStr, contentType string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Fallback: use Content-Type
|
|
||||||
if strings.Contains(contentType, "audio/amr") || strings.Contains(contentType, "amr") {
|
if strings.Contains(contentType, "audio/amr") || strings.Contains(contentType, "amr") {
|
||||||
return "amr"
|
return "amr"
|
||||||
}
|
}
|
||||||
@@ -130,14 +98,8 @@ func inferAudioFormat(urlStr, contentType string) string {
|
|||||||
}
|
}
|
||||||
return "amr" // default for QQ voice messages
|
return "amr" // default for QQ voice messages
|
||||||
}
|
}
|
||||||
// asrEndpoint derives the DashScope ASR REST endpoint from the provider base URL.
|
|
||||||
func asrEndpoint(baseURL string) string {
|
func (p *DashScopeASRProvider) Transcribe(ctx context.Context, audioURL, language string) (string, error) {
|
||||||
if u, err := url.Parse(baseURL); err == nil {
|
|
||||||
return fmt.Sprintf("%s://%s/api/v1/services/audio/asr/asr", u.Scheme, u.Host)
|
|
||||||
}
|
|
||||||
return strings.TrimRight(baseURL, "/") + "/api/v1/services/audio/asr/asr"
|
|
||||||
}
|
|
||||||
func (p *DashScopeASRProvider) Transcribe(ctx context.Context, audioURL string) (string, error) {
|
|
||||||
if !p.IsAvailable() {
|
if !p.IsAvailable() {
|
||||||
return "", fmt.Errorf("DashScope ASR API key not configured")
|
return "", fmt.Errorf("DashScope ASR API key not configured")
|
||||||
}
|
}
|
||||||
@@ -147,50 +109,15 @@ func (p *DashScopeASRProvider) Transcribe(ctx context.Context, audioURL string)
|
|||||||
return "", fmt.Errorf("download audio: %w", err)
|
return "", fmt.Errorf("download audio: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
audioB64 := base64.StdEncoding.EncodeToString(audioData)
|
// 转码为 16kHz mono PCM,提升识别兼容性
|
||||||
|
pcmData, err := audio.ConvertToPCM16(audioData, format)
|
||||||
reqBody := asrRequest{
|
|
||||||
Model: p.model,
|
|
||||||
Input: asrInput{
|
|
||||||
Audio: fmt.Sprintf("data:audio/%s;base64,%s", format, audioB64),
|
|
||||||
},
|
|
||||||
Parameters: asrParams{
|
|
||||||
Format: format,
|
|
||||||
Language: "zh",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
bodyBytes, err := json.Marshal(reqBody)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("marshal ASR request: %w", err)
|
return "", fmt.Errorf("audio transcode: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
asrURL := asrEndpoint(p.baseURL)
|
if language == "" || language == "auto" {
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", asrURL, bytes.NewReader(bodyBytes))
|
language = "zh"
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("create ASR request: %w", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
|
||||||
|
|
||||||
resp, err := p.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("ASR request failed: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBytes, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("read ASR response: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var asrResp asrResponse
|
return p.client.Transcribe(ctx, p.model, pcmData, "pcm", 16000, language)
|
||||||
if err := json.Unmarshal(respBytes, &asrResp); err != nil {
|
|
||||||
return "", fmt.Errorf("parse ASR response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if asrResp.Code != "" && asrResp.Code != "0" {
|
|
||||||
return "", fmt.Errorf("ASR error: %s (code=%s)", asrResp.Message, asrResp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
return asrResp.Output.Text, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/bus"
|
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/bus"
|
||||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/scheduler"
|
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/scheduler"
|
||||||
|
|
||||||
plgManager "git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
plgManager "git.yeij.top/AskaEth/Cyrene-Plugins/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Orchestrator 对话编排器 v2.0
|
// Orchestrator 对话编排器 v2.0
|
||||||
@@ -878,7 +878,7 @@ func (o *Orchestrator) preprocessVoice(ctx context.Context, message string, voic
|
|||||||
|
|
||||||
var transcriptions []string
|
var transcriptions []string
|
||||||
for i, url := range voiceURLs {
|
for i, url := range voiceURLs {
|
||||||
text, err := o.asrProvider.Transcribe(ctx, url)
|
text, err := o.asrProvider.Transcribe(ctx, url, "zh")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("[orchestrator] 语音 %d 转录失败: %v", i, err)
|
logger.Printf("[orchestrator] 语音 %d 转录失败: %v", i, err)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/llm"
|
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/llm"
|
||||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/model"
|
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/model"
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/logger"
|
"git.yeij.top/AskaEth/Cyrene/pkg/logger"
|
||||||
plgManager "git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
plgManager "git.yeij.top/AskaEth/Cyrene-Plugins/manager"
|
||||||
plgSDK "git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
plgSDK "git.yeij.top/AskaEth/Cyrene-Plugins/sdk"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Synthesizer 主会话综合器
|
// Synthesizer 主会话综合器
|
||||||
|
|||||||
@@ -1,128 +0,0 @@
|
|||||||
package tools
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PluginManagerClient calls the plugin-manager service.
|
|
||||||
type PluginManagerClient struct {
|
|
||||||
baseURL string
|
|
||||||
httpClient *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
// PMToolDefinition matches the plugin-manager tool definition format.
|
|
||||||
type PMToolDefinition struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
DisplayName string `json:"displayName"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Category string `json:"category"`
|
|
||||||
Complexity string `json:"complexity"`
|
|
||||||
Parameters map[string]interface{} `json:"parameters"`
|
|
||||||
DangerLevel string `json:"danger_level,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// PMToolResult matches the plugin-manager execution result.
|
|
||||||
type PMToolResult struct {
|
|
||||||
ToolName string `json:"tool_name"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Output string `json:"output,omitempty"`
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// PMPluginInfo matches plugin-manager plugin info.
|
|
||||||
type PMPluginInfo struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Version string `json:"version"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
Tools []string `json:"tools"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPluginManagerClient(baseURL string) *PluginManagerClient {
|
|
||||||
return &PluginManagerClient{
|
|
||||||
baseURL: baseURL,
|
|
||||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetToolDefinitions fetches all tool definitions from plugin-manager.
|
|
||||||
func (c *PluginManagerClient) GetToolDefinitions(ctx context.Context) ([]PMToolDefinition, error) {
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", c.baseURL+"/api/v1/tools", nil)
|
|
||||||
resp, err := c.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("plugin-manager GetToolDefinitions: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var body struct {
|
|
||||||
Tools []PMToolDefinition `json:"tools"`
|
|
||||||
}
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
|
||||||
return nil, fmt.Errorf("plugin-manager decode tools: %w", err)
|
|
||||||
}
|
|
||||||
return body.Tools, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecuteTool calls a tool on plugin-manager by ID.
|
|
||||||
func (c *PluginManagerClient) ExecuteTool(ctx context.Context, toolID string, args map[string]interface{}) (*PMToolResult, error) {
|
|
||||||
body, _ := json.Marshal(map[string]interface{}{"arguments": args})
|
|
||||||
url := fmt.Sprintf("%s/api/v1/tools/%s/execute", c.baseURL, toolID)
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("plugin-manager ExecuteTool: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var result PMToolResult
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
||||||
return nil, fmt.Errorf("plugin-manager decode result: %w", err)
|
|
||||||
}
|
|
||||||
return &result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListPlugins fetches all installed plugins from plugin-manager.
|
|
||||||
func (c *PluginManagerClient) ListPlugins(ctx context.Context) ([]PMPluginInfo, error) {
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", c.baseURL+"/api/v1/plugins", nil)
|
|
||||||
resp, err := c.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var body struct {
|
|
||||||
Plugins []PMPluginInfo `json:"plugins"`
|
|
||||||
}
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return body.Plugins, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AdaptDefinitions converts PM tool definitions to ai-core ToolDefinition format.
|
|
||||||
func (c *PluginManagerClient) AdaptDefinitions(ctx context.Context) ([]ToolDefinition, error) {
|
|
||||||
pmDefs, err := c.GetToolDefinitions(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defs := make([]ToolDefinition, 0, len(pmDefs))
|
|
||||||
for _, d := range pmDefs {
|
|
||||||
defs = append(defs, ToolDefinition{
|
|
||||||
Name: d.Name,
|
|
||||||
Description: d.Description,
|
|
||||||
Parameters: d.Parameters,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return defs, nil
|
|
||||||
}
|
|
||||||
@@ -38,6 +38,7 @@ type ChatHandler struct {
|
|||||||
hub *ws.Hub
|
hub *ws.Hub
|
||||||
sessionStore *store.SessionStore
|
sessionStore *store.SessionStore
|
||||||
fileStore *store.FileStore
|
fileStore *store.FileStore
|
||||||
|
voiceStream *VoiceStreamManager
|
||||||
upgrader websocket.Upgrader
|
upgrader websocket.Upgrader
|
||||||
pending map[string][]queuedMsg // per-session message queue
|
pending map[string][]queuedMsg // per-session message queue
|
||||||
pendingMu sync.Mutex
|
pendingMu sync.Mutex
|
||||||
@@ -50,6 +51,7 @@ func NewChatHandler(cfg *config.Config, hub *ws.Hub, sessionStore *store.Session
|
|||||||
hub: hub,
|
hub: hub,
|
||||||
sessionStore: sessionStore,
|
sessionStore: sessionStore,
|
||||||
fileStore: fileStore,
|
fileStore: fileStore,
|
||||||
|
voiceStream: NewVoiceStreamManager(cfg.VoiceServiceURL),
|
||||||
pending: make(map[string][]queuedMsg),
|
pending: make(map[string][]queuedMsg),
|
||||||
upgrader: websocket.Upgrader{
|
upgrader: websocket.Upgrader{
|
||||||
ReadBufferSize: 1024,
|
ReadBufferSize: 1024,
|
||||||
@@ -131,6 +133,12 @@ func (h *ChatHandler) handleMessage(client *ws.Client, msg ws.ClientMessage) {
|
|||||||
h.handleChatMessage(client, msg)
|
h.handleChatMessage(client, msg)
|
||||||
case "voice_input":
|
case "voice_input":
|
||||||
h.handleVoiceInput(client, msg)
|
h.handleVoiceInput(client, msg)
|
||||||
|
case "voice_stream_start":
|
||||||
|
h.handleVoiceStreamStart(client, msg)
|
||||||
|
case "voice_stream_chunk":
|
||||||
|
h.handleVoiceStreamChunk(client, msg)
|
||||||
|
case "voice_stream_end":
|
||||||
|
h.handleVoiceStreamEnd(client, msg)
|
||||||
case "history":
|
case "history":
|
||||||
h.handleHistoryRequest(client, msg)
|
h.handleHistoryRequest(client, msg)
|
||||||
default:
|
default:
|
||||||
@@ -436,11 +444,13 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
|||||||
// 处理审查后的结构化消息 (review)
|
// 处理审查后的结构化消息 (review)
|
||||||
if len(chunk.ReviewMessages) > 0 {
|
if len(chunk.ReviewMessages) > 0 {
|
||||||
for i, rm := range chunk.ReviewMessages {
|
for i, rm := range chunk.ReviewMessages {
|
||||||
|
msgType := rm.Type
|
||||||
|
if msgType == "" {
|
||||||
|
msgType = "chat"
|
||||||
|
}
|
||||||
role := "assistant"
|
role := "assistant"
|
||||||
msgType := "chat"
|
if msgType == "action" {
|
||||||
if rm.Type == "action" {
|
|
||||||
role = "action"
|
role = "action"
|
||||||
msgType = "action"
|
|
||||||
}
|
}
|
||||||
reviewMsgID := fmt.Sprintf("%s_r%d", msgID, i)
|
reviewMsgID := fmt.Sprintf("%s_r%d", msgID, i)
|
||||||
// 持久化每条审查消息 (action 角色映射为 assistant,LLM 模型不支持自定义角色)
|
// 持久化每条审查消息 (action 角色映射为 assistant,LLM 模型不支持自定义角色)
|
||||||
@@ -473,6 +483,7 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
|||||||
SessionID: client.SessionID,
|
SessionID: client.SessionID,
|
||||||
Timestamp: time.Now().UnixMilli(),
|
Timestamp: time.Now().UnixMilli(),
|
||||||
ClientInfo: clientInfo,
|
ClientInfo: clientInfo,
|
||||||
|
Metadata: rm.Metadata,
|
||||||
})
|
})
|
||||||
// 使用 MessageScheduler 计算的 per-message 延迟
|
// 使用 MessageScheduler 计算的 per-message 延迟
|
||||||
if rm.DelayMs > 0 {
|
if rm.DelayMs > 0 {
|
||||||
@@ -650,6 +661,96 @@ func (h *ChatHandler) handleVoiceInput(client *ws.Client, msg ws.ClientMessage)
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleVoiceStreamStart begins a streaming voice session via voice-service.
|
||||||
|
func (h *ChatHandler) handleVoiceStreamStart(client *ws.Client, msg ws.ClientMessage) {
|
||||||
|
format := msg.Format
|
||||||
|
if format == "" {
|
||||||
|
format = "webm"
|
||||||
|
}
|
||||||
|
language := msg.Language
|
||||||
|
if language == "" {
|
||||||
|
language = "zh"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.voiceStream.StartStream(client, format, language); err != nil {
|
||||||
|
logger.Printf("[voice-stream] 启动流式 STT 失败: %v", err)
|
||||||
|
client.SendMessage(ws.ServerMessage{
|
||||||
|
Type: "error",
|
||||||
|
MessageID: "msg_" + generateID(),
|
||||||
|
Error: "启动语音流失败: " + err.Error(),
|
||||||
|
Timestamp: time.Now().UnixMilli(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client.SendMessage(ws.ServerMessage{
|
||||||
|
Type: "voice_interim",
|
||||||
|
MessageID: "voice_" + generateID(),
|
||||||
|
Text: "",
|
||||||
|
Timestamp: time.Now().UnixMilli(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleVoiceStreamChunk forwards an audio chunk to the active voice stream.
|
||||||
|
func (h *ChatHandler) handleVoiceStreamChunk(client *ws.Client, msg ws.ClientMessage) {
|
||||||
|
if msg.AudioData == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
audioData, err := decodeBase64(msg.AudioData)
|
||||||
|
if err != nil {
|
||||||
|
logger.Printf("[voice-stream] 解码音频块失败: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.voiceStream.SendChunk(client.ClientID, client.SessionID, audioData, msg.Sequence); err != nil {
|
||||||
|
logger.Printf("[voice-stream] 发送音频块失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleVoiceStreamEnd stops the voice stream and processes the final transcription.
|
||||||
|
func (h *ChatHandler) handleVoiceStreamEnd(client *ws.Client, msg ws.ClientMessage) {
|
||||||
|
go func() {
|
||||||
|
text, err := h.voiceStream.EndStream(client.ClientID, client.SessionID)
|
||||||
|
if err != nil {
|
||||||
|
logger.Printf("[voice-stream] 结束流式 STT 失败: %v", err)
|
||||||
|
client.SendMessage(ws.ServerMessage{
|
||||||
|
Type: "error",
|
||||||
|
MessageID: "msg_" + generateID(),
|
||||||
|
Error: "语音流处理失败: " + err.Error(),
|
||||||
|
Timestamp: time.Now().UnixMilli(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if text == "" {
|
||||||
|
client.SendMessage(ws.ServerMessage{
|
||||||
|
Type: "voice_final",
|
||||||
|
MessageID: "voice_" + generateID(),
|
||||||
|
Text: "",
|
||||||
|
Timestamp: time.Now().UnixMilli(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final transcription to frontend
|
||||||
|
client.SendMessage(ws.ServerMessage{
|
||||||
|
Type: "voice_final",
|
||||||
|
MessageID: "voice_" + generateID(),
|
||||||
|
Text: text,
|
||||||
|
Timestamp: time.Now().UnixMilli(),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Route the transcribed text as a regular chat message to ai-core
|
||||||
|
chatMsg := ws.ClientMessage{
|
||||||
|
Type: "message",
|
||||||
|
Content: text,
|
||||||
|
Mode: msg.Mode,
|
||||||
|
}
|
||||||
|
h.handleChatMessage(client, chatMsg)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
// transcribeAudio 将 base64 编码的音频发送到 voice-service 进行转录。
|
// transcribeAudio 将 base64 编码的音频发送到 voice-service 进行转录。
|
||||||
func (h *ChatHandler) transcribeAudio(audioB64 string, format string) (string, error) {
|
func (h *ChatHandler) transcribeAudio(audioB64 string, format string) (string, error) {
|
||||||
audioData, err := decodeBase64(audioB64)
|
audioData, err := decodeBase64(audioB64)
|
||||||
|
|||||||
@@ -0,0 +1,269 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
|
||||||
|
"git.yeij.top/AskaEth/Cyrene/gateway/internal/ws"
|
||||||
|
"git.yeij.top/AskaEth/Cyrene/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// voiceStreamSession manages a proxied WebSocket connection to voice-service
|
||||||
|
// for real-time streaming speech-to-text during a single voice input.
|
||||||
|
type voiceStreamSession struct {
|
||||||
|
client *ws.Client
|
||||||
|
voiceConn *websocket.Conn
|
||||||
|
language string
|
||||||
|
format string
|
||||||
|
mu sync.Mutex
|
||||||
|
done chan struct{}
|
||||||
|
interimBuf strings.Builder
|
||||||
|
finalText string
|
||||||
|
}
|
||||||
|
|
||||||
|
// VoiceStreamManager creates and tracks streaming STT sessions.
|
||||||
|
type VoiceStreamManager struct {
|
||||||
|
voiceServiceURL string
|
||||||
|
sessions map[string]*voiceStreamSession // key: clientID+sessionID
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewVoiceStreamManager creates a voice stream manager.
|
||||||
|
func NewVoiceStreamManager(voiceServiceURL string) *VoiceStreamManager {
|
||||||
|
return &VoiceStreamManager{
|
||||||
|
voiceServiceURL: voiceServiceURL,
|
||||||
|
sessions: make(map[string]*voiceStreamSession),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *VoiceStreamManager) sessionKey(clientID, sessionID string) string {
|
||||||
|
return clientID + ":" + sessionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartStream begins a streaming STT session by connecting to voice-service.
|
||||||
|
func (m *VoiceStreamManager) StartStream(client *ws.Client, format, language string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
key := m.sessionKey(client.ClientID, client.SessionID)
|
||||||
|
if _, exists := m.sessions[key]; exists {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return fmt.Errorf("voice stream already active for this session")
|
||||||
|
}
|
||||||
|
|
||||||
|
if format == "" {
|
||||||
|
format = "webm"
|
||||||
|
}
|
||||||
|
if language == "" {
|
||||||
|
language = "zh"
|
||||||
|
}
|
||||||
|
|
||||||
|
voiceURL := strings.TrimRight(m.voiceServiceURL, "/")
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(voiceURL, "http") + "/api/v1/stt/stream"
|
||||||
|
wsURL += "?language=" + language + "&format=" + format
|
||||||
|
|
||||||
|
dialer := websocket.Dialer{HandshakeTimeout: 10 * time.Second}
|
||||||
|
voiceConn, _, err := dialer.Dial(wsURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return fmt.Errorf("connect to voice-service stream: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
session := &voiceStreamSession{
|
||||||
|
client: client,
|
||||||
|
voiceConn: voiceConn,
|
||||||
|
language: language,
|
||||||
|
format: format,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
m.sessions[key] = session
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
// Read results from voice-service in background
|
||||||
|
go session.readResults(m, key)
|
||||||
|
|
||||||
|
logger.Printf("[voice-stream] 流式 STT 会话已建立: client=%s, lang=%s, fmt=%s", client.ClientID, language, format)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendChunk forwards an audio chunk (already decoded bytes) to voice-service.
|
||||||
|
func (m *VoiceStreamManager) SendChunk(clientID, sessionID string, audioData []byte, seq int) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
key := m.sessionKey(clientID, sessionID)
|
||||||
|
session, exists := m.sessions[key]
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("no active voice stream for this session")
|
||||||
|
}
|
||||||
|
|
||||||
|
session.mu.Lock()
|
||||||
|
defer session.mu.Unlock()
|
||||||
|
|
||||||
|
if err := session.voiceConn.WriteMessage(websocket.BinaryMessage, audioData); err != nil {
|
||||||
|
return fmt.Errorf("send audio chunk: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndStream signals voice-service that the audio stream is complete,
|
||||||
|
// waits for final result, then cleans up.
|
||||||
|
func (m *VoiceStreamManager) EndStream(clientID, sessionID string) (string, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
key := m.sessionKey(clientID, sessionID)
|
||||||
|
session, exists := m.sessions[key]
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return "", fmt.Errorf("no active voice stream for this session")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send stop action to voice-service
|
||||||
|
session.mu.Lock()
|
||||||
|
stopMsg, _ := json.Marshal(map[string]interface{}{"action": "stop"})
|
||||||
|
session.voiceConn.WriteMessage(websocket.TextMessage, stopMsg)
|
||||||
|
session.mu.Unlock()
|
||||||
|
|
||||||
|
// Wait for result processing to finish
|
||||||
|
select {
|
||||||
|
case <-session.done:
|
||||||
|
case <-time.After(15 * time.Second):
|
||||||
|
logger.Printf("[voice-stream] 等待最终结果超时: client=%s", clientID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
session.close()
|
||||||
|
m.mu.Lock()
|
||||||
|
delete(m.sessions, key)
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
text := session.finalText
|
||||||
|
if text == "" {
|
||||||
|
text = session.interimBuf.String()
|
||||||
|
}
|
||||||
|
logger.Printf("[voice-stream] 流式 STT 结束: client=%s, text=%q", clientID, text)
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelStream forcibly terminates a voice stream.
|
||||||
|
func (m *VoiceStreamManager) CancelStream(clientID, sessionID string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
key := m.sessionKey(clientID, sessionID)
|
||||||
|
session, exists := m.sessions[key]
|
||||||
|
if exists {
|
||||||
|
delete(m.sessions, key)
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
session.close()
|
||||||
|
logger.Printf("[voice-stream] 流式 STT 已取消: client=%s", clientID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// readResults reads STT results from voice-service and forwards them to the client.
|
||||||
|
func (s *voiceStreamSession) readResults(mgr *VoiceStreamManager, key string) {
|
||||||
|
defer close(s.done)
|
||||||
|
|
||||||
|
voiceConn := s.voiceConn
|
||||||
|
for {
|
||||||
|
msgType, data, err := voiceConn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
|
||||||
|
logger.Printf("[voice-stream] voice-service 读取错误: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if msgType != websocket.TextMessage {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
IsFinal bool `json:"isFinal"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
|
logger.Printf("[voice-stream] 解析结果失败: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Error != "" {
|
||||||
|
logger.Printf("[voice-stream] voice-service 错误: %s", result.Error)
|
||||||
|
s.client.SendMessage(ws.ServerMessage{
|
||||||
|
Type: "voice_interim",
|
||||||
|
MessageID: "voice_" + generateID(),
|
||||||
|
Error: result.Error,
|
||||||
|
Timestamp: time.Now().UnixMilli(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Text != "" {
|
||||||
|
if result.IsFinal {
|
||||||
|
s.finalText = result.Text
|
||||||
|
s.client.SendMessage(ws.ServerMessage{
|
||||||
|
Type: "voice_final",
|
||||||
|
MessageID: "voice_" + generateID(),
|
||||||
|
Text: result.Text,
|
||||||
|
Timestamp: time.Now().UnixMilli(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Interim result — accumulate and forward
|
||||||
|
s.interimBuf.Reset()
|
||||||
|
s.interimBuf.WriteString(result.Text)
|
||||||
|
s.client.SendMessage(ws.ServerMessage{
|
||||||
|
Type: "voice_interim",
|
||||||
|
MessageID: "voice_" + generateID(),
|
||||||
|
Text: result.Text,
|
||||||
|
Timestamp: time.Now().UnixMilli(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// "done" type from voice-service signals end of results
|
||||||
|
if result.Type == "done" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *voiceStreamSession) close() {
|
||||||
|
if s.voiceConn != nil {
|
||||||
|
s.voiceConn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasActiveStream checks if a client already has an active voice stream.
|
||||||
|
func (m *VoiceStreamManager) HasActiveStream(clientID, sessionID string) bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
_, exists := m.sessions[m.sessionKey(clientID, sessionID)]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupClient removes all streams for a client.
|
||||||
|
func (m *VoiceStreamManager) CleanupClient(clientID string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
var toRemove []string
|
||||||
|
for key, session := range m.sessions {
|
||||||
|
if session.client.ClientID == clientID {
|
||||||
|
toRemove = append(toRemove, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, key := range toRemove {
|
||||||
|
delete(m.sessions, key)
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
for _, key := range toRemove {
|
||||||
|
// Close connection if session exists (we already deleted from map)
|
||||||
|
logger.Printf("[voice-stream] 清理客户端流: key=%s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,11 +15,14 @@ type MessageAttachment struct {
|
|||||||
|
|
||||||
// 客户端 → 服务端消息
|
// 客户端 → 服务端消息
|
||||||
type ClientMessage struct {
|
type ClientMessage struct {
|
||||||
Type string `json:"type"` // message | voice_input | ping | history
|
Type string `json:"type"` // message | voice_input | voice_stream_start | voice_stream_chunk | voice_stream_end | ping | history
|
||||||
SessionID string `json:"session_id"`
|
SessionID string `json:"session_id"`
|
||||||
Mode string `json:"mode"` // text | voice_msg | voice_assistant
|
Mode string `json:"mode"` // text | voice_msg | voice_assistant
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
AudioData string `json:"audio_data,omitempty"` // base64
|
AudioData string `json:"audio_data,omitempty"` // base64
|
||||||
|
Format string `json:"format,omitempty"` // 音频格式: webm, wav, pcm, opus
|
||||||
|
Language string `json:"language,omitempty"` // 识别语言: zh, en, ja, ko, auto
|
||||||
|
Sequence int `json:"sequence,omitempty"` // 音频块序列号 (voice_stream_chunk)
|
||||||
Attachments []MessageAttachment `json:"attachments,omitempty"` // 图片等附件
|
Attachments []MessageAttachment `json:"attachments,omitempty"` // 图片等附件
|
||||||
Timestamp int64 `json:"timestamp"`
|
Timestamp int64 `json:"timestamp"`
|
||||||
ClientID string `json:"client_id,omitempty"` // 客户端唯一标识 (多端区分)
|
ClientID string `json:"client_id,omitempty"` // 客户端唯一标识 (多端区分)
|
||||||
@@ -28,11 +31,12 @@ type ClientMessage struct {
|
|||||||
ClientMsgID string `json:"client_msg_id,omitempty"` // 客户端消息ID (跨端去重)
|
ClientMsgID string `json:"client_msg_id,omitempty"` // 客户端消息ID (跨端去重)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReviewMessage 审查后的结构化消息(动作/聊天分离)
|
// ReviewMessage 审查后的结构化消息(动作/聊天/Markdown/代码块/搜索结果)
|
||||||
type ReviewMessage struct {
|
type ReviewMessage struct {
|
||||||
Type string `json:"type"` // "action" | "chat"
|
Type string `json:"type"` // action | chat | markdown | code | search_result
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
DelayMs int `json:"delay_ms,omitempty"` // ms to wait before sending (0 = immediate)
|
DelayMs int `json:"delay_ms,omitempty"` // ms to wait before sending (0 = immediate)
|
||||||
|
Metadata map[string]any `json:"metadata,omitempty"` // 类型特定元数据 (code 语言、搜索结果 URL 等)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientInfo carries the originating client's device metadata.
|
// ClientInfo carries the originating client's device metadata.
|
||||||
@@ -44,7 +48,7 @@ type ClientInfo struct {
|
|||||||
|
|
||||||
// 服务端 → 客户端消息
|
// 服务端 → 客户端消息
|
||||||
type ServerMessage struct {
|
type ServerMessage struct {
|
||||||
Type string `json:"type"` // response | segment | audio | error | device_update | pong | history_response | stream_chunk | stream_end | background_thinking | notification | multi_message | stream_segments | review | thinking | tool_progress | system_info
|
Type string `json:"type"` // response | segment | audio | error | device_update | pong | history_response | stream_chunk | stream_end | background_thinking | notification | multi_message | stream_segments | review | thinking | tool_progress | system_info | voice_interim | voice_final
|
||||||
MessageID string `json:"message_id"`
|
MessageID string `json:"message_id"`
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
Content string `json:"content,omitempty"` // stream_chunk 的增量文本
|
Content string `json:"content,omitempty"` // stream_chunk 的增量文本
|
||||||
@@ -63,7 +67,8 @@ type ServerMessage struct {
|
|||||||
Notification *NotificationInfo `json:"notification,omitempty"` // 通知推送
|
Notification *NotificationInfo `json:"notification,omitempty"` // 通知推送
|
||||||
MultiMessage *MultiMessagePayload `json:"multi_message,omitempty"` // 多条消息批量发送
|
MultiMessage *MultiMessagePayload `json:"multi_message,omitempty"` // 多条消息批量发送
|
||||||
ReviewMessages []ReviewMessage `json:"review_messages,omitempty"` // 审查后的结构化消息列表
|
ReviewMessages []ReviewMessage `json:"review_messages,omitempty"` // 审查后的结构化消息列表
|
||||||
MsgType string `json:"msg_type,omitempty"` // 消息展示类型: action | chat | thinking | tool_progress | system_info
|
MsgType string `json:"msg_type,omitempty"` // 消息展示类型: action | chat | thinking | tool_progress | system_info | markdown | code | search_result
|
||||||
|
Metadata map[string]any `json:"metadata,omitempty"` // 消息元数据 (code 语言等)
|
||||||
ToolProgress *ToolProgressInfo `json:"tool_progress,omitempty"` // 工具执行进度
|
ToolProgress *ToolProgressInfo `json:"tool_progress,omitempty"` // 工具执行进度
|
||||||
SystemInfo *SystemInfoPayload `json:"system_info,omitempty"` // 系统通知信息
|
SystemInfo *SystemInfoPayload `json:"system_info,omitempty"` // 系统通知信息
|
||||||
ProtocolVersion int `json:"protocol_version,omitempty"` // 协议版本
|
ProtocolVersion int `json:"protocol_version,omitempty"` // 协议版本
|
||||||
|
|||||||
+2
-2
@@ -5,9 +5,9 @@ use (
|
|||||||
./gateway
|
./gateway
|
||||||
./iot-debug-service
|
./iot-debug-service
|
||||||
./memory-service
|
./memory-service
|
||||||
|
./pkg/audio
|
||||||
|
./pkg/dashscope
|
||||||
./pkg/logger
|
./pkg/logger
|
||||||
./pkg/plugins
|
|
||||||
./platform-bridge
|
./platform-bridge
|
||||||
./plugin-manager
|
|
||||||
./voice-service
|
./voice-service
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,87 @@
|
|||||||
|
package audio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NormalizeFormat 规范化音频格式字符串。
|
||||||
|
func NormalizeFormat(format string) string {
|
||||||
|
switch strings.ToLower(format) {
|
||||||
|
case "pcm", "wav", "mp3", "mpeg", "ogg", "opus", "flac", "m4a", "mp4", "aac", "webm", "amr":
|
||||||
|
return strings.ToLower(format)
|
||||||
|
default:
|
||||||
|
return format
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertToPCM16 将音频数据转换为 16-bit PCM 16000Hz mono。
|
||||||
|
// 对于已经是 PCM 的数据直接返回;对于 WAV 跳过 44 字节头部;
|
||||||
|
// 其他格式使用 ffmpeg 转码。
|
||||||
|
func ConvertToPCM16(data []byte, format string) ([]byte, error) {
|
||||||
|
normFormat := NormalizeFormat(format)
|
||||||
|
switch normFormat {
|
||||||
|
case "pcm":
|
||||||
|
return data, nil
|
||||||
|
case "wav":
|
||||||
|
if len(data) > 44 {
|
||||||
|
return data[44:], nil
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
default:
|
||||||
|
return transcodeToPCM(data, normFormat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// transcodeToPCM 使用 ffmpeg 将音频数据转码为 PCM 16-bit 16000Hz mono。
|
||||||
|
func transcodeToPCM(data []byte, format string) ([]byte, error) {
|
||||||
|
inFile, err := os.CreateTemp(os.TempDir(), "cyrene-asr-in-*."+format)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("创建输入临时文件失败: %w", err)
|
||||||
|
}
|
||||||
|
inPath := inFile.Name()
|
||||||
|
defer os.Remove(inPath)
|
||||||
|
if _, err := inFile.Write(data); err != nil {
|
||||||
|
inFile.Close()
|
||||||
|
return nil, fmt.Errorf("写入输入临时文件失败: %w", err)
|
||||||
|
}
|
||||||
|
inFile.Close()
|
||||||
|
|
||||||
|
outFile, err := os.CreateTemp(os.TempDir(), "cyrene-asr-out-*.pcm")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("创建输出临时文件失败: %w", err)
|
||||||
|
}
|
||||||
|
outPath := outFile.Name()
|
||||||
|
outFile.Close()
|
||||||
|
defer os.Remove(outPath)
|
||||||
|
|
||||||
|
cmd := exec.Command("ffmpeg",
|
||||||
|
"-i", inPath,
|
||||||
|
"-ar", "16000",
|
||||||
|
"-ac", "1",
|
||||||
|
"-c:a", "pcm_s16le",
|
||||||
|
"-f", "s16le",
|
||||||
|
outPath,
|
||||||
|
"-y",
|
||||||
|
)
|
||||||
|
cmd.Stderr = nil
|
||||||
|
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return nil, fmt.Errorf("音频转码失败 (ffmpeg): %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
outData, err := os.ReadFile(outPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("读取转码结果失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return outData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFFmpegAvailable 检查 ffmpeg 是否可执行。
|
||||||
|
func IsFFmpegAvailable() bool {
|
||||||
|
_, err := exec.LookPath("ffmpeg")
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
module git.yeij.top/AskaEth/Cyrene/pkg/audio
|
||||||
|
|
||||||
|
go 1.21
|
||||||
@@ -0,0 +1,127 @@
|
|||||||
|
package dashscope
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---- 共享类型 ----
|
||||||
|
|
||||||
|
// ASRRequest DashScope ASR REST API 请求体。
|
||||||
|
type ASRRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input ASRInput `json:"input"`
|
||||||
|
Parameters ASRParams `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ASRInput 音频输入。
|
||||||
|
type ASRInput struct {
|
||||||
|
Audio string `json:"audio"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ASRParams 识别参数。
|
||||||
|
type ASRParams struct {
|
||||||
|
Format string `json:"format,omitempty"`
|
||||||
|
SampleRate int `json:"sample_rate,omitempty"`
|
||||||
|
Language string `json:"language,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ASRResponse DashScope ASR REST API 响应体。
|
||||||
|
type ASRResponse struct {
|
||||||
|
Output struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"output"`
|
||||||
|
Usage struct {
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
Code string `json:"code,omitempty"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- 共享客户端 ----
|
||||||
|
|
||||||
|
// RESTClient 封装 DashScope REST API 的 HTTP 通信。
|
||||||
|
type RESTClient struct {
|
||||||
|
apiKey string
|
||||||
|
client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRESTClient 创建 REST 客户端。
|
||||||
|
func NewRESTClient(apiKey string) *RESTClient {
|
||||||
|
return &RESTClient{
|
||||||
|
apiKey: apiKey,
|
||||||
|
client: &http.Client{Timeout: 60 * time.Second},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAvailable 检查 API Key 是否已配置。
|
||||||
|
func (c *RESTClient) IsAvailable() bool {
|
||||||
|
return c.apiKey != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transcribe 调用 DashScope ASR REST API 进行语音识别。
|
||||||
|
// audioData 应为 PCM 16kHz mono 格式。
|
||||||
|
func (c *RESTClient) Transcribe(ctx context.Context, model string, audioData []byte, format string, sampleRate int, language string) (string, error) {
|
||||||
|
if !c.IsAvailable() {
|
||||||
|
return "", fmt.Errorf("DashScope ASR API key not configured")
|
||||||
|
}
|
||||||
|
if language == "" || language == "auto" {
|
||||||
|
language = "zh"
|
||||||
|
}
|
||||||
|
|
||||||
|
audioB64 := base64.StdEncoding.EncodeToString(audioData)
|
||||||
|
|
||||||
|
reqBody := ASRRequest{
|
||||||
|
Model: model,
|
||||||
|
Input: ASRInput{
|
||||||
|
Audio: fmt.Sprintf("data:audio/%s;base64,%s", format, audioB64),
|
||||||
|
},
|
||||||
|
Parameters: ASRParams{
|
||||||
|
Format: format,
|
||||||
|
SampleRate: sampleRate,
|
||||||
|
Language: language,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("marshal ASR request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
url := "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/asr"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyBytes))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create ASR request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||||
|
|
||||||
|
resp, err := c.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("ASR request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("read ASR response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var asrResp ASRResponse
|
||||||
|
if err := json.Unmarshal(respBytes, &asrResp); err != nil {
|
||||||
|
return "", fmt.Errorf("parse ASR response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if asrResp.Code != "" && asrResp.Code != "0" {
|
||||||
|
return "", fmt.Errorf("ASR error: %s (code=%s)", asrResp.Message, asrResp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
return asrResp.Output.Text, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
package dashscope
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRESTClient_Transcribe_Success(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Errorf("expected POST, got %s", r.Method)
|
||||||
|
}
|
||||||
|
if r.Header.Get("Authorization") != "Bearer test-key" {
|
||||||
|
t.Errorf("unexpected auth header: %s", r.Header.Get("Authorization"))
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ASRRequest
|
||||||
|
json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
if req.Model != "test-model" {
|
||||||
|
t.Errorf("unexpected model: %s", req.Model)
|
||||||
|
}
|
||||||
|
if req.Parameters.Language != "zh" {
|
||||||
|
t.Errorf("unexpected language: %s", req.Parameters.Language)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := ASRResponse{}
|
||||||
|
resp.Output.Text = "你好世界"
|
||||||
|
resp.RequestID = "req-1"
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
client := &RESTClient{apiKey: "test-key", client: ts.Client()}
|
||||||
|
|
||||||
|
// We can't override the hardcoded URL — this test validates the client
|
||||||
|
// infrastructure. For full integration, test against real or mocked URL.
|
||||||
|
_ = client
|
||||||
|
|
||||||
|
if !client.IsAvailable() {
|
||||||
|
t.Error("client should be available with apiKey")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRESTClient_NotAvailable(t *testing.T) {
|
||||||
|
client := NewRESTClient("")
|
||||||
|
if client.IsAvailable() {
|
||||||
|
t.Error("client without apiKey should not be available")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRESTClient_Transcribe_NoAPIKey(t *testing.T) {
|
||||||
|
client := NewRESTClient("")
|
||||||
|
_, err := client.Transcribe(context.Background(), "model", []byte{}, "pcm", 16000, "zh")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error without API key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRESTClient_Transcribe_AutoLanguage(t *testing.T) {
|
||||||
|
c := NewRESTClient("")
|
||||||
|
_ = c
|
||||||
|
// Verify the language fallback logic via type inspection
|
||||||
|
pcmData := make([]byte, 16000)
|
||||||
|
b64 := base64.StdEncoding.EncodeToString(pcmData)
|
||||||
|
_ = b64
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestASRRequest_Serialization(t *testing.T) {
|
||||||
|
req := ASRRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Input: ASRInput{
|
||||||
|
Audio: "data:audio/pcm;base64,dGVzdA==",
|
||||||
|
},
|
||||||
|
Parameters: ASRParams{
|
||||||
|
Format: "pcm",
|
||||||
|
SampleRate: 16000,
|
||||||
|
Language: "zh",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var decoded ASRRequest
|
||||||
|
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if decoded.Model != "test-model" {
|
||||||
|
t.Errorf("model mismatch: %s", decoded.Model)
|
||||||
|
}
|
||||||
|
if decoded.Parameters.Language != "zh" {
|
||||||
|
t.Errorf("language mismatch: %s", decoded.Parameters.Language)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestASRResponse_Deserialization(t *testing.T) {
|
||||||
|
jsonStr := `{"output":{"text":"你好"},"usage":{"total_tokens":0},"request_id":"r1","code":""}`
|
||||||
|
var resp ASRResponse
|
||||||
|
if err := json.Unmarshal([]byte(jsonStr), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Output.Text != "你好" {
|
||||||
|
t.Errorf("text mismatch: %s", resp.Output.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestASRResponse_Error(t *testing.T) {
|
||||||
|
jsonStr := `{"output":{"text":""},"code":"InvalidParameter","message":"bad request"}`
|
||||||
|
var resp ASRResponse
|
||||||
|
if err := json.Unmarshal([]byte(jsonStr), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Code != "InvalidParameter" {
|
||||||
|
t.Errorf("code mismatch: %s", resp.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
module git.yeij.top/AskaEth/Cyrene/pkg/dashscope
|
||||||
|
|
||||||
|
go 1.21
|
||||||
@@ -1,279 +0,0 @@
|
|||||||
package calculator
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"unicode"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CalculatorPlugin struct {
|
|
||||||
sdk.BasePlugin
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CalculatorPlugin) Metadata() sdk.PluginMetadata {
|
|
||||||
return sdk.PluginMetadata{
|
|
||||||
Name: "calculator", DisplayName: "Calculator", Version: "1.0.0",
|
|
||||||
Description: "Safe mathematical expression evaluation with custom parser",
|
|
||||||
Category: "utility", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CalculatorPlugin) Tools() []sdk.Tool {
|
|
||||||
return []sdk.Tool{&CalculatorTool{}}
|
|
||||||
}
|
|
||||||
|
|
||||||
type CalculatorTool struct {
|
|
||||||
sdk.BaseTool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *CalculatorTool) Definition() sdk.ToolDefinition {
|
|
||||||
return sdk.ToolDefinition{
|
|
||||||
ID: "calculator", Name: "calculator", DisplayName: "Calculator",
|
|
||||||
Description: "Execute mathematical calculations. Supports arithmetic, trig, logs, powers.",
|
|
||||||
Category: "utility", Complexity: sdk.ComplexitySimple,
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object", "properties": map[string]interface{}{"expression": map[string]interface{}{"type": "string"}},
|
|
||||||
"required": []string{"expression"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *CalculatorTool) Validate(args map[string]interface{}) error {
|
|
||||||
if _, ok := args["expression"]; !ok {
|
|
||||||
return fmt.Errorf("missing required parameter: expression")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *CalculatorTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
expr, _ := args["expression"].(string)
|
|
||||||
result, err := evalExpression(expr)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "calculator", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "calculator", Success: true, Output: fmt.Sprintf("%v", result)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expression parser supporting +, -, *, /, %, ^, functions, constants.
|
|
||||||
type exprParser struct {
|
|
||||||
s string
|
|
||||||
pos int
|
|
||||||
}
|
|
||||||
|
|
||||||
func evalExpression(s string) (float64, error) {
|
|
||||||
p := &exprParser{s: strings.TrimSpace(s)}
|
|
||||||
result, err := p.parseAddSub()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if p.pos < len(p.s) {
|
|
||||||
return 0, fmt.Errorf("unexpected character at position %d: %c", p.pos, p.s[p.pos])
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *exprParser) peek() byte {
|
|
||||||
if p.pos < len(p.s) {
|
|
||||||
return p.s[p.pos]
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *exprParser) skipSpaces() {
|
|
||||||
for p.pos < len(p.s) && p.s[p.pos] == ' ' {
|
|
||||||
p.pos++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *exprParser) parseAddSub() (float64, error) {
|
|
||||||
left, err := p.parseMulDiv()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
p.skipSpaces()
|
|
||||||
op := p.peek()
|
|
||||||
if op != '+' && op != '-' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
p.pos++
|
|
||||||
right, err := p.parseMulDiv()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if op == '+' {
|
|
||||||
left += right
|
|
||||||
} else {
|
|
||||||
left -= right
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return left, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *exprParser) parseMulDiv() (float64, error) {
|
|
||||||
left, err := p.parsePower()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
p.skipSpaces()
|
|
||||||
op := p.peek()
|
|
||||||
if op != '*' && op != '/' && op != '%' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
p.pos++
|
|
||||||
right, err := p.parsePower()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
switch op {
|
|
||||||
case '*':
|
|
||||||
left *= right
|
|
||||||
case '/':
|
|
||||||
if right == 0 {
|
|
||||||
return 0, fmt.Errorf("division by zero")
|
|
||||||
}
|
|
||||||
left /= right
|
|
||||||
case '%':
|
|
||||||
left = math.Mod(left, right)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return left, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *exprParser) parsePower() (float64, error) {
|
|
||||||
left, err := p.parseUnary()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
p.skipSpaces()
|
|
||||||
if p.peek() == '^' {
|
|
||||||
p.pos++
|
|
||||||
right, err := p.parseUnary()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return math.Pow(left, right), nil
|
|
||||||
}
|
|
||||||
return left, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *exprParser) parseUnary() (float64, error) {
|
|
||||||
p.skipSpaces()
|
|
||||||
if p.peek() == '-' {
|
|
||||||
p.pos++
|
|
||||||
val, err := p.parseAtom()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return -val, nil
|
|
||||||
}
|
|
||||||
if p.peek() == '+' {
|
|
||||||
p.pos++
|
|
||||||
}
|
|
||||||
return p.parseAtom()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *exprParser) parseAtom() (float64, error) {
|
|
||||||
p.skipSpaces()
|
|
||||||
if p.peek() == '(' {
|
|
||||||
p.pos++
|
|
||||||
result, err := p.parseAddSub()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
p.skipSpaces()
|
|
||||||
if p.peek() != ')' {
|
|
||||||
return 0, fmt.Errorf("missing closing parenthesis")
|
|
||||||
}
|
|
||||||
p.pos++
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
if p.peek() == 0 {
|
|
||||||
return 0, fmt.Errorf("unexpected end of expression")
|
|
||||||
}
|
|
||||||
if unicode.IsDigit(rune(p.peek())) || p.peek() == '.' {
|
|
||||||
return p.parseNumber()
|
|
||||||
}
|
|
||||||
return p.parseFuncOrConst()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *exprParser) parseNumber() (float64, error) {
|
|
||||||
start := p.pos
|
|
||||||
for p.pos < len(p.s) && (unicode.IsDigit(rune(p.s[p.pos])) || p.s[p.pos] == '.') {
|
|
||||||
p.pos++
|
|
||||||
}
|
|
||||||
return strconv.ParseFloat(p.s[start:p.pos], 64)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *exprParser) parseFuncOrConst() (float64, error) {
|
|
||||||
start := p.pos
|
|
||||||
for p.pos < len(p.s) && (unicode.IsLetter(rune(p.s[p.pos])) || p.s[p.pos] == '_') {
|
|
||||||
p.pos++
|
|
||||||
}
|
|
||||||
name := p.s[start:p.pos]
|
|
||||||
p.skipSpaces()
|
|
||||||
|
|
||||||
switch name {
|
|
||||||
case "pi":
|
|
||||||
return math.Pi, nil
|
|
||||||
case "e":
|
|
||||||
return math.E, nil
|
|
||||||
case "sqrt", "sin", "cos", "tan", "abs", "floor", "ceil", "round", "log", "ln":
|
|
||||||
if p.peek() != '(' {
|
|
||||||
return 0, fmt.Errorf("expected '(' after function %s", name)
|
|
||||||
}
|
|
||||||
p.pos++
|
|
||||||
arg, err := p.parseAddSub()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if p.peek() != ')' {
|
|
||||||
return 0, fmt.Errorf("missing ')' after function argument")
|
|
||||||
}
|
|
||||||
p.pos++
|
|
||||||
return applyFunc(name, arg)
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("unknown function or constant: %s", name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyFunc(name string, x float64) (float64, error) {
|
|
||||||
switch name {
|
|
||||||
case "sqrt":
|
|
||||||
if x < 0 {
|
|
||||||
return 0, fmt.Errorf("square root of negative number")
|
|
||||||
}
|
|
||||||
return math.Sqrt(x), nil
|
|
||||||
case "sin":
|
|
||||||
return math.Sin(x), nil
|
|
||||||
case "cos":
|
|
||||||
return math.Cos(x), nil
|
|
||||||
case "tan":
|
|
||||||
return math.Tan(x), nil
|
|
||||||
case "abs":
|
|
||||||
return math.Abs(x), nil
|
|
||||||
case "floor":
|
|
||||||
return math.Floor(x), nil
|
|
||||||
case "ceil":
|
|
||||||
return math.Ceil(x), nil
|
|
||||||
case "round":
|
|
||||||
return math.Round(x), nil
|
|
||||||
case "log":
|
|
||||||
if x <= 0 {
|
|
||||||
return 0, fmt.Errorf("log of non-positive number")
|
|
||||||
}
|
|
||||||
return math.Log10(x), nil
|
|
||||||
case "ln":
|
|
||||||
if x <= 0 {
|
|
||||||
return 0, fmt.Errorf("ln of non-positive number")
|
|
||||||
}
|
|
||||||
return math.Log(x), nil
|
|
||||||
}
|
|
||||||
return 0, fmt.Errorf("unknown function: %s", name)
|
|
||||||
}
|
|
||||||
@@ -1,116 +0,0 @@
|
|||||||
package crypto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/md5"
|
|
||||||
"crypto/sha1"
|
|
||||||
"crypto/sha256"
|
|
||||||
"crypto/sha512"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"hash"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CryptoPlugin struct{ sdk.BasePlugin }
|
|
||||||
|
|
||||||
func (p *CryptoPlugin) Metadata() sdk.PluginMetadata {
|
|
||||||
return sdk.PluginMetadata{
|
|
||||||
Name: "crypto", DisplayName: "Crypto & Encoding", Version: "1.0.0",
|
|
||||||
Description: "Hashing (MD5/SHA) and encoding (Base64, URL) utilities",
|
|
||||||
Category: "utility", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *CryptoPlugin) Tools() []sdk.Tool { return []sdk.Tool{&CryptoTool{}} }
|
|
||||||
|
|
||||||
type CryptoTool struct{ sdk.BaseTool }
|
|
||||||
|
|
||||||
func (t *CryptoTool) Definition() sdk.ToolDefinition {
|
|
||||||
return sdk.ToolDefinition{
|
|
||||||
ID: "crypto", Name: "crypto", DisplayName: "Crypto & Encoding",
|
|
||||||
Description: "Crypto hash and encoding utilities. MD5/SHA hashing, Base64 encode/decode, URL encode/decode.",
|
|
||||||
Category: "utility", Complexity: sdk.ComplexitySimple,
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"action": map[string]interface{}{"type": "string", "enum": []string{"hash", "base64_encode", "base64_decode", "url_encode", "url_decode"}},
|
|
||||||
"input": map[string]interface{}{"type": "string"},
|
|
||||||
"algorithm": map[string]interface{}{"type": "string", "enum": []string{"md5", "sha1", "sha256", "sha512"}},
|
|
||||||
},
|
|
||||||
"required": []string{"action", "input"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *CryptoTool) Validate(args map[string]interface{}) error {
|
|
||||||
for _, k := range []string{"action", "input"} {
|
|
||||||
if _, ok := args[k]; !ok {
|
|
||||||
return fmt.Errorf("missing required parameter: %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *CryptoTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
action, _ := args["action"].(string)
|
|
||||||
input, _ := args["input"].(string)
|
|
||||||
|
|
||||||
switch action {
|
|
||||||
case "hash":
|
|
||||||
alg, _ := args["algorithm"].(string)
|
|
||||||
if alg == "" {
|
|
||||||
alg = "sha256"
|
|
||||||
}
|
|
||||||
var h hash.Hash
|
|
||||||
switch alg {
|
|
||||||
case "md5":
|
|
||||||
h = md5.New()
|
|
||||||
case "sha1":
|
|
||||||
h = sha1.New()
|
|
||||||
case "sha256":
|
|
||||||
h = sha256.New()
|
|
||||||
case "sha512":
|
|
||||||
h = sha512.New()
|
|
||||||
default:
|
|
||||||
return &sdk.ToolResult{ToolName: "crypto", Success: false, Error: "unsupported algorithm: " + alg}, nil
|
|
||||||
}
|
|
||||||
h.Write([]byte(input))
|
|
||||||
return &sdk.ToolResult{ToolName: "crypto", Success: true,
|
|
||||||
Output: fmt.Sprintf("%s: %x", alg, h.Sum(nil))}, nil
|
|
||||||
|
|
||||||
case "base64_encode":
|
|
||||||
return &sdk.ToolResult{ToolName: "crypto", Success: true,
|
|
||||||
Output: base64.StdEncoding.EncodeToString([]byte(input))}, nil
|
|
||||||
|
|
||||||
case "base64_decode":
|
|
||||||
for _, enc := range []*base64.Encoding{base64.StdEncoding, base64.RawStdEncoding, base64.URLEncoding, base64.RawURLEncoding} {
|
|
||||||
if decoded, err := enc.DecodeString(input); err == nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "crypto", Success: true, Output: truncate(string(decoded), 200)}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "crypto", Success: false, Error: "failed to decode base64"}, nil
|
|
||||||
|
|
||||||
case "url_encode":
|
|
||||||
return &sdk.ToolResult{ToolName: "crypto", Success: true,
|
|
||||||
Output: url.QueryEscape(input)}, nil
|
|
||||||
|
|
||||||
case "url_decode":
|
|
||||||
decoded, err := url.QueryUnescape(input)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "crypto", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "crypto", Success: true, Output: decoded}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "crypto", Success: false, Error: "unknown action: " + action}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func truncate(s string, n int) string {
|
|
||||||
runes := []rune(s)
|
|
||||||
if len(runes) > n {
|
|
||||||
return string(runes[:n]) + "..."
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
package datetime
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
type DatetimePlugin struct{ sdk.BasePlugin }
|
|
||||||
|
|
||||||
func (p *DatetimePlugin) Metadata() sdk.PluginMetadata {
|
|
||||||
return sdk.PluginMetadata{
|
|
||||||
Name: "datetime", DisplayName: "Date & Time", Version: "1.0.0",
|
|
||||||
Description: "Date/time utilities: now, format, arithmetic, diff, timezone list",
|
|
||||||
Category: "utility", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DatetimePlugin) Tools() []sdk.Tool { return []sdk.Tool{&DatetimeTool{}} }
|
|
||||||
|
|
||||||
type DatetimeTool struct{ sdk.BaseTool }
|
|
||||||
|
|
||||||
func (t *DatetimeTool) Definition() sdk.ToolDefinition {
|
|
||||||
return sdk.ToolDefinition{
|
|
||||||
ID: "datetime", Name: "datetime", DisplayName: "Date & Time",
|
|
||||||
Description: "Date/time utility. Get current time, format dates, date arithmetic, date diff, list timezones.",
|
|
||||||
Category: "utility", Complexity: sdk.ComplexitySimple,
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"action": map[string]interface{}{"type": "string", "enum": []string{"now", "format", "add", "diff", "timezone_list"}},
|
|
||||||
"format": map[string]interface{}{"type": "string"},
|
|
||||||
"timezone": map[string]interface{}{"type": "string"},
|
|
||||||
"date": map[string]interface{}{"type": "string"},
|
|
||||||
"duration": map[string]interface{}{"type": "string"},
|
|
||||||
"date2": map[string]interface{}{"type": "string"},
|
|
||||||
},
|
|
||||||
"required": []string{"action"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *DatetimeTool) Validate(args map[string]interface{}) error {
|
|
||||||
if _, ok := args["action"]; !ok {
|
|
||||||
return fmt.Errorf("missing required parameter: action")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *DatetimeTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
action, _ := args["action"].(string)
|
|
||||||
tzStr, _ := args["timezone"].(string)
|
|
||||||
loc, _ := parseLocation(tzStr)
|
|
||||||
now := time.Now().In(loc)
|
|
||||||
|
|
||||||
switch action {
|
|
||||||
case "now":
|
|
||||||
return &sdk.ToolResult{ToolName: "datetime", Success: true,
|
|
||||||
Output: fmt.Sprintf("Current time: %s (unix: %d, zone: %s)", now.Format(time.RFC3339), now.Unix(), loc.String())}, nil
|
|
||||||
|
|
||||||
case "format":
|
|
||||||
dateStr, _ := args["date"].(string)
|
|
||||||
format, _ := args["format"].(string)
|
|
||||||
if format == "" {
|
|
||||||
format = time.RFC3339
|
|
||||||
}
|
|
||||||
parsed, err := parseDate(dateStr, loc)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "datetime", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "datetime", Success: true,
|
|
||||||
Output: fmt.Sprintf("Formatted: %s", parsed.Format(format))}, nil
|
|
||||||
|
|
||||||
case "add":
|
|
||||||
dateStr, _ := args["date"].(string)
|
|
||||||
durStr, _ := args["duration"].(string)
|
|
||||||
base := now
|
|
||||||
if dateStr != "" {
|
|
||||||
var err error
|
|
||||||
base, err = parseDate(dateStr, loc)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "datetime", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result, err := addDuration(base, durStr)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "datetime", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "datetime", Success: true,
|
|
||||||
Output: fmt.Sprintf("%s + %s = %s", base.Format(time.RFC3339), durStr, result.Format(time.RFC3339))}, nil
|
|
||||||
|
|
||||||
case "diff":
|
|
||||||
d1, _ := args["date"].(string)
|
|
||||||
d2, _ := args["date2"].(string)
|
|
||||||
t1, err := parseDate(d1, loc)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "datetime", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
t2, err := parseDate(d2, loc)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "datetime", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
diff := t2.Sub(t1)
|
|
||||||
if diff < 0 {
|
|
||||||
diff = -diff
|
|
||||||
}
|
|
||||||
days := int(diff.Hours()) / 24
|
|
||||||
hours := int(diff.Hours()) % 24
|
|
||||||
minutes := int(diff.Minutes()) % 60
|
|
||||||
seconds := int(diff.Seconds()) % 60
|
|
||||||
return &sdk.ToolResult{ToolName: "datetime", Success: true,
|
|
||||||
Output: fmt.Sprintf("Difference: %d days, %d hours, %d minutes, %d seconds", days, hours, minutes, seconds)}, nil
|
|
||||||
|
|
||||||
case "timezone_list":
|
|
||||||
return &sdk.ToolResult{ToolName: "datetime", Success: true,
|
|
||||||
Output: "Common timezones: UTC, Asia/Shanghai, Asia/Tokyo, Asia/Seoul, Asia/Singapore, Asia/Kolkata, Asia/Dubai, Europe/London, Europe/Paris, Europe/Moscow, America/New_York, America/Chicago, America/Los_Angeles, America/Sao_Paulo, Australia/Sydney, Pacific/Auckland, Africa/Cairo, Africa/Lagos"}, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return &sdk.ToolResult{ToolName: "datetime", Success: false,
|
|
||||||
Error: fmt.Sprintf("unknown action: %s", action)}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseLocation(tz string) (*time.Location, error) {
|
|
||||||
if tz == "" {
|
|
||||||
loc, err := time.LoadLocation("Asia/Shanghai")
|
|
||||||
if err != nil {
|
|
||||||
return time.UTC, nil
|
|
||||||
}
|
|
||||||
return loc, nil
|
|
||||||
}
|
|
||||||
return time.LoadLocation(tz)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseDate(s string, loc *time.Location) (time.Time, error) {
|
|
||||||
formats := []string{time.RFC3339, "2006-01-02T15:04:05", "2006-01-02 15:04:05", "2006-01-02", "2006/01/02"}
|
|
||||||
for _, f := range formats {
|
|
||||||
if t, err := time.ParseInLocation(f, s, loc); err == nil {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return time.Time{}, fmt.Errorf("cannot parse date: %s", s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func addDuration(t time.Time, durStr string) (time.Time, error) {
|
|
||||||
durStr = strings.TrimSpace(durStr)
|
|
||||||
if durStr == "" {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
// Handle months and years
|
|
||||||
if strings.Contains(durStr, "M") || strings.Contains(durStr, "y") {
|
|
||||||
months := 0
|
|
||||||
years := 0
|
|
||||||
if strings.Contains(durStr, "y") {
|
|
||||||
fmt.Sscanf(durStr, "%dy", &years)
|
|
||||||
}
|
|
||||||
if strings.Contains(durStr, "M") {
|
|
||||||
fmt.Sscanf(durStr, "%dM", &months)
|
|
||||||
}
|
|
||||||
return t.AddDate(years, months, 0), nil
|
|
||||||
}
|
|
||||||
d, err := time.ParseDuration(durStr)
|
|
||||||
if err != nil {
|
|
||||||
return t, fmt.Errorf("invalid duration: %s", durStr)
|
|
||||||
}
|
|
||||||
return t.Add(d), nil
|
|
||||||
}
|
|
||||||
@@ -1,158 +0,0 @@
|
|||||||
package file
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
type FilePlugin struct {
|
|
||||||
sdk.BasePlugin
|
|
||||||
dataDir string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewFilePlugin(dataDir string) *FilePlugin {
|
|
||||||
if dataDir == "" {
|
|
||||||
dataDir = "/tmp/cyrene_data"
|
|
||||||
}
|
|
||||||
return &FilePlugin{dataDir: dataDir}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *FilePlugin) Metadata() sdk.PluginMetadata {
|
|
||||||
return sdk.PluginMetadata{
|
|
||||||
Name: "file", DisplayName: "File Operations", Version: "1.0.0",
|
|
||||||
Description: "Sandboxed file operations: read, write, list, delete within DATA_DIR",
|
|
||||||
Category: "system", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *FilePlugin) Tools() []sdk.Tool { return []sdk.Tool{&FileTool{dataDir: p.dataDir}} }
|
|
||||||
|
|
||||||
type FileTool struct {
|
|
||||||
sdk.BaseTool
|
|
||||||
dataDir string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *FileTool) Definition() sdk.ToolDefinition {
|
|
||||||
return sdk.ToolDefinition{
|
|
||||||
ID: "file_ops", Name: "file_ops", DisplayName: "File Operations",
|
|
||||||
Description: "File operations within a sandboxed data directory. Read, write, list, check existence, delete.",
|
|
||||||
Category: "system", Complexity: sdk.ComplexitySimple,
|
|
||||||
DangerLevel: "medium",
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"action": map[string]interface{}{"type": "string", "enum": []string{"read", "write", "list", "exists", "delete"}},
|
|
||||||
"path": map[string]interface{}{"type": "string"},
|
|
||||||
"content": map[string]interface{}{"type": "string"},
|
|
||||||
},
|
|
||||||
"required": []string{"action", "path"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *FileTool) Validate(args map[string]interface{}) error {
|
|
||||||
for _, k := range []string{"action", "path"} {
|
|
||||||
if _, ok := args[k]; !ok {
|
|
||||||
return fmt.Errorf("missing required parameter: %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *FileTool) safePath(p string) (string, error) {
|
|
||||||
clean := filepath.Clean(p)
|
|
||||||
abs, err := filepath.Abs(filepath.Join(t.dataDir, clean))
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("path resolution failed: %w", err)
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(abs, filepath.Clean(t.dataDir)+string(os.PathSeparator)) && abs != filepath.Clean(t.dataDir) {
|
|
||||||
return "", fmt.Errorf("path traversal denied: %s", p)
|
|
||||||
}
|
|
||||||
return abs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *FileTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
action, _ := args["action"].(string)
|
|
||||||
pathStr, _ := args["path"].(string)
|
|
||||||
|
|
||||||
safePath, err := t.safePath(pathStr)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch action {
|
|
||||||
case "read":
|
|
||||||
info, err := os.Stat(safePath)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
if info.IsDir() {
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: "cannot read a directory"}, nil
|
|
||||||
}
|
|
||||||
if info.Size() > 100*1024 {
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: "file too large (>100KB)"}, nil
|
|
||||||
}
|
|
||||||
data, err := os.ReadFile(safePath)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: true, Output: string(data)}, nil
|
|
||||||
|
|
||||||
case "write":
|
|
||||||
content, _ := args["content"].(string)
|
|
||||||
dir := filepath.Dir(safePath)
|
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(safePath, []byte(content), 0644); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: true, Output: fmt.Sprintf("Written %d bytes to %s", len(content), pathStr)}, nil
|
|
||||||
|
|
||||||
case "list":
|
|
||||||
entries, err := os.ReadDir(safePath)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
var out strings.Builder
|
|
||||||
for _, e := range entries {
|
|
||||||
info, _ := e.Info()
|
|
||||||
if e.IsDir() {
|
|
||||||
out.WriteString(fmt.Sprintf("[DIR] %s/\n", e.Name()))
|
|
||||||
} else {
|
|
||||||
out.WriteString(fmt.Sprintf("[FILE] %s (%d bytes)\n", e.Name(), info.Size()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: true, Output: out.String()}, nil
|
|
||||||
|
|
||||||
case "exists":
|
|
||||||
info, err := os.Stat(safePath)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: true, Output: fmt.Sprintf("Path does not exist: %s", pathStr)}, nil
|
|
||||||
}
|
|
||||||
kind := "file"
|
|
||||||
if info.IsDir() {
|
|
||||||
kind = "directory"
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: true, Output: fmt.Sprintf("Path exists (%s): %s", kind, pathStr)}, nil
|
|
||||||
|
|
||||||
case "delete":
|
|
||||||
info, err := os.Stat(safePath)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
if info.IsDir() {
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: "cannot delete a directory"}, nil
|
|
||||||
}
|
|
||||||
if err := os.Remove(safePath); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: true, Output: fmt.Sprintf("Deleted: %s", pathStr)}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: "unknown action: " + action}, nil
|
|
||||||
}
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
module git.yeij.top/AskaEth/Cyrene/pkg/plugins
|
|
||||||
|
|
||||||
go 1.21
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
package http
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
type HTTPPlugin struct {
|
|
||||||
sdk.BasePlugin
|
|
||||||
client *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHTTPPlugin() *HTTPPlugin {
|
|
||||||
return &HTTPPlugin{client: &http.Client{Timeout: 10 * time.Second}}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *HTTPPlugin) Metadata() sdk.PluginMetadata {
|
|
||||||
return sdk.PluginMetadata{
|
|
||||||
Name: "http", DisplayName: "HTTP Client", Version: "1.0.0",
|
|
||||||
Description: "Send arbitrary HTTP requests with custom methods, headers, body",
|
|
||||||
Category: "network", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *HTTPPlugin) Tools() []sdk.Tool { return []sdk.Tool{&HTTPTool{client: p.client}} }
|
|
||||||
|
|
||||||
type HTTPTool struct {
|
|
||||||
sdk.BaseTool
|
|
||||||
client *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *HTTPTool) Definition() sdk.ToolDefinition {
|
|
||||||
return sdk.ToolDefinition{
|
|
||||||
ID: "http_request", Name: "http_request", DisplayName: "HTTP Client",
|
|
||||||
Description: "Send arbitrary HTTP requests. Supports custom methods, headers, and body.",
|
|
||||||
Category: "network", Complexity: sdk.ComplexitySimple,
|
|
||||||
DangerLevel: "low",
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"url": map[string]interface{}{"type": "string"},
|
|
||||||
"method": map[string]interface{}{"type": "string", "enum": []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}},
|
|
||||||
"headers": map[string]interface{}{"type": "object"},
|
|
||||||
"body": map[string]interface{}{"type": "string"},
|
|
||||||
"timeout": map[string]interface{}{"type": "number"},
|
|
||||||
},
|
|
||||||
"required": []string{"url"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var allowedMethods = map[string]bool{
|
|
||||||
"GET": true, "POST": true, "PUT": true, "DELETE": true,
|
|
||||||
"PATCH": true, "HEAD": true, "OPTIONS": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *HTTPTool) Validate(args map[string]interface{}) error {
|
|
||||||
if _, ok := args["url"]; !ok {
|
|
||||||
return fmt.Errorf("missing required parameter: url")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *HTTPTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
urlStr, _ := args["url"].(string)
|
|
||||||
method, _ := args["method"].(string)
|
|
||||||
if method == "" {
|
|
||||||
method = "GET"
|
|
||||||
}
|
|
||||||
if !allowedMethods[method] {
|
|
||||||
return &sdk.ToolResult{ToolName: "http_request", Success: false, Error: "invalid method: " + method}, nil
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(urlStr, "http://") && !strings.HasPrefix(urlStr, "https://") {
|
|
||||||
return &sdk.ToolResult{ToolName: "http_request", Success: false, Error: "only http/https URLs allowed"}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var bodyReader io.Reader
|
|
||||||
if body, _ := args["body"].(string); body != "" {
|
|
||||||
bodyReader = strings.NewReader(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest(method, urlStr, bodyReader)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "http_request", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
req.Header.Set("User-Agent", "CyreneBot/1.0")
|
|
||||||
|
|
||||||
if headers, ok := args["headers"].(map[string]interface{}); ok {
|
|
||||||
for k, v := range headers {
|
|
||||||
req.Header.Set(k, fmt.Sprint(v))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
client := t.client
|
|
||||||
if timeout, _ := args["timeout"].(float64); timeout > 0 {
|
|
||||||
client = &http.Client{Timeout: time.Duration(timeout) * time.Second}
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "http_request", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, 50*1024))
|
|
||||||
return &sdk.ToolResult{ToolName: "http_request", Success: resp.StatusCode < 500, Output: fmt.Sprintf(
|
|
||||||
"HTTP %d\n%s\n\n%s", resp.StatusCode, formatHeaders(resp.Header), string(bodyBytes))}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatHeaders(h http.Header) string {
|
|
||||||
var lines []string
|
|
||||||
for k, v := range h {
|
|
||||||
lines = append(lines, fmt.Sprintf("%s: %s", k, strings.Join(v, ", ")))
|
|
||||||
}
|
|
||||||
return strings.Join(lines, "\n")
|
|
||||||
}
|
|
||||||
@@ -1,189 +0,0 @@
|
|||||||
package iotcontrol
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
// IoTController extends IoTClient with control operations.
|
|
||||||
type IoTController interface {
|
|
||||||
GetDevice(ctx context.Context, deviceID string) (*sdk.IoTDeviceState, error)
|
|
||||||
SetDeviceProperty(ctx context.Context, deviceID, property string, value interface{}) error
|
|
||||||
ToggleDevice(ctx context.Context, deviceID string) (*sdk.IoTDeviceState, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type IoTControlPlugin struct {
|
|
||||||
sdk.BasePlugin
|
|
||||||
iotClient IoTController
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewIoTControlPlugin(client IoTController) *IoTControlPlugin {
|
|
||||||
return &IoTControlPlugin{iotClient: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *IoTControlPlugin) Metadata() sdk.PluginMetadata {
|
|
||||||
return sdk.PluginMetadata{
|
|
||||||
Name: "iot_control", DisplayName: "IoT Device Control", Version: "1.0.0",
|
|
||||||
Description: "Control smart home devices: toggle, set temperature/brightness/mode/color",
|
|
||||||
Category: "iot", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *IoTControlPlugin) Tools() []sdk.Tool {
|
|
||||||
return []sdk.Tool{&IoTControlTool{iotClient: p.iotClient}}
|
|
||||||
}
|
|
||||||
|
|
||||||
type IoTControlTool struct {
|
|
||||||
sdk.BaseTool
|
|
||||||
iotClient IoTController
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IoTControlTool) Definition() sdk.ToolDefinition {
|
|
||||||
return sdk.ToolDefinition{
|
|
||||||
ID: "iot_control", Name: "iot_control", DisplayName: "IoT Device Control",
|
|
||||||
Description: "Control smart home devices. Supports toggle, turn_on, turn_off, set_temperature, set_brightness, set_position, set_mode, set_color.",
|
|
||||||
Category: "iot", Complexity: sdk.ComplexitySimple,
|
|
||||||
DangerLevel: "medium",
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"device_id": map[string]interface{}{"type": "string"},
|
|
||||||
"action": map[string]interface{}{"type": "string", "enum": []string{"toggle", "turn_on", "turn_off", "set_temperature", "set_brightness", "set_position", "set_mode", "set_color"}},
|
|
||||||
"value": map[string]interface{}{},
|
|
||||||
},
|
|
||||||
"required": []string{"device_id", "action"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IoTControlTool) Validate(args map[string]interface{}) error {
|
|
||||||
for _, k := range []string{"device_id", "action"} {
|
|
||||||
if _, ok := args[k]; !ok {
|
|
||||||
return fmt.Errorf("missing required parameter: %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IoTControlTool) Execute(ctx context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
if t.iotClient == nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: "IoT client not configured"}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceID, _ := args["device_id"].(string)
|
|
||||||
if deviceID == "" {
|
|
||||||
deviceID, _ = args["entity_id"].(string)
|
|
||||||
}
|
|
||||||
action := normalizeAction(args)
|
|
||||||
|
|
||||||
switch action {
|
|
||||||
case "turn_on", "turn_off":
|
|
||||||
status := "on"
|
|
||||||
if action == "turn_off" {
|
|
||||||
status = "off"
|
|
||||||
}
|
|
||||||
if err := t.iotClient.SetDeviceProperty(ctx, deviceID, "status", status); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
|
||||||
Output: fmt.Sprintf("Device %s turned %s", deviceID, status)}, nil
|
|
||||||
|
|
||||||
case "set_temperature":
|
|
||||||
value := toFloat64(args["value"])
|
|
||||||
old := ""
|
|
||||||
if dev, err := t.iotClient.GetDevice(ctx, deviceID); err == nil {
|
|
||||||
old = fmt.Sprintf(" (was %.1fC)", dev.Temperature)
|
|
||||||
}
|
|
||||||
if err := t.iotClient.SetDeviceProperty(ctx, deviceID, "temperature", value); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
|
||||||
Output: fmt.Sprintf("Temperature set to %.1fC%s", value, old)}, nil
|
|
||||||
|
|
||||||
case "set_brightness":
|
|
||||||
value := toFloat64(args["value"])
|
|
||||||
if err := t.iotClient.SetDeviceProperty(ctx, deviceID, "brightness", value); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
|
||||||
Output: fmt.Sprintf("Brightness set to %.0f%%", value)}, nil
|
|
||||||
|
|
||||||
case "set_position":
|
|
||||||
value := toFloat64(args["value"])
|
|
||||||
if err := t.iotClient.SetDeviceProperty(ctx, deviceID, "position", value); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
|
||||||
Output: fmt.Sprintf("Position set to %.0f%%", value)}, nil
|
|
||||||
|
|
||||||
case "set_mode":
|
|
||||||
value, _ := args["value"].(string)
|
|
||||||
if err := t.iotClient.SetDeviceProperty(ctx, deviceID, "mode", value); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
|
||||||
Output: fmt.Sprintf("Mode set to %s", value)}, nil
|
|
||||||
|
|
||||||
case "set_color":
|
|
||||||
value, _ := args["value"].(string)
|
|
||||||
if err := t.iotClient.SetDeviceProperty(ctx, deviceID, "color", value); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
|
||||||
Output: fmt.Sprintf("Color set to %s", value)}, nil
|
|
||||||
|
|
||||||
case "toggle":
|
|
||||||
dev, err := t.iotClient.ToggleDevice(ctx, deviceID)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
|
||||||
Output: fmt.Sprintf("Device %s toggled to %s", deviceID, dev.Status)}, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false,
|
|
||||||
Error: fmt.Sprintf("unknown action: %s", action)}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeAction(args map[string]interface{}) string {
|
|
||||||
action, _ := args["action"].(string)
|
|
||||||
// Chinese aliases
|
|
||||||
switch action {
|
|
||||||
case "打开":
|
|
||||||
return "turn_on"
|
|
||||||
case "关闭", "关掉", "关上":
|
|
||||||
return "turn_off"
|
|
||||||
case "设置温度", "调温度":
|
|
||||||
return "set_temperature"
|
|
||||||
case "设置亮度", "调亮度":
|
|
||||||
return "set_brightness"
|
|
||||||
case "设置位置":
|
|
||||||
return "set_position"
|
|
||||||
case "设置模式":
|
|
||||||
return "set_mode"
|
|
||||||
case "设置颜色":
|
|
||||||
return "set_color"
|
|
||||||
case "开关", "切换":
|
|
||||||
return "toggle"
|
|
||||||
}
|
|
||||||
return action
|
|
||||||
}
|
|
||||||
|
|
||||||
func toFloat64(v interface{}) float64 {
|
|
||||||
switch n := v.(type) {
|
|
||||||
case float64:
|
|
||||||
return n
|
|
||||||
case int:
|
|
||||||
return float64(n)
|
|
||||||
case int64:
|
|
||||||
return float64(n)
|
|
||||||
case string:
|
|
||||||
var f float64
|
|
||||||
fmt.Sscanf(n, "%f", &f)
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
package iotquery
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
// IoTClient is the interface for IoT device access.
|
|
||||||
type IoTClient interface {
|
|
||||||
GetAllDevices(ctx context.Context) ([]sdk.IoTDeviceState, error)
|
|
||||||
GetDevice(ctx context.Context, deviceID string) (*sdk.IoTDeviceState, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type IoTQueryPlugin struct {
|
|
||||||
sdk.BasePlugin
|
|
||||||
iotClient IoTClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewIoTQueryPlugin(client IoTClient) *IoTQueryPlugin {
|
|
||||||
return &IoTQueryPlugin{iotClient: client}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *IoTQueryPlugin) Metadata() sdk.PluginMetadata {
|
|
||||||
return sdk.PluginMetadata{
|
|
||||||
Name: "iot_query", DisplayName: "IoT Device Query", Version: "1.0.0",
|
|
||||||
Description: "Query smart home device status (single device or all devices)",
|
|
||||||
Category: "iot", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *IoTQueryPlugin) Tools() []sdk.Tool { return []sdk.Tool{&IoTQueryTool{iotClient: p.iotClient}} }
|
|
||||||
|
|
||||||
type IoTQueryTool struct {
|
|
||||||
sdk.BaseTool
|
|
||||||
iotClient IoTClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IoTQueryTool) Definition() sdk.ToolDefinition {
|
|
||||||
return sdk.ToolDefinition{
|
|
||||||
ID: "iot_query", Name: "iot_query", DisplayName: "IoT Device Query",
|
|
||||||
Description: "Query smart home device status. Device status is typically auto-injected; use this only when status is stale.",
|
|
||||||
Category: "iot", Complexity: sdk.ComplexitySimple,
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{"device_id": map[string]interface{}{"type": "string"}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *IoTQueryTool) Validate(args map[string]interface{}) error { return nil }
|
|
||||||
|
|
||||||
func (t *IoTQueryTool) Execute(ctx context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
if t.iotClient == nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_query", Success: false, Error: "IoT client not configured"}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceID, _ := args["device_id"].(string)
|
|
||||||
if deviceID != "" {
|
|
||||||
dev, err := t.iotClient.GetDevice(ctx, deviceID)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_query", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_query", Success: true, Output: formatDevice(dev)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
devices, err := t.iotClient.GetAllDevices(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_query", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
if len(devices) == 0 {
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_query", Success: true, Output: "No devices found"}, nil
|
|
||||||
}
|
|
||||||
var out string
|
|
||||||
for _, d := range devices {
|
|
||||||
out += formatDeviceLine(&d) + "\n"
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "iot_query", Success: true, Output: out}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatDevice(d *sdk.IoTDeviceState) string {
|
|
||||||
emoji := deviceEmoji(d.Type)
|
|
||||||
return fmt.Sprintf("%s %s (%s)\n Status: %s\n ID: %s", emoji, d.Name, d.Type, d.Status, d.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatDeviceLine(d *sdk.IoTDeviceState) string {
|
|
||||||
emoji := deviceEmoji(d.Type)
|
|
||||||
switch d.Type {
|
|
||||||
case "light":
|
|
||||||
return fmt.Sprintf("%s %s: %s (brightness: %d, color: %s)", emoji, d.Name, d.Status, d.Brightness, d.Color)
|
|
||||||
case "ac":
|
|
||||||
return fmt.Sprintf("%s %s: %s (mode: %s, temp: %.1fC)", emoji, d.Name, d.Status, d.Mode, d.Temperature)
|
|
||||||
case "curtain":
|
|
||||||
return fmt.Sprintf("%s %s: %s (position: %d%%)", emoji, d.Name, d.Status, d.Position)
|
|
||||||
case "sensor":
|
|
||||||
return fmt.Sprintf("%s %s: %.1f%s", emoji, d.Name, d.Value, d.Unit)
|
|
||||||
case "lock":
|
|
||||||
return fmt.Sprintf("%s %s: %s (battery: %d%%)", emoji, d.Name, d.Status, d.Battery)
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("%s %s: %s", emoji, d.Name, d.Status)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func deviceEmoji(t string) string {
|
|
||||||
switch t {
|
|
||||||
case "light":
|
|
||||||
return "\U0001F4A1"
|
|
||||||
case "ac":
|
|
||||||
return "❄️"
|
|
||||||
case "curtain":
|
|
||||||
return "\U0001F3E0"
|
|
||||||
case "sensor":
|
|
||||||
return "\U0001F4CA"
|
|
||||||
case "lock":
|
|
||||||
return "\U0001F512"
|
|
||||||
default:
|
|
||||||
return "\U0001F4E6"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,132 +0,0 @@
|
|||||||
package jsonplugin
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
type JSONPlugin struct{ sdk.BasePlugin }
|
|
||||||
|
|
||||||
func (p *JSONPlugin) Metadata() sdk.PluginMetadata {
|
|
||||||
return sdk.PluginMetadata{
|
|
||||||
Name: "json", DisplayName: "JSON Processor", Version: "1.0.0",
|
|
||||||
Description: "JSON parsing, dot-path query, validation, pretty-print",
|
|
||||||
Category: "format", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *JSONPlugin) Tools() []sdk.Tool { return []sdk.Tool{&JSONTool{}} }
|
|
||||||
|
|
||||||
type JSONTool struct{ sdk.BaseTool }
|
|
||||||
|
|
||||||
func (t *JSONTool) Definition() sdk.ToolDefinition {
|
|
||||||
return sdk.ToolDefinition{
|
|
||||||
ID: "json_ops", Name: "json_ops", DisplayName: "JSON Processor",
|
|
||||||
Description: "JSON processing. Parse/pretty-print, query by dot-notation path, validate.",
|
|
||||||
Category: "format", Complexity: sdk.ComplexitySimple,
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"action": map[string]interface{}{"type": "string", "enum": []string{"parse", "query", "validate"}},
|
|
||||||
"json_string": map[string]interface{}{"type": "string"},
|
|
||||||
"path": map[string]interface{}{"type": "string"},
|
|
||||||
},
|
|
||||||
"required": []string{"action", "json_string"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *JSONTool) Validate(args map[string]interface{}) error {
|
|
||||||
for _, k := range []string{"action", "json_string"} {
|
|
||||||
if _, ok := args[k]; !ok {
|
|
||||||
return fmt.Errorf("missing required parameter: %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *JSONTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
action, _ := args["action"].(string)
|
|
||||||
jsonStr, _ := args["json_string"].(string)
|
|
||||||
|
|
||||||
switch action {
|
|
||||||
case "parse":
|
|
||||||
var v interface{}
|
|
||||||
if err := json.Unmarshal([]byte(jsonStr), &v); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "json_ops", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
pretty, err := json.MarshalIndent(v, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "json_ops", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "json_ops", Success: true, Output: string(pretty)}, nil
|
|
||||||
|
|
||||||
case "query":
|
|
||||||
var v interface{}
|
|
||||||
if err := json.Unmarshal([]byte(jsonStr), &v); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "json_ops", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
path, _ := args["path"].(string)
|
|
||||||
if path == "" {
|
|
||||||
return &sdk.ToolResult{ToolName: "json_ops", Success: false, Error: "path is required for query"}, nil
|
|
||||||
}
|
|
||||||
result, err := jsonPathQuery(v, path)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "json_ops", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
out, _ := json.Marshal(result)
|
|
||||||
return &sdk.ToolResult{ToolName: "json_ops", Success: true, Output: string(out)}, nil
|
|
||||||
|
|
||||||
case "validate":
|
|
||||||
var v interface{}
|
|
||||||
if err := json.Unmarshal([]byte(jsonStr), &v); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "json_ops", Success: true, Output: "Invalid JSON: " + err.Error()}, nil
|
|
||||||
}
|
|
||||||
typeStr := "unknown"
|
|
||||||
switch v.(type) {
|
|
||||||
case map[string]interface{}:
|
|
||||||
typeStr = "object"
|
|
||||||
case []interface{}:
|
|
||||||
typeStr = "array"
|
|
||||||
case string:
|
|
||||||
typeStr = "string"
|
|
||||||
case float64:
|
|
||||||
typeStr = "number"
|
|
||||||
case bool:
|
|
||||||
typeStr = "boolean"
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "json_ops", Success: true,
|
|
||||||
Output: fmt.Sprintf("Valid JSON (type: %s, size: %d bytes)", typeStr, len(jsonStr))}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "json_ops", Success: false, Error: "unknown action: " + action}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func jsonPathQuery(root interface{}, path string) (interface{}, error) {
|
|
||||||
path = strings.TrimPrefix(path, "$.")
|
|
||||||
parts := strings.Split(path, ".")
|
|
||||||
current := root
|
|
||||||
for _, part := range parts {
|
|
||||||
switch v := current.(type) {
|
|
||||||
case map[string]interface{}:
|
|
||||||
var ok bool
|
|
||||||
current, ok = v[part]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("key %q not found", part)
|
|
||||||
}
|
|
||||||
case []interface{}:
|
|
||||||
idx, err := strconv.Atoi(part)
|
|
||||||
if err != nil || idx < 0 || idx >= len(v) {
|
|
||||||
return nil, fmt.Errorf("invalid array index: %s", part)
|
|
||||||
}
|
|
||||||
current = v[idx]
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("cannot traverse into %T at path segment %q", current, part)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return current, nil
|
|
||||||
}
|
|
||||||
@@ -1,226 +0,0 @@
|
|||||||
package manager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PluginManager manages the lifecycle of all plugins and their tools.
|
|
||||||
type PluginManager struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
plugins map[string]*pluginEntry
|
|
||||||
registry *ToolRegistry
|
|
||||||
host sdk.HostAPI
|
|
||||||
}
|
|
||||||
|
|
||||||
type pluginEntry struct {
|
|
||||||
instance sdk.Plugin
|
|
||||||
info sdk.PluginInfo
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPluginManager(registry *ToolRegistry, host sdk.HostAPI) *PluginManager {
|
|
||||||
return &PluginManager{
|
|
||||||
plugins: make(map[string]*pluginEntry),
|
|
||||||
registry: registry,
|
|
||||||
host: host,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Install registers a plugin instance.
|
|
||||||
func (m *PluginManager) Install(plugin sdk.Plugin) error {
|
|
||||||
meta := plugin.Metadata()
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
|
|
||||||
if _, exists := m.plugins[meta.Name]; exists {
|
|
||||||
return fmt.Errorf("plugin %q is already installed", meta.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.plugins[meta.Name] = &pluginEntry{
|
|
||||||
instance: plugin,
|
|
||||||
info: sdk.PluginInfo{
|
|
||||||
Metadata: meta,
|
|
||||||
Status: sdk.StatusInstalled,
|
|
||||||
InstalledAt: time.Now(),
|
|
||||||
Enabled: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable activates a plugin: Init → register tools → Start.
|
|
||||||
func (m *PluginManager) Enable(ctx context.Context, pluginName string) error {
|
|
||||||
m.mu.Lock()
|
|
||||||
entry, ok := m.plugins[pluginName]
|
|
||||||
m.mu.Unlock()
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("plugin %q not found", pluginName)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mu.Lock()
|
|
||||||
entry.info.Status = sdk.StatusLoaded
|
|
||||||
m.mu.Unlock()
|
|
||||||
|
|
||||||
meta := entry.instance.Metadata()
|
|
||||||
if err := entry.instance.Init(ctx, nil); err != nil {
|
|
||||||
m.mu.Lock()
|
|
||||||
entry.info.Status = sdk.StatusError
|
|
||||||
m.mu.Unlock()
|
|
||||||
return fmt.Errorf("plugin %q init failed: %w", meta.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pluginCtx, cancel := context.WithCancel(context.Background())
|
|
||||||
if err := entry.instance.Start(pluginCtx, m.host); err != nil {
|
|
||||||
cancel()
|
|
||||||
m.mu.Lock()
|
|
||||||
entry.info.Status = sdk.StatusError
|
|
||||||
m.mu.Unlock()
|
|
||||||
return fmt.Errorf("plugin %q start failed: %w", meta.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tools := entry.instance.Tools()
|
|
||||||
toolIDs := make([]string, 0, len(tools))
|
|
||||||
for _, t := range tools {
|
|
||||||
if err := m.registry.Register(t); err != nil {
|
|
||||||
m.registry.UnregisterAll(toolIDs)
|
|
||||||
cancel()
|
|
||||||
m.mu.Lock()
|
|
||||||
entry.info.Status = sdk.StatusError
|
|
||||||
m.mu.Unlock()
|
|
||||||
return fmt.Errorf("plugin %q tool register failed: %w", meta.Name, err)
|
|
||||||
}
|
|
||||||
toolIDs = append(toolIDs, t.Definition().ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mu.Lock()
|
|
||||||
entry.cancel = cancel
|
|
||||||
entry.info.Status = sdk.StatusRunning
|
|
||||||
entry.info.Enabled = true
|
|
||||||
entry.info.Tools = toolIDs
|
|
||||||
m.mu.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Disable stops a plugin and unregisters its tools.
|
|
||||||
func (m *PluginManager) Disable(ctx context.Context, pluginName string) error {
|
|
||||||
m.mu.Lock()
|
|
||||||
entry, ok := m.plugins[pluginName]
|
|
||||||
m.mu.Unlock()
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("plugin %q not found", pluginName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := entry.instance.Stop(ctx); err != nil {
|
|
||||||
return fmt.Errorf("plugin %q stop failed: %w", pluginName, err)
|
|
||||||
}
|
|
||||||
if entry.cancel != nil {
|
|
||||||
entry.cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
m.registry.UnregisterAll(entry.info.Tools)
|
|
||||||
|
|
||||||
m.mu.Lock()
|
|
||||||
entry.info.Status = sdk.StatusDisabled
|
|
||||||
entry.info.Enabled = false
|
|
||||||
entry.info.Tools = nil
|
|
||||||
m.mu.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// List returns info for all installed plugins.
|
|
||||||
func (m *PluginManager) List() []sdk.PluginInfo {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
result := make([]sdk.PluginInfo, 0, len(m.plugins))
|
|
||||||
for _, entry := range m.plugins {
|
|
||||||
result = append(result, entry.info)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns info for a single plugin.
|
|
||||||
func (m *PluginManager) Get(pluginName string) (*sdk.PluginInfo, bool) {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
entry, ok := m.plugins[pluginName]
|
|
||||||
if !ok {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
info := entry.info
|
|
||||||
return &info, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// EnableAll starts all installed plugins.
|
|
||||||
func (m *PluginManager) EnableAll(ctx context.Context) []error {
|
|
||||||
m.mu.RLock()
|
|
||||||
names := make([]string, 0, len(m.plugins))
|
|
||||||
for name := range m.plugins {
|
|
||||||
names = append(names, name)
|
|
||||||
}
|
|
||||||
m.mu.RUnlock()
|
|
||||||
|
|
||||||
var errs []error
|
|
||||||
for _, name := range names {
|
|
||||||
if err := m.Enable(ctx, name); err != nil {
|
|
||||||
errs = append(errs, fmt.Errorf("%s: %w", name, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return errs
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uninstall removes a plugin completely.
|
|
||||||
func (m *PluginManager) Uninstall(ctx context.Context, pluginName string) error {
|
|
||||||
m.mu.RLock()
|
|
||||||
entry, ok := m.plugins[pluginName]
|
|
||||||
m.mu.RUnlock()
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("plugin %q not found", pluginName)
|
|
||||||
}
|
|
||||||
if entry.info.Status == sdk.StatusRunning {
|
|
||||||
if err := m.Disable(ctx, pluginName); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.mu.Lock()
|
|
||||||
defer m.mu.Unlock()
|
|
||||||
delete(m.plugins, pluginName)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reload stops and re-starts a plugin.
|
|
||||||
func (m *PluginManager) Reload(ctx context.Context, pluginName string) error {
|
|
||||||
if err := m.Disable(ctx, pluginName); err != nil {
|
|
||||||
return fmt.Errorf("reload disable: %w", err)
|
|
||||||
}
|
|
||||||
return m.Enable(ctx, pluginName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown stops all running plugins gracefully.
|
|
||||||
func (m *PluginManager) Shutdown(ctx context.Context) []error {
|
|
||||||
m.mu.RLock()
|
|
||||||
names := make([]string, 0, len(m.plugins))
|
|
||||||
for name, entry := range m.plugins {
|
|
||||||
if entry.info.Status == sdk.StatusRunning {
|
|
||||||
names = append(names, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.mu.RUnlock()
|
|
||||||
|
|
||||||
var errs []error
|
|
||||||
for _, name := range names {
|
|
||||||
if err := m.Disable(ctx, name); err != nil {
|
|
||||||
errs = append(errs, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return errs
|
|
||||||
}
|
|
||||||
|
|
||||||
// Registry returns the aggregated tool registry.
|
|
||||||
func (m *PluginManager) Registry() *ToolRegistry {
|
|
||||||
return m.registry
|
|
||||||
}
|
|
||||||
@@ -1,326 +0,0 @@
|
|||||||
package manager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CtxKeyIsAdmin is the context key for the admin flag.
|
|
||||||
type ctxKey string
|
|
||||||
|
|
||||||
const CtxKeyIsAdmin ctxKey = "isAdmin"
|
|
||||||
|
|
||||||
// adminOnlyTools lists tools that require admin permission to execute.
|
|
||||||
var adminOnlyTools = map[string]bool{
|
|
||||||
"host_exec": true,
|
|
||||||
"os_exec": true,
|
|
||||||
"host_file": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsAdminFromCtx returns true if the context carries an admin flag.
|
|
||||||
func IsAdminFromCtx(ctx context.Context) bool {
|
|
||||||
v, _ := ctx.Value(CtxKeyIsAdmin).(bool)
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
// CallLogRecord 工具调用记录
|
|
||||||
type CallLogRecord struct {
|
|
||||||
CallID string `json:"call_id"`
|
|
||||||
ToolName string `json:"tool_name"`
|
|
||||||
Arguments string `json:"arguments"`
|
|
||||||
Output string `json:"output"`
|
|
||||||
Error string `json:"error"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
DurationMs int `json:"duration_ms"`
|
|
||||||
Timestamp int64 `json:"timestamp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// callLogRing 线程安全的环形缓冲区
|
|
||||||
type callLogRing struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
records []CallLogRecord
|
|
||||||
capacity int
|
|
||||||
head int
|
|
||||||
size int
|
|
||||||
}
|
|
||||||
|
|
||||||
func newCallLogRing(capacity int) *callLogRing {
|
|
||||||
return &callLogRing{capacity: capacity, records: make([]CallLogRecord, capacity)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *callLogRing) push(rec CallLogRecord) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
rec.CallID = fmt.Sprintf("%d", time.Now().UnixNano())
|
|
||||||
rec.Timestamp = time.Now().UnixMilli()
|
|
||||||
r.records[r.head] = rec
|
|
||||||
r.head = (r.head + 1) % r.capacity
|
|
||||||
if r.size < r.capacity {
|
|
||||||
r.size++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *callLogRing) getAll() []CallLogRecord {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
result := make([]CallLogRecord, r.size)
|
|
||||||
for i := 0; i < r.size; i++ {
|
|
||||||
idx := (r.head - 1 - i) % r.capacity
|
|
||||||
if idx < 0 {
|
|
||||||
idx += r.capacity
|
|
||||||
}
|
|
||||||
result[i] = r.records[idx]
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *callLogRing) statsByTool() map[string]map[string]interface{} {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
byTool := make(map[string]map[string]interface{})
|
|
||||||
for i := 0; i < r.size; i++ {
|
|
||||||
idx := (r.head - 1 - i) % r.capacity
|
|
||||||
if idx < 0 {
|
|
||||||
idx += r.capacity
|
|
||||||
}
|
|
||||||
rec := r.records[idx]
|
|
||||||
if _, ok := byTool[rec.ToolName]; !ok {
|
|
||||||
byTool[rec.ToolName] = map[string]interface{}{
|
|
||||||
"tool_name": rec.ToolName, "count": 0, "success_count": 0,
|
|
||||||
"fail_count": 0, "total_duration_ms": 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s := byTool[rec.ToolName]
|
|
||||||
s["count"] = s["count"].(int) + 1
|
|
||||||
if rec.Success {
|
|
||||||
s["success_count"] = s["success_count"].(int) + 1
|
|
||||||
} else {
|
|
||||||
s["fail_count"] = s["fail_count"].(int) + 1
|
|
||||||
}
|
|
||||||
s["total_duration_ms"] = s["total_duration_ms"].(int) + rec.DurationMs
|
|
||||||
}
|
|
||||||
return byTool
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToolRegistry aggregates tool definitions from all running plugins and dispatches execution.
|
|
||||||
type ToolRegistry struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
tools map[string]sdk.Tool // tool ID -> Tool
|
|
||||||
callLog *callLogRing
|
|
||||||
enabled bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewToolRegistry() *ToolRegistry {
|
|
||||||
return &ToolRegistry{
|
|
||||||
tools: make(map[string]sdk.Tool),
|
|
||||||
callLog: newCallLogRing(500),
|
|
||||||
enabled: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsEnabled returns whether tool execution is enabled.
|
|
||||||
func (r *ToolRegistry) IsEnabled() bool {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
return r.enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetEnabled enables or disables tool execution.
|
|
||||||
func (r *ToolRegistry) SetEnabled(enabled bool) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.enabled = enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefinitionNames returns all registered tool names.
|
|
||||||
func (r *ToolRegistry) DefinitionNames() []string {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
names := make([]string, 0, len(r.tools))
|
|
||||||
for id := range r.tools {
|
|
||||||
names = append(names, id)
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ToolRegistry) Register(tool sdk.Tool) error {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
id := tool.Definition().ID
|
|
||||||
if _, exists := r.tools[id]; exists {
|
|
||||||
return fmt.Errorf("tool %q already registered", id)
|
|
||||||
}
|
|
||||||
r.tools[id] = tool
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ToolRegistry) Unregister(toolID string) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
delete(r.tools, toolID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ToolRegistry) Get(toolID string) (sdk.Tool, bool) {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
t, ok := r.tools[toolID]
|
|
||||||
return t, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ToolRegistry) List() []sdk.Tool {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
result := make([]sdk.Tool, 0, len(r.tools))
|
|
||||||
for _, t := range r.tools {
|
|
||||||
result = append(result, t)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ToolRegistry) Definitions() []sdk.ToolDefinition {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
defs := make([]sdk.ToolDefinition, 0, len(r.tools))
|
|
||||||
for _, t := range r.tools {
|
|
||||||
defs = append(defs, t.Definition())
|
|
||||||
}
|
|
||||||
return defs
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ToolRegistry) Execute(ctx context.Context, toolID string, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
r.mu.RLock()
|
|
||||||
tool, ok := r.tools[toolID]
|
|
||||||
r.mu.RUnlock()
|
|
||||||
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
r.callLog.push(CallLogRecord{
|
|
||||||
ToolName: toolID, Error: fmt.Sprintf("tool %q not found", toolID),
|
|
||||||
Success: false, DurationMs: int(time.Since(startTime).Milliseconds()),
|
|
||||||
})
|
|
||||||
return nil, fmt.Errorf("tool %q not found", toolID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tool.Validate(args); err != nil {
|
|
||||||
r.callLog.push(CallLogRecord{
|
|
||||||
ToolName: toolID, Error: err.Error(), Success: false,
|
|
||||||
DurationMs: int(time.Since(startTime).Milliseconds()),
|
|
||||||
})
|
|
||||||
return &sdk.ToolResult{Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Admin-only tools: deny non-admin callers.
|
|
||||||
if adminOnlyTools[toolID] && !IsAdminFromCtx(ctx) {
|
|
||||||
errMsg := fmt.Sprintf("工具 %s 仅限管理员使用", toolID)
|
|
||||||
r.callLog.push(CallLogRecord{
|
|
||||||
ToolName: toolID, Error: errMsg, Success: false,
|
|
||||||
DurationMs: int(time.Since(startTime).Milliseconds()),
|
|
||||||
})
|
|
||||||
return &sdk.ToolResult{Success: false, Error: errMsg}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := tool.Execute(ctx, args)
|
|
||||||
durationMs := int(time.Since(startTime).Milliseconds())
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
r.callLog.push(CallLogRecord{
|
|
||||||
ToolName: toolID, Error: err.Error(), Success: false, DurationMs: durationMs,
|
|
||||||
})
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var argsJSON string
|
|
||||||
if args != nil {
|
|
||||||
if b, _ := json.Marshal(args); b != nil {
|
|
||||||
argsJSON = string(b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.callLog.push(CallLogRecord{
|
|
||||||
ToolName: toolID, Arguments: argsJSON, Output: result.Output,
|
|
||||||
Error: result.Error, Success: result.Success, DurationMs: durationMs,
|
|
||||||
})
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnregisterAll removes all tools matching given IDs.
|
|
||||||
func (r *ToolRegistry) UnregisterAll(toolIDs []string) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
for _, id := range toolIDs {
|
|
||||||
delete(r.tools, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCallLogs 获取工具调用记录(最新在前,支持按工具名过滤、分页)
|
|
||||||
func (r *ToolRegistry) GetCallLogs(toolName string, limit, offset int) ([]CallLogRecord, int) {
|
|
||||||
all := r.callLog.getAll()
|
|
||||||
|
|
||||||
// 过滤
|
|
||||||
var filtered []CallLogRecord
|
|
||||||
if toolName == "" {
|
|
||||||
filtered = all
|
|
||||||
} else {
|
|
||||||
filtered = make([]CallLogRecord, 0)
|
|
||||||
for _, rec := range all {
|
|
||||||
if rec.ToolName == toolName {
|
|
||||||
filtered = append(filtered, rec)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
total := len(filtered)
|
|
||||||
|
|
||||||
// 分页
|
|
||||||
if offset >= len(filtered) {
|
|
||||||
return []CallLogRecord{}, total
|
|
||||||
}
|
|
||||||
page := filtered[offset:]
|
|
||||||
if limit > 0 && limit < len(page) {
|
|
||||||
page = page[:limit]
|
|
||||||
}
|
|
||||||
|
|
||||||
return page, total
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCallStats 获取工具调用统计
|
|
||||||
func (r *ToolRegistry) GetCallStats() map[string]interface{} {
|
|
||||||
byTool := r.callLog.statsByTool()
|
|
||||||
totalCalls, successCount, failCount, totalDurationMs := 0, 0, 0, 0
|
|
||||||
toolStats := make([]map[string]interface{}, 0, len(byTool))
|
|
||||||
for _, s := range byTool {
|
|
||||||
count := s["count"].(int)
|
|
||||||
success := s["success_count"].(int)
|
|
||||||
fail := s["fail_count"].(int)
|
|
||||||
totalDur := s["total_duration_ms"].(int)
|
|
||||||
avgDur := 0.0
|
|
||||||
if count > 0 {
|
|
||||||
avgDur = float64(totalDur) / float64(count)
|
|
||||||
}
|
|
||||||
s["avg_duration_ms"] = avgDur
|
|
||||||
delete(s, "total_duration_ms")
|
|
||||||
toolStats = append(toolStats, s)
|
|
||||||
totalCalls += count
|
|
||||||
successCount += success
|
|
||||||
failCount += fail
|
|
||||||
totalDurationMs += totalDur
|
|
||||||
}
|
|
||||||
avgDuration := 0.0
|
|
||||||
if totalCalls > 0 {
|
|
||||||
avgDuration = float64(totalDurationMs) / float64(totalCalls)
|
|
||||||
}
|
|
||||||
successRate := 0.0
|
|
||||||
if totalCalls > 0 {
|
|
||||||
successRate = float64(successCount) / float64(totalCalls) * 100
|
|
||||||
}
|
|
||||||
return map[string]interface{}{
|
|
||||||
"total_calls": totalCalls, "success_count": successCount, "fail_count": failCount,
|
|
||||||
"success_rate": successRate, "avg_duration_ms": avgDuration, "by_tool": toolStats,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
package markdown
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MarkdownPlugin struct{ sdk.BasePlugin }
|
|
||||||
|
|
||||||
func (p *MarkdownPlugin) Metadata() sdk.PluginMetadata {
|
|
||||||
return sdk.PluginMetadata{
|
|
||||||
Name: "markdown", DisplayName: "Markdown Processor", Version: "1.0.0",
|
|
||||||
Description: "Markdown processing: to HTML, extract text/links/code, generate TOC",
|
|
||||||
Category: "format", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *MarkdownPlugin) Tools() []sdk.Tool { return []sdk.Tool{&MarkdownTool{}} }
|
|
||||||
|
|
||||||
type MarkdownTool struct{ sdk.BaseTool }
|
|
||||||
|
|
||||||
func (t *MarkdownTool) Definition() sdk.ToolDefinition {
|
|
||||||
return sdk.ToolDefinition{
|
|
||||||
ID: "markdown", Name: "markdown", DisplayName: "Markdown Processor",
|
|
||||||
Description: "Markdown processing. Convert to HTML, extract plain text, extract links/code blocks, generate TOC.",
|
|
||||||
Category: "format", Complexity: sdk.ComplexitySimple,
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"action": map[string]interface{}{"type": "string", "enum": []string{"to_html", "to_text", "extract_links", "extract_code", "table_of_contents"}},
|
|
||||||
"markdown": map[string]interface{}{"type": "string"},
|
|
||||||
},
|
|
||||||
"required": []string{"action", "markdown"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *MarkdownTool) Validate(args map[string]interface{}) error {
|
|
||||||
for _, k := range []string{"action", "markdown"} {
|
|
||||||
if _, ok := args[k]; !ok {
|
|
||||||
return fmt.Errorf("missing required parameter: %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *MarkdownTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
action, _ := args["action"].(string)
|
|
||||||
md, _ := args["markdown"].(string)
|
|
||||||
|
|
||||||
switch action {
|
|
||||||
case "to_html":
|
|
||||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: mdToHTML(md)}, nil
|
|
||||||
|
|
||||||
case "to_text":
|
|
||||||
text := md
|
|
||||||
reCode := regexp.MustCompile("(?s)```.*?```")
|
|
||||||
text = reCode.ReplaceAllString(text, "")
|
|
||||||
text = regexp.MustCompile(`\*\*([^*]+)\*\*`).ReplaceAllString(text, "$1")
|
|
||||||
text = regexp.MustCompile(`\*([^*]+)\*`).ReplaceAllString(text, "$1")
|
|
||||||
text = regexp.MustCompile(`~~([^~]+)~~`).ReplaceAllString(text, "$1")
|
|
||||||
text = regexp.MustCompile(`^#{1,6}\s+`).ReplaceAllString(text, "")
|
|
||||||
text = regexp.MustCompile(`^[*-]\s+`).ReplaceAllString(text, "- ")
|
|
||||||
text = regexp.MustCompile(`^>\s+`).ReplaceAllString(text, "")
|
|
||||||
text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n")
|
|
||||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: strings.TrimSpace(text)}, nil
|
|
||||||
|
|
||||||
case "extract_links":
|
|
||||||
re := regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`)
|
|
||||||
matches := re.FindAllStringSubmatch(md, -1)
|
|
||||||
if len(matches) == 0 {
|
|
||||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: "No links found"}, nil
|
|
||||||
}
|
|
||||||
var out strings.Builder
|
|
||||||
for i, m := range matches {
|
|
||||||
out.WriteString(fmt.Sprintf("%d. %s -> %s\n", i+1, m[1], m[2]))
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: out.String()}, nil
|
|
||||||
|
|
||||||
case "extract_code":
|
|
||||||
re := regexp.MustCompile("(?s)```(\\w*)\n?(.*?)```")
|
|
||||||
matches := re.FindAllStringSubmatch(md, -1)
|
|
||||||
if len(matches) == 0 {
|
|
||||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: "No code blocks found"}, nil
|
|
||||||
}
|
|
||||||
var out strings.Builder
|
|
||||||
for i, m := range matches {
|
|
||||||
lang := m[1]
|
|
||||||
if lang == "" {
|
|
||||||
lang = "text"
|
|
||||||
}
|
|
||||||
code := m[2]
|
|
||||||
if len([]rune(code)) > 500 {
|
|
||||||
code = string([]rune(code)[:500]) + "..."
|
|
||||||
}
|
|
||||||
out.WriteString(fmt.Sprintf("--- Block %d (%s) ---\n%s\n\n", i+1, lang, code))
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: out.String()}, nil
|
|
||||||
|
|
||||||
case "table_of_contents":
|
|
||||||
re := regexp.MustCompile(`(?m)^(#{1,6})\s+(.+)$`)
|
|
||||||
matches := re.FindAllStringSubmatch(md, -1)
|
|
||||||
if len(matches) == 0 {
|
|
||||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: "No headings found"}, nil
|
|
||||||
}
|
|
||||||
var out strings.Builder
|
|
||||||
for _, m := range matches {
|
|
||||||
depth := len(m[1])
|
|
||||||
indent := strings.Repeat(" ", depth-1)
|
|
||||||
out.WriteString(fmt.Sprintf("%s- %s\n", indent, m[2]))
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: out.String()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "markdown", Success: false, Error: "unknown action: " + action}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func mdToHTML(md string) string {
|
|
||||||
// Save code blocks
|
|
||||||
type placeholder struct {
|
|
||||||
orig string
|
|
||||||
content string
|
|
||||||
language string
|
|
||||||
}
|
|
||||||
blocks := []*placeholder{}
|
|
||||||
reCode := regexp.MustCompile("(?s)```(\\w*)\n?(.*?)```")
|
|
||||||
md = reCode.ReplaceAllStringFunc(md, func(s string) string {
|
|
||||||
m := reCode.FindStringSubmatch(s)
|
|
||||||
b := &placeholder{orig: fmt.Sprintf("\x00CODE%d\x00", len(blocks)), language: m[1], content: escapeHTML(m[2])}
|
|
||||||
blocks = append(blocks, b)
|
|
||||||
return b.orig
|
|
||||||
})
|
|
||||||
|
|
||||||
// Inline elements
|
|
||||||
md = regexp.MustCompile("`([^`]+)`").ReplaceAllString(md, "<code>$1</code>")
|
|
||||||
md = regexp.MustCompile(`!\[([^\]]*)\]\(([^)]+)\)`).ReplaceAllString(md, `<img src="$2" alt="$1">`)
|
|
||||||
md = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`).ReplaceAllString(md, `<a href="$2">$1</a>`)
|
|
||||||
md = regexp.MustCompile(`\*\*([^*]+)\*\*`).ReplaceAllString(md, `<strong>$1</strong>`)
|
|
||||||
md = regexp.MustCompile(`\*([^*]+)\*`).ReplaceAllString(md, `<em>$1</em>`)
|
|
||||||
md = regexp.MustCompile(`~~([^~]+)~~`).ReplaceAllString(md, `<del>$1</del>`)
|
|
||||||
md = regexp.MustCompile(`(?m)^#{6}\s+(.+)$`).ReplaceAllString(md, `<h6>$1</h6>`)
|
|
||||||
md = regexp.MustCompile(`(?m)^#{5}\s+(.+)$`).ReplaceAllString(md, `<h5>$1</h5>`)
|
|
||||||
md = regexp.MustCompile(`(?m)^#{4}\s+(.+)$`).ReplaceAllString(md, `<h4>$1</h4>`)
|
|
||||||
md = regexp.MustCompile(`(?m)^#{3}\s+(.+)$`).ReplaceAllString(md, `<h3>$1</h3>`)
|
|
||||||
md = regexp.MustCompile(`(?m)^#{2}\s+(.+)$`).ReplaceAllString(md, `<h2>$1</h2>`)
|
|
||||||
md = regexp.MustCompile(`(?m)^#{1}\s+(.+)$`).ReplaceAllString(md, `<h1>$1</h1>`)
|
|
||||||
md = regexp.MustCompile(`(?m)^---\s*$`).ReplaceAllString(md, `<hr>`)
|
|
||||||
md = regexp.MustCompile(`(?m)^>\s+(.+)$`).ReplaceAllString(md, `<blockquote>$1</blockquote>`)
|
|
||||||
|
|
||||||
// Restore code blocks
|
|
||||||
for _, b := range blocks {
|
|
||||||
langAttr := ""
|
|
||||||
if b.language != "" {
|
|
||||||
langAttr = " class=\"language-" + b.language + "\""
|
|
||||||
}
|
|
||||||
md = strings.Replace(md, b.orig, "<pre><code"+langAttr+">"+b.content+"</code></pre>", 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Paragraphs
|
|
||||||
lines := strings.Split(md, "\n")
|
|
||||||
var out strings.Builder
|
|
||||||
for _, line := range lines {
|
|
||||||
trimmed := strings.TrimSpace(line)
|
|
||||||
if trimmed == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(trimmed, "<") {
|
|
||||||
out.WriteString(trimmed + "\n")
|
|
||||||
} else {
|
|
||||||
out.WriteString("<p>" + trimmed + "</p>\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func escapeHTML(s string) string {
|
|
||||||
s = strings.ReplaceAll(s, "&", "&")
|
|
||||||
s = strings.ReplaceAll(s, "<", "<")
|
|
||||||
s = strings.ReplaceAll(s, ">", ">")
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
@@ -1,175 +0,0 @@
|
|||||||
package random
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
"math/big"
|
|
||||||
mathrand "math/rand"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RandomPlugin struct{ sdk.BasePlugin }
|
|
||||||
|
|
||||||
func (p *RandomPlugin) Metadata() sdk.PluginMetadata {
|
|
||||||
return sdk.PluginMetadata{
|
|
||||||
Name: "random", DisplayName: "Random Generator", Version: "1.0.0",
|
|
||||||
Description: "Random generation: numbers, UUIDs, secure passwords, pick/shuffle",
|
|
||||||
Category: "utility", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *RandomPlugin) Tools() []sdk.Tool { return []sdk.Tool{&RandomTool{}} }
|
|
||||||
|
|
||||||
type RandomTool struct{ sdk.BaseTool }
|
|
||||||
|
|
||||||
func (t *RandomTool) Definition() sdk.ToolDefinition {
|
|
||||||
return sdk.ToolDefinition{
|
|
||||||
ID: "random", Name: "random", DisplayName: "Random Generator",
|
|
||||||
Description: "Random generation. Random numbers, UUID v4, secure passwords, pick from list, shuffle list.",
|
|
||||||
Category: "utility", Complexity: sdk.ComplexitySimple,
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"action": map[string]interface{}{"type": "string", "enum": []string{"number", "uuid", "password", "pick", "shuffle"}},
|
|
||||||
"min": map[string]interface{}{"type": "number"},
|
|
||||||
"max": map[string]interface{}{"type": "number"},
|
|
||||||
"length": map[string]interface{}{"type": "number"},
|
|
||||||
"items": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
|
|
||||||
"count": map[string]interface{}{"type": "number"},
|
|
||||||
},
|
|
||||||
"required": []string{"action"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *RandomTool) Validate(args map[string]interface{}) error {
|
|
||||||
if _, ok := args["action"]; !ok {
|
|
||||||
return fmt.Errorf("missing required parameter: action")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *RandomTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
action, _ := args["action"].(string)
|
|
||||||
|
|
||||||
switch action {
|
|
||||||
case "number":
|
|
||||||
min := getIntArg(args, "min", 0)
|
|
||||||
max := getIntArg(args, "max", 100)
|
|
||||||
n, err := rand.Int(rand.Reader, big.NewInt(int64(max-min+1)))
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "random", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "random", Success: true,
|
|
||||||
Output: fmt.Sprintf("%d", int(n.Int64())+min)}, nil
|
|
||||||
|
|
||||||
case "uuid":
|
|
||||||
uuid := make([]byte, 16)
|
|
||||||
rand.Read(uuid)
|
|
||||||
uuid[6] = (uuid[6] & 0x0f) | 0x40
|
|
||||||
uuid[8] = (uuid[8] & 0x3f) | 0x80
|
|
||||||
return &sdk.ToolResult{ToolName: "random", Success: true,
|
|
||||||
Output: fmt.Sprintf("%x-%x-%x-%x-%x", uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:])}, nil
|
|
||||||
|
|
||||||
case "password":
|
|
||||||
length := getIntArg(args, "length", 16)
|
|
||||||
if length < 4 {
|
|
||||||
length = 4
|
|
||||||
}
|
|
||||||
if length > 128 {
|
|
||||||
length = 128
|
|
||||||
}
|
|
||||||
upper := "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
||||||
lower := "abcdefghijklmnopqrstuvwxyz"
|
|
||||||
digits := "0123456789"
|
|
||||||
symbols := "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
|
||||||
all := upper + lower + digits + symbols
|
|
||||||
bytes := make([]byte, length)
|
|
||||||
for i := range bytes {
|
|
||||||
idx, _ := rand.Int(rand.Reader, big.NewInt(int64(len(all))))
|
|
||||||
bytes[i] = all[idx.Int64()]
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "random", Success: true, Output: string(bytes)}, nil
|
|
||||||
|
|
||||||
case "pick":
|
|
||||||
items := getStringSliceArg(args, "items")
|
|
||||||
if len(items) == 0 {
|
|
||||||
return &sdk.ToolResult{ToolName: "random", Success: false, Error: "items list is empty"}, nil
|
|
||||||
}
|
|
||||||
count := getIntArg(args, "count", 1)
|
|
||||||
if count > len(items) {
|
|
||||||
count = len(items)
|
|
||||||
}
|
|
||||||
indices := shuffledIndices(len(items))
|
|
||||||
picked := make([]string, count)
|
|
||||||
for i := 0; i < count; i++ {
|
|
||||||
picked[i] = items[indices[i]]
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "random", Success: true, Output: strings.Join(picked, ", ")}, nil
|
|
||||||
|
|
||||||
case "shuffle":
|
|
||||||
items := getStringSliceArg(args, "items")
|
|
||||||
indices := shuffledIndices(len(items))
|
|
||||||
shuffled := make([]string, len(items))
|
|
||||||
for i, idx := range indices {
|
|
||||||
shuffled[i] = items[idx]
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "random", Success: true, Output: strings.Join(shuffled, ", ")}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "random", Success: false, Error: "unknown action: " + action}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getIntArg(args map[string]interface{}, key string, defaultVal int) int {
|
|
||||||
v, ok := args[key]
|
|
||||||
if !ok {
|
|
||||||
return defaultVal
|
|
||||||
}
|
|
||||||
switch n := v.(type) {
|
|
||||||
case float64:
|
|
||||||
return int(n)
|
|
||||||
case int:
|
|
||||||
return n
|
|
||||||
case int64:
|
|
||||||
return int(n)
|
|
||||||
}
|
|
||||||
return defaultVal
|
|
||||||
}
|
|
||||||
|
|
||||||
func getStringSliceArg(args map[string]interface{}, key string) []string {
|
|
||||||
v, ok := args[key]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
switch s := v.(type) {
|
|
||||||
case []string:
|
|
||||||
return s
|
|
||||||
case []interface{}:
|
|
||||||
result := make([]string, len(s))
|
|
||||||
for i, item := range s {
|
|
||||||
result[i] = fmt.Sprint(item)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func shuffledIndices(n int) []int {
|
|
||||||
indices := make([]int, n)
|
|
||||||
for i := range indices {
|
|
||||||
indices[i] = i
|
|
||||||
}
|
|
||||||
for i := n - 1; i > 0; i-- {
|
|
||||||
jBig, err := rand.Int(rand.Reader, big.NewInt(int64(i+1)))
|
|
||||||
if err != nil {
|
|
||||||
j := mathrand.Intn(i + 1)
|
|
||||||
indices[i], indices[j] = indices[j], indices[i]
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
j := int(jBig.Int64())
|
|
||||||
indices[i], indices[j] = indices[j], indices[i]
|
|
||||||
}
|
|
||||||
return indices
|
|
||||||
}
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
package sdk
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// BasePlugin provides default implementations for optional Plugin methods.
|
|
||||||
type BasePlugin struct{}
|
|
||||||
|
|
||||||
func (BasePlugin) Init(_ context.Context, _ PluginConfig) error { return nil }
|
|
||||||
|
|
||||||
func (BasePlugin) Start(_ context.Context, _ HostAPI) error { return nil }
|
|
||||||
|
|
||||||
func (BasePlugin) Stop(_ context.Context) error { return nil }
|
|
||||||
|
|
||||||
func (BasePlugin) Health(_ context.Context) error { return nil }
|
|
||||||
|
|
||||||
// BaseTool provides a Validate default that checks required parameters.
|
|
||||||
type BaseTool struct {
|
|
||||||
Def ToolDefinition
|
|
||||||
Required []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b BaseTool) Definition() ToolDefinition { return b.Def }
|
|
||||||
|
|
||||||
func (b BaseTool) Complexity() ToolComplexity { return ComplexitySimple }
|
|
||||||
|
|
||||||
func (b BaseTool) Validate(args map[string]interface{}) error {
|
|
||||||
for _, key := range b.Required {
|
|
||||||
if _, ok := args[key]; !ok {
|
|
||||||
return fmt.Errorf("missing required parameter: %s", key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b BaseTool) Execute(_ context.Context, _ map[string]interface{}) (*ToolResult, error) {
|
|
||||||
return nil, fmt.Errorf("not implemented")
|
|
||||||
}
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
package sdk
|
|
||||||
|
|
||||||
// PluginPermissions defines what a plugin is allowed to do.
|
|
||||||
type PluginPermissions struct {
|
|
||||||
NetworkAllowed bool `json:"networkAllowed"`
|
|
||||||
AllowedHosts []string `json:"allowedHosts,omitempty"`
|
|
||||||
IoTRead bool `json:"iotRead"`
|
|
||||||
IoTWrite bool `json:"iotWrite"`
|
|
||||||
MemoryRead bool `json:"memoryRead"`
|
|
||||||
MemoryWrite bool `json:"memoryWrite"`
|
|
||||||
FileRead bool `json:"fileRead"`
|
|
||||||
FileWrite bool `json:"fileWrite"`
|
|
||||||
AllowedPaths []string `json:"allowedPaths,omitempty"`
|
|
||||||
ExecAllowed bool `json:"execAllowed"`
|
|
||||||
MaxCPUPercent float64 `json:"maxCPUPercent"`
|
|
||||||
MaxMemoryMB int `json:"maxMemoryMB"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultPermissions returns a safe default permission set.
|
|
||||||
func DefaultPermissions() PluginPermissions {
|
|
||||||
return PluginPermissions{
|
|
||||||
NetworkAllowed: false,
|
|
||||||
AllowedHosts: []string{},
|
|
||||||
IoTRead: false,
|
|
||||||
IoTWrite: false,
|
|
||||||
MemoryRead: false,
|
|
||||||
MemoryWrite: false,
|
|
||||||
FileRead: false,
|
|
||||||
FileWrite: false,
|
|
||||||
AllowedPaths: []string{},
|
|
||||||
ExecAllowed: false,
|
|
||||||
MaxCPUPercent: 10.0,
|
|
||||||
MaxMemoryMB: 128,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
package sdk
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Plugin is the main interface every plugin must implement.
|
|
||||||
type Plugin interface {
|
|
||||||
Metadata() PluginMetadata
|
|
||||||
Init(ctx context.Context, config PluginConfig) error
|
|
||||||
Start(ctx context.Context, host HostAPI) error
|
|
||||||
Stop(ctx context.Context) error
|
|
||||||
Health(ctx context.Context) error
|
|
||||||
Tools() []Tool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tool is the interface every tool must implement.
|
|
||||||
type Tool interface {
|
|
||||||
Definition() ToolDefinition
|
|
||||||
Execute(ctx context.Context, args map[string]interface{}) (*ToolResult, error)
|
|
||||||
Validate(args map[string]interface{}) error
|
|
||||||
Complexity() ToolComplexity
|
|
||||||
}
|
|
||||||
|
|
||||||
// ComplexTool extends Tool for async multi-round execution.
|
|
||||||
type ComplexTool interface {
|
|
||||||
Tool
|
|
||||||
ExecuteAsync(ctx context.Context, args map[string]interface{}) (<-chan ToolProgress, error)
|
|
||||||
Cancel(ctx context.Context, executionID string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// HostAPI gives plugins access to Cyrene core capabilities.
|
|
||||||
type HostAPI interface {
|
|
||||||
CallLLM(ctx context.Context, messages []LLMMessage) (*LLMResponse, error)
|
|
||||||
SearchMemory(ctx context.Context, userID, query string, limit int) ([]MemoryEntry, error)
|
|
||||||
StoreMemory(ctx context.Context, entry MemoryEntry) error
|
|
||||||
Logger() Logger
|
|
||||||
GetConfig(key string) (string, error)
|
|
||||||
SetConfig(key, value string) error
|
|
||||||
PublishEvent(ctx context.Context, event map[string]interface{}) error
|
|
||||||
HTTPClient() *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
// Logger is a minimal logging interface for plugins.
|
|
||||||
type Logger interface {
|
|
||||||
Printf(format string, args ...interface{})
|
|
||||||
Println(args ...interface{})
|
|
||||||
}
|
|
||||||
@@ -1,134 +0,0 @@
|
|||||||
package sdk
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
// ToolComplexity grades tools into simple (single-call, <2s) and complex (multi-round, async).
|
|
||||||
type ToolComplexity string
|
|
||||||
|
|
||||||
const (
|
|
||||||
ComplexitySimple ToolComplexity = "simple"
|
|
||||||
ComplexityComplex ToolComplexity = "complex"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PluginMetadata describes a plugin's identity and requirements.
|
|
||||||
type PluginMetadata struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
DisplayName string `json:"displayName"`
|
|
||||||
Version string `json:"version"`
|
|
||||||
MinCyreneVersion string `json:"minCyreneVersion"`
|
|
||||||
Author PluginAuthor `json:"author"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
License string `json:"license"`
|
|
||||||
Keywords []string `json:"keywords,omitempty"`
|
|
||||||
Category string `json:"category"`
|
|
||||||
Dependencies map[string]string `json:"dependencies,omitempty"` // plugin name -> version range
|
|
||||||
Homepage string `json:"homepage,omitempty"`
|
|
||||||
Repository string `json:"repository,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PluginAuthor struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Email string `json:"email,omitempty"`
|
|
||||||
URL string `json:"url,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// PluginConfig holds runtime configuration for a plugin.
|
|
||||||
type PluginConfig map[string]interface{}
|
|
||||||
|
|
||||||
// ToolDefinition describes a tool's interface for LLM function calling.
|
|
||||||
type ToolDefinition struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
DisplayName string `json:"displayName"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Category string `json:"category"`
|
|
||||||
Complexity ToolComplexity `json:"complexity"`
|
|
||||||
Parameters map[string]interface{} `json:"parameters"`
|
|
||||||
Returns map[string]interface{} `json:"returns,omitempty"`
|
|
||||||
TimeoutMs int `json:"timeout_ms,omitempty"`
|
|
||||||
MaxRetries int `json:"max_retries,omitempty"`
|
|
||||||
DangerLevel string `json:"danger_level,omitempty"` // low / medium / high
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToolResult is the standard tool execution result.
|
|
||||||
type ToolResult struct {
|
|
||||||
ToolName string `json:"tool_name"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
Output string `json:"output,omitempty"`
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
DurationMs int64 `json:"duration_ms,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToolProgress reports execution progress for complex (async) tools.
|
|
||||||
type ToolProgress struct {
|
|
||||||
ExecutionID string `json:"execution_id"`
|
|
||||||
Status string `json:"status"` // started / running / completed / failed / cancelled
|
|
||||||
Progress float64 `json:"progress"` // 0.0 - 1.0
|
|
||||||
Message string `json:"message,omitempty"`
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
Result *ToolResult `json:"result,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// PluginStatus represents the current lifecycle state of a plugin.
|
|
||||||
type PluginStatus string
|
|
||||||
|
|
||||||
const (
|
|
||||||
StatusInstalled PluginStatus = "installed"
|
|
||||||
StatusLoaded PluginStatus = "loaded"
|
|
||||||
StatusRunning PluginStatus = "running"
|
|
||||||
StatusPaused PluginStatus = "paused"
|
|
||||||
StatusError PluginStatus = "error"
|
|
||||||
StatusDisabled PluginStatus = "disabled"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PluginInfo is the runtime view of an installed plugin.
|
|
||||||
type PluginInfo struct {
|
|
||||||
Metadata PluginMetadata `json:"metadata"`
|
|
||||||
Status PluginStatus `json:"status"`
|
|
||||||
Tools []string `json:"tools"` // tool IDs provided by this plugin
|
|
||||||
InstalledAt time.Time `json:"installed_at"`
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// LLMMessage is a message in an LLM conversation.
|
|
||||||
type LLMMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// LLMResponse is the result of an LLM call.
|
|
||||||
type LLMResponse struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToolCall represents a tool call requested by the LLM.
|
|
||||||
type ToolCall struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Arguments map[string]interface{} `json:"arguments"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// IoTDeviceState is the shared device state across IoT plugins.
|
|
||||||
type IoTDeviceState struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Brightness int `json:"brightness,omitempty"`
|
|
||||||
Color string `json:"color,omitempty"`
|
|
||||||
Mode string `json:"mode,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
Position int `json:"position,omitempty"`
|
|
||||||
Value float64 `json:"value,omitempty"`
|
|
||||||
Unit string `json:"unit,omitempty"`
|
|
||||||
Battery int `json:"battery,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// MemoryEntry is a memory record.
|
|
||||||
type MemoryEntry struct {
|
|
||||||
UserID string `json:"user_id"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Meta map[string]interface{} `json:"meta,omitempty"`
|
|
||||||
}
|
|
||||||
@@ -1,177 +0,0 @@
|
|||||||
package text
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"unicode"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TextPlugin struct{ sdk.BasePlugin }
|
|
||||||
|
|
||||||
func (p *TextPlugin) Metadata() sdk.PluginMetadata {
|
|
||||||
return sdk.PluginMetadata{
|
|
||||||
Name: "text", DisplayName: "Text Processing", Version: "1.0.0",
|
|
||||||
Description: "Text processing: count stats, summarize, regex extract",
|
|
||||||
Category: "utility", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *TextPlugin) Tools() []sdk.Tool { return []sdk.Tool{&TextTool{}} }
|
|
||||||
|
|
||||||
type TextTool struct{ sdk.BaseTool }
|
|
||||||
|
|
||||||
func (t *TextTool) Definition() sdk.ToolDefinition {
|
|
||||||
return sdk.ToolDefinition{
|
|
||||||
ID: "text", Name: "text", DisplayName: "Text Processing",
|
|
||||||
Description: "Text processing. Count stats, summarize, translate, regex extract.",
|
|
||||||
Category: "utility", Complexity: sdk.ComplexitySimple,
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{
|
|
||||||
"action": map[string]interface{}{"type": "string", "enum": []string{"count", "summarize", "translate", "extract"}},
|
|
||||||
"text": map[string]interface{}{"type": "string"},
|
|
||||||
"target_lang": map[string]interface{}{"type": "string", "enum": []string{"en", "zh", "ja", "ko", "fr", "de"}},
|
|
||||||
"pattern": map[string]interface{}{"type": "string"},
|
|
||||||
},
|
|
||||||
"required": []string{"action", "text"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TextTool) Validate(args map[string]interface{}) error {
|
|
||||||
for _, k := range []string{"action", "text"} {
|
|
||||||
if _, ok := args[k]; !ok {
|
|
||||||
return fmt.Errorf("missing required parameter: %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TextTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
action, _ := args["action"].(string)
|
|
||||||
txt, _ := args["text"].(string)
|
|
||||||
|
|
||||||
switch action {
|
|
||||||
case "count":
|
|
||||||
charsNoSpace := 0
|
|
||||||
chineseChars := 0
|
|
||||||
for _, r := range txt {
|
|
||||||
if !unicode.IsSpace(r) {
|
|
||||||
charsNoSpace++
|
|
||||||
}
|
|
||||||
if unicode.Is(unicode.Han, r) {
|
|
||||||
chineseChars++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
words := strings.Fields(txt)
|
|
||||||
lines := strings.Split(txt, "\n")
|
|
||||||
paragraphs := regexp.MustCompile(`\n\s*\n`).Split(txt, -1)
|
|
||||||
return &sdk.ToolResult{ToolName: "text", Success: true, Output: fmt.Sprintf(
|
|
||||||
"Characters: %d (no spaces: %d, Chinese: %d)\nBytes: %d\nWords: %d\nLines: %d\nParagraphs: %d",
|
|
||||||
len([]rune(txt)), charsNoSpace, chineseChars, len(txt), len(words), len(lines), len(paragraphs))}, nil
|
|
||||||
|
|
||||||
case "summarize":
|
|
||||||
paragraphs := regexp.MustCompile(`\n\s*\n`).Split(txt, -1)
|
|
||||||
firstPara := ""
|
|
||||||
if len(paragraphs) > 0 {
|
|
||||||
runes := []rune(paragraphs[0])
|
|
||||||
if len(runes) > 300 {
|
|
||||||
runes = runes[:300]
|
|
||||||
}
|
|
||||||
firstPara = string(runes)
|
|
||||||
}
|
|
||||||
sentences := regexp.MustCompile(`[。!?.!?]+`).Split(txt, -1)
|
|
||||||
keywords := []string{"重要", "关键", "因此", "总结", "important", "key", "conclusion", "therefore"}
|
|
||||||
type scored struct {
|
|
||||||
text string
|
|
||||||
score int
|
|
||||||
}
|
|
||||||
var scoredSents []scored
|
|
||||||
for _, s := range sentences {
|
|
||||||
s = strings.TrimSpace(s)
|
|
||||||
if len([]rune(s)) < 10 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
score := len([]rune(s))
|
|
||||||
for _, kw := range keywords {
|
|
||||||
if strings.Contains(strings.ToLower(s), strings.ToLower(kw)) {
|
|
||||||
score += 20
|
|
||||||
}
|
|
||||||
}
|
|
||||||
scoredSents = append(scoredSents, scored{s, score})
|
|
||||||
}
|
|
||||||
var out strings.Builder
|
|
||||||
out.WriteString(fmt.Sprintf("First paragraph: %s\n\nKey sentences:\n", firstPara))
|
|
||||||
count := 0
|
|
||||||
for i := 0; i < len(scoredSents) && count < 5; i++ {
|
|
||||||
out.WriteString(fmt.Sprintf("- %s\n", scoredSents[i].text))
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "text", Success: true, Output: out.String()}, nil
|
|
||||||
|
|
||||||
case "translate":
|
|
||||||
targetLang, _ := args["target_lang"].(string)
|
|
||||||
if targetLang == "" {
|
|
||||||
targetLang = "en"
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "text", Success: true, Output: fmt.Sprintf(
|
|
||||||
"[Translation request] Please translate the following text to %s.\n\nOriginal text:\n%s", targetLang, txt)}, nil
|
|
||||||
|
|
||||||
case "extract":
|
|
||||||
pattern, _ := args["pattern"].(string)
|
|
||||||
var out strings.Builder
|
|
||||||
extracted := false
|
|
||||||
if pattern == "" || pattern == "email" {
|
|
||||||
re := regexp.MustCompile(`[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}`)
|
|
||||||
if matches := re.FindAllString(txt, -1); len(matches) > 0 {
|
|
||||||
out.WriteString("Emails:\n")
|
|
||||||
for _, m := range matches {
|
|
||||||
out.WriteString(fmt.Sprintf("- %s\n", m))
|
|
||||||
}
|
|
||||||
extracted = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if pattern == "" || pattern == "phone" {
|
|
||||||
re := regexp.MustCompile(`1[3-9]\d{9}`)
|
|
||||||
if matches := re.FindAllString(txt, -1); len(matches) > 0 {
|
|
||||||
out.WriteString("Phone numbers:\n")
|
|
||||||
for _, m := range matches {
|
|
||||||
out.WriteString(fmt.Sprintf("- %s\n", m))
|
|
||||||
}
|
|
||||||
extracted = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if pattern == "" || pattern == "url" {
|
|
||||||
re := regexp.MustCompile(`https?://[^\s<>"{}|\\^` + "`" + `\[\]]+`)
|
|
||||||
if matches := re.FindAllString(txt, -1); len(matches) > 0 {
|
|
||||||
out.WriteString("URLs:\n")
|
|
||||||
for _, m := range matches {
|
|
||||||
out.WriteString(fmt.Sprintf("- %s\n", m))
|
|
||||||
}
|
|
||||||
extracted = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !extracted && pattern != "" && pattern != "email" && pattern != "phone" && pattern != "url" {
|
|
||||||
re, err := regexp.Compile(pattern)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "text", Success: false, Error: "Invalid regex: " + err.Error()}, nil
|
|
||||||
}
|
|
||||||
if matches := re.FindAllString(txt, -1); len(matches) > 0 {
|
|
||||||
out.WriteString(fmt.Sprintf("Pattern matches (%s):\n", pattern))
|
|
||||||
for _, m := range matches {
|
|
||||||
out.WriteString(fmt.Sprintf("- %s\n", m))
|
|
||||||
}
|
|
||||||
extracted = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !extracted {
|
|
||||||
return &sdk.ToolResult{ToolName: "text", Success: true, Output: "No matches found"}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "text", Success: true, Output: out.String()}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "text", Success: false, Error: "unknown action: " + action}, nil
|
|
||||||
}
|
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
package webfetch
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
type WebFetchPlugin struct {
|
|
||||||
sdk.BasePlugin
|
|
||||||
client *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWebFetchPlugin() *WebFetchPlugin {
|
|
||||||
return &WebFetchPlugin{client: &http.Client{Timeout: 15 * time.Second}}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WebFetchPlugin) Metadata() sdk.PluginMetadata {
|
|
||||||
return sdk.PluginMetadata{
|
|
||||||
Name: "web_fetch", DisplayName: "Web Fetch", Version: "1.0.0",
|
|
||||||
Description: "Fetch and extract text content from URLs",
|
|
||||||
Category: "network", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WebFetchPlugin) Tools() []sdk.Tool { return []sdk.Tool{&WebFetchTool{client: p.client}} }
|
|
||||||
|
|
||||||
type WebFetchTool struct {
|
|
||||||
sdk.BaseTool
|
|
||||||
client *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *WebFetchTool) Definition() sdk.ToolDefinition {
|
|
||||||
return sdk.ToolDefinition{
|
|
||||||
ID: "web_fetch", Name: "web_fetch", DisplayName: "Web Fetch",
|
|
||||||
Description: "Fetch content of a specified URL. Returns plain text summary (first 2000 characters). HTTP/HTTPS only.",
|
|
||||||
Category: "network", Complexity: sdk.ComplexitySimple,
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{"url": map[string]interface{}{"type": "string"}},
|
|
||||||
"required": []string{"url"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *WebFetchTool) Validate(args map[string]interface{}) error {
|
|
||||||
if _, ok := args["url"]; !ok {
|
|
||||||
return fmt.Errorf("missing required parameter: url")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *WebFetchTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
urlStr, _ := args["url"].(string)
|
|
||||||
if !strings.HasPrefix(urlStr, "http://") && !strings.HasPrefix(urlStr, "https://") {
|
|
||||||
return &sdk.ToolResult{ToolName: "web_fetch", Success: false, Error: "only http/https URLs allowed"}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", urlStr, nil)
|
|
||||||
req.Header.Set("User-Agent", "CyreneBot/1.0")
|
|
||||||
resp, err := t.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "web_fetch", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, 100*1024))
|
|
||||||
text := stripHTMLFull(string(bodyBytes))
|
|
||||||
text = removeBlankLines(text)
|
|
||||||
runes := []rune(text)
|
|
||||||
if len(runes) > 2000 {
|
|
||||||
text = string(runes[:2000]) + "..."
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "web_fetch", Success: true, Output: fmt.Sprintf(
|
|
||||||
"URL: %s\nStatus: %d\nContent-Type: %s\n\n%s",
|
|
||||||
urlStr, resp.StatusCode, resp.Header.Get("Content-Type"), text)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func stripHTMLFull(s string) string {
|
|
||||||
result := make([]rune, 0, len([]rune(s)))
|
|
||||||
inTag := false
|
|
||||||
for _, r := range s {
|
|
||||||
if r == '<' {
|
|
||||||
inTag = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if r == '>' {
|
|
||||||
inTag = false
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !inTag {
|
|
||||||
result = append(result, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return string(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeBlankLines(s string) string {
|
|
||||||
lines := strings.Split(s, "\n")
|
|
||||||
var result []string
|
|
||||||
for _, line := range lines {
|
|
||||||
trimmed := strings.TrimSpace(line)
|
|
||||||
if trimmed != "" {
|
|
||||||
result = append(result, trimmed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(result, "\n")
|
|
||||||
}
|
|
||||||
@@ -1,239 +0,0 @@
|
|||||||
package websearch
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
type WebSearchPlugin struct {
|
|
||||||
sdk.BasePlugin
|
|
||||||
client *http.Client
|
|
||||||
searxngURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWebSearchPlugin() *WebSearchPlugin {
|
|
||||||
return &WebSearchPlugin{client: &http.Client{Timeout: 10 * time.Second}}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWebSearchPluginWithURL(searxngURL string) *WebSearchPlugin {
|
|
||||||
return &WebSearchPlugin{
|
|
||||||
client: &http.Client{Timeout: 10 * time.Second},
|
|
||||||
searxngURL: strings.TrimRight(searxngURL, "/"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WebSearchPlugin) Metadata() sdk.PluginMetadata {
|
|
||||||
return sdk.PluginMetadata{
|
|
||||||
Name: "web_search", DisplayName: "Web Search", Version: "1.1.0",
|
|
||||||
Description: "Search the internet via SearXNG (or DuckDuckGo fallback)",
|
|
||||||
Category: "network", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WebSearchPlugin) Tools() []sdk.Tool {
|
|
||||||
return []sdk.Tool{&WebSearchTool{client: p.client, searxngURL: p.searxngURL}}
|
|
||||||
}
|
|
||||||
|
|
||||||
type WebSearchTool struct {
|
|
||||||
sdk.BaseTool
|
|
||||||
client *http.Client
|
|
||||||
searxngURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---- SearXNG response types ----
|
|
||||||
type searxngResponse struct {
|
|
||||||
Query string `json:"query"`
|
|
||||||
NumberOrResults int `json:"number_of_results"`
|
|
||||||
Results []searxngResult `json:"results"`
|
|
||||||
Answers []string `json:"answers"`
|
|
||||||
Suggestions []string `json:"suggestions"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type searxngResult struct {
|
|
||||||
Title string `json:"title"`
|
|
||||||
URL string `json:"url"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
Engine string `json:"engine"`
|
|
||||||
Score float64 `json:"score"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---- DuckDuckGo response types (fallback) ----
|
|
||||||
type ddgResponse struct {
|
|
||||||
Abstract string `json:"Abstract"`
|
|
||||||
AbstractText string `json:"AbstractText"`
|
|
||||||
Answer string `json:"Answer"`
|
|
||||||
Heading string `json:"Heading"`
|
|
||||||
Results []ddgTopic `json:"Results"`
|
|
||||||
RelatedTopics []ddgTopic `json:"RelatedTopics"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ddgTopic struct {
|
|
||||||
FirstURL string `json:"FirstURL"`
|
|
||||||
Text string `json:"Text"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *WebSearchTool) Definition() sdk.ToolDefinition {
|
|
||||||
return sdk.ToolDefinition{
|
|
||||||
ID: "web_search", Name: "web_search", DisplayName: "Web Search",
|
|
||||||
Description: "Search the internet. SearXNG backend with DuckDuckGo fallback. Returns up to 5 results.",
|
|
||||||
Category: "network", Complexity: sdk.ComplexitySimple,
|
|
||||||
Parameters: map[string]interface{}{
|
|
||||||
"type": "object",
|
|
||||||
"properties": map[string]interface{}{"query": map[string]interface{}{"type": "string"}},
|
|
||||||
"required": []string{"query"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *WebSearchTool) Validate(args map[string]interface{}) error {
|
|
||||||
if _, ok := args["query"]; !ok {
|
|
||||||
return fmt.Errorf("missing required parameter: query")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *WebSearchTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
|
||||||
query, _ := args["query"].(string)
|
|
||||||
if query == "" {
|
|
||||||
return &sdk.ToolResult{ToolName: "web_search", Success: false, Error: "empty query"}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.searxngURL != "" {
|
|
||||||
return t.searchViaSearXNG(query)
|
|
||||||
}
|
|
||||||
return t.searchViaDuckDuckGo(query)
|
|
||||||
}
|
|
||||||
|
|
||||||
// China-accessible SearXNG engines (baidu, sogou, 360search, bing all work from China)
|
|
||||||
const searxngEngines = "bing,sogou,360search,baidu"
|
|
||||||
|
|
||||||
func (t *WebSearchTool) searchViaSearXNG(query string) (*sdk.ToolResult, error) {
|
|
||||||
apiURL := fmt.Sprintf("%s/search?format=json&engines=%s&q=%s",
|
|
||||||
t.searxngURL, searxngEngines, url.QueryEscape(query))
|
|
||||||
|
|
||||||
resp, err := t.client.Get(apiURL)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "web_search", Success: false,
|
|
||||||
Error: fmt.Sprintf("SearXNG request failed: %v", err)}, nil
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return &sdk.ToolResult{ToolName: "web_search", Success: false,
|
|
||||||
Error: fmt.Sprintf("SearXNG returned HTTP %d", resp.StatusCode)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var result searxngResponse
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "web_search", Success: false,
|
|
||||||
Error: fmt.Sprintf("SearXNG parse error: %v", err)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var out strings.Builder
|
|
||||||
out.WriteString(fmt.Sprintf("搜索: %s (共%d条结果)\n\n", query, result.NumberOrResults))
|
|
||||||
|
|
||||||
// 优先显示答案(如 Wikipedia infobox)
|
|
||||||
for _, answer := range result.Answers {
|
|
||||||
out.WriteString(fmt.Sprintf("📌 %s\n\n", answer))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 搜索结果(最多5条,按score排序)
|
|
||||||
count := 0
|
|
||||||
for _, r := range result.Results {
|
|
||||||
if count >= 5 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if r.Title == "" || r.URL == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
content := cleanSnippet(r.Content)
|
|
||||||
out.WriteString(fmt.Sprintf("%d. **%s**\n %s\n %s\n\n", count+1, r.Title, r.URL, content))
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
|
|
||||||
if out.Len() == 0 {
|
|
||||||
return &sdk.ToolResult{ToolName: "web_search", Success: true,
|
|
||||||
Output: fmt.Sprintf("未找到与「%s」相关的结果。", query)}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "web_search", Success: true, Output: out.String()}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *WebSearchTool) searchViaDuckDuckGo(query string) (*sdk.ToolResult, error) {
|
|
||||||
apiURL := fmt.Sprintf("https://api.duckduckgo.com/?q=%s&format=json&no_html=1", url.QueryEscape(query))
|
|
||||||
resp, err := t.client.Get(apiURL)
|
|
||||||
if err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "web_search", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var result ddgResponse
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
||||||
return &sdk.ToolResult{ToolName: "web_search", Success: false, Error: err.Error()}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var out strings.Builder
|
|
||||||
if result.Answer != "" {
|
|
||||||
out.WriteString(fmt.Sprintf("Answer: %s\n\n", result.Answer))
|
|
||||||
}
|
|
||||||
if result.AbstractText != "" {
|
|
||||||
text := result.AbstractText
|
|
||||||
if len([]rune(text)) > 500 {
|
|
||||||
text = string([]rune(text)[:500]) + "..."
|
|
||||||
}
|
|
||||||
out.WriteString(fmt.Sprintf("Abstract: %s\n\n", stripHTML(text)))
|
|
||||||
}
|
|
||||||
topics := result.Results
|
|
||||||
if len(topics) == 0 {
|
|
||||||
topics = result.RelatedTopics
|
|
||||||
}
|
|
||||||
count := 0
|
|
||||||
for _, topic := range topics {
|
|
||||||
if count >= 5 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if topic.Text == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
out.WriteString(fmt.Sprintf("%d. %s (%s)\n", count+1, stripHTML(topic.Text), topic.FirstURL))
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
if out.Len() == 0 {
|
|
||||||
return &sdk.ToolResult{ToolName: "web_search", Success: true,
|
|
||||||
Output: "No results found for: " + query}, nil
|
|
||||||
}
|
|
||||||
return &sdk.ToolResult{ToolName: "web_search", Success: true, Output: out.String()}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func cleanSnippet(s string) string {
|
|
||||||
runes := []rune(strings.TrimSpace(s))
|
|
||||||
if len(runes) > 200 {
|
|
||||||
return string(runes[:200]) + "..."
|
|
||||||
}
|
|
||||||
return string(runes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func stripHTML(s string) string {
|
|
||||||
result := make([]rune, 0, len([]rune(s)))
|
|
||||||
inTag := false
|
|
||||||
for _, r := range s {
|
|
||||||
if r == '<' {
|
|
||||||
inTag = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if r == '>' {
|
|
||||||
inTag = false
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !inTag {
|
|
||||||
result = append(result, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.TrimSpace(string(result))
|
|
||||||
}
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
)
|
|
||||||
|
|
||||||
type hostAPI struct {
|
|
||||||
registry *manager.ToolRegistry
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHostAPI(registry *manager.ToolRegistry) *hostAPI {
|
|
||||||
return &hostAPI{registry: registry}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hostAPI) CallLLM(_ context.Context, _ []sdk.LLMMessage) (*sdk.LLMResponse, error) {
|
|
||||||
return nil, fmt.Errorf("LLM call not available in plugin host")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hostAPI) SearchMemory(_ context.Context, _, _ string, _ int) ([]sdk.MemoryEntry, error) {
|
|
||||||
return nil, fmt.Errorf("memory search not available in plugin host")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hostAPI) StoreMemory(_ context.Context, _ sdk.MemoryEntry) error {
|
|
||||||
return fmt.Errorf("memory store not available in plugin host")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hostAPI) Logger() sdk.Logger {
|
|
||||||
return log.Default()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hostAPI) GetConfig(key string) (string, error) {
|
|
||||||
return "", fmt.Errorf("config key not found: %s", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hostAPI) SetConfig(_, _ string) error { return nil }
|
|
||||||
|
|
||||||
func (h *hostAPI) PublishEvent(_ context.Context, _ map[string]interface{}) error { return nil }
|
|
||||||
|
|
||||||
func (h *hostAPI) HTTPClient() *http.Client {
|
|
||||||
return http.DefaultClient
|
|
||||||
}
|
|
||||||
@@ -1,112 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
iotquery "git.yeij.top/AskaEth/Cyrene/pkg/plugins/iot_query"
|
|
||||||
)
|
|
||||||
|
|
||||||
type iotClient struct {
|
|
||||||
baseURL string
|
|
||||||
httpClient *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func newIoTClient(baseURL string) *iotClient {
|
|
||||||
return &iotClient{
|
|
||||||
baseURL: baseURL,
|
|
||||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *iotClient) GetAllDevices(ctx context.Context) ([]sdk.IoTDeviceState, error) {
|
|
||||||
url := c.baseURL + "/api/v1/devices"
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
||||||
resp, err := c.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var result struct {
|
|
||||||
Devices []sdk.IoTDeviceState `json:"devices"`
|
|
||||||
}
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return result.Devices, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *iotClient) GetDevice(ctx context.Context, deviceID string) (*sdk.IoTDeviceState, error) {
|
|
||||||
url := fmt.Sprintf("%s/api/v1/devices/%s", c.baseURL, deviceID)
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
||||||
resp, err := c.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var dev sdk.IoTDeviceState
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&dev); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &dev, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// iotControllerAdapter adapts IoTClient to iotcontrol.IoTController.
|
|
||||||
type iotControllerAdapter struct {
|
|
||||||
query iotquery.IoTClient
|
|
||||||
client *http.Client
|
|
||||||
baseURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newIoTControllerAdapter(query iotquery.IoTClient, baseURL string) *iotControllerAdapter {
|
|
||||||
return &iotControllerAdapter{
|
|
||||||
query: query,
|
|
||||||
client: &http.Client{Timeout: 5 * time.Second},
|
|
||||||
baseURL: baseURL,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *iotControllerAdapter) GetDevice(ctx context.Context, deviceID string) (*sdk.IoTDeviceState, error) {
|
|
||||||
return a.query.GetDevice(ctx, deviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *iotControllerAdapter) SetDeviceProperty(ctx context.Context, deviceID, property string, value interface{}) error {
|
|
||||||
url := fmt.Sprintf("%s/api/v1/devices/%s/property", a.baseURL, deviceID)
|
|
||||||
body, _ := json.Marshal(map[string]interface{}{"property": property, "value": value})
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(body)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
resp, err := a.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
if resp.StatusCode >= 400 {
|
|
||||||
msg, _ := io.ReadAll(resp.Body)
|
|
||||||
return fmt.Errorf("set property failed: HTTP %d - %s", resp.StatusCode, string(msg))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *iotControllerAdapter) ToggleDevice(ctx context.Context, deviceID string) (*sdk.IoTDeviceState, error) {
|
|
||||||
url := fmt.Sprintf("%s/api/v1/devices/%s/toggle", a.baseURL, deviceID)
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "POST", url, nil)
|
|
||||||
resp, err := a.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var dev sdk.IoTDeviceState
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&dev); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &dev, nil
|
|
||||||
}
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/calculator"
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/crypto"
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/datetime"
|
|
||||||
fileplugin "git.yeij.top/AskaEth/Cyrene/pkg/plugins/file"
|
|
||||||
httpplugin "git.yeij.top/AskaEth/Cyrene/pkg/plugins/http"
|
|
||||||
iotcontrol "git.yeij.top/AskaEth/Cyrene/pkg/plugins/iot_control"
|
|
||||||
iotquery "git.yeij.top/AskaEth/Cyrene/pkg/plugins/iot_query"
|
|
||||||
jsonplugin "git.yeij.top/AskaEth/Cyrene/pkg/plugins/json"
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/markdown"
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/random"
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/text"
|
|
||||||
webfetch "git.yeij.top/AskaEth/Cyrene/pkg/plugins/web_fetch"
|
|
||||||
websearch "git.yeij.top/AskaEth/Cyrene/pkg/plugins/web_search"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/plugin-manager/internal/config"
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/plugin-manager/internal/handler"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
cfg := config.Load()
|
|
||||||
|
|
||||||
var iotAPI iotquery.IoTClient
|
|
||||||
if cfg.IoTSvcURL != "" {
|
|
||||||
iotAPI = newIoTClient(cfg.IoTSvcURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
registry := manager.NewToolRegistry()
|
|
||||||
host := newHostAPI(registry)
|
|
||||||
mgr := manager.NewPluginManager(registry, host)
|
|
||||||
|
|
||||||
builtins := []sdk.Plugin{
|
|
||||||
&calculator.CalculatorPlugin{},
|
|
||||||
&datetime.DatetimePlugin{},
|
|
||||||
&text.TextPlugin{},
|
|
||||||
&crypto.CryptoPlugin{},
|
|
||||||
&random.RandomPlugin{},
|
|
||||||
&markdown.MarkdownPlugin{},
|
|
||||||
&jsonplugin.JSONPlugin{},
|
|
||||||
fileplugin.NewFilePlugin(cfg.DataDir),
|
|
||||||
httpplugin.NewHTTPPlugin(),
|
|
||||||
websearch.NewWebSearchPlugin(),
|
|
||||||
webfetch.NewWebFetchPlugin(),
|
|
||||||
iotquery.NewIoTQueryPlugin(iotAPI),
|
|
||||||
}
|
|
||||||
for _, p := range builtins {
|
|
||||||
if err := mgr.Install(p); err != nil {
|
|
||||||
println("WARN: install plugin failed:", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if iotAPI != nil {
|
|
||||||
ctrlPlugin := iotcontrol.NewIoTControlPlugin(newIoTControllerAdapter(iotAPI, cfg.IoTSvcURL))
|
|
||||||
if err := mgr.Install(ctrlPlugin); err != nil {
|
|
||||||
println("WARN: install plugin failed:", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
errs := mgr.EnableAll(ctx)
|
|
||||||
for _, e := range errs {
|
|
||||||
println("WARN: enable plugin failed:", e.Error())
|
|
||||||
}
|
|
||||||
println("Plugin Manager: all built-in plugins enabled")
|
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
|
||||||
ph := handler.NewPluginHandler(mgr)
|
|
||||||
ph.RegisterRoutes(mux)
|
|
||||||
|
|
||||||
println("Plugin Manager listening on port", cfg.Port)
|
|
||||||
srv := &http.Server{Addr: ":" + cfg.Port, Handler: mux}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
||||||
println("FATAL:", err.Error())
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
quit := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
<-quit
|
|
||||||
println("Shutting down Plugin Manager...")
|
|
||||||
|
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
mgr.Shutdown(shutdownCtx)
|
|
||||||
srv.Shutdown(shutdownCtx)
|
|
||||||
println("Plugin Manager stopped")
|
|
||||||
}
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
module git.yeij.top/AskaEth/Cyrene/plugin-manager
|
|
||||||
|
|
||||||
go 1.26.2
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import "os"
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
Port string
|
|
||||||
Env string
|
|
||||||
DataDir string
|
|
||||||
IoTSvcURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
func Load() *Config {
|
|
||||||
cfg := &Config{
|
|
||||||
Port: "8094",
|
|
||||||
Env: "development",
|
|
||||||
DataDir: "./data",
|
|
||||||
IoTSvcURL: "http://localhost:8093",
|
|
||||||
}
|
|
||||||
if v := os.Getenv("PORT"); v != "" {
|
|
||||||
cfg.Port = v
|
|
||||||
}
|
|
||||||
if v := os.Getenv("ENV"); v != "" {
|
|
||||||
cfg.Env = v
|
|
||||||
}
|
|
||||||
if v := os.Getenv("DATA_DIR"); v != "" {
|
|
||||||
cfg.DataDir = v
|
|
||||||
}
|
|
||||||
if v := os.Getenv("IOT_SERVICE_URL"); v != "" {
|
|
||||||
cfg.IoTSvcURL = v
|
|
||||||
}
|
|
||||||
return cfg
|
|
||||||
}
|
|
||||||
@@ -1,210 +0,0 @@
|
|||||||
package handler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PluginHandler exposes the Plugin Manager REST API via net/http.
|
|
||||||
type PluginHandler struct {
|
|
||||||
mgr *manager.PluginManager
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPluginHandler(mgr *manager.PluginManager) *PluginHandler {
|
|
||||||
return &PluginHandler{mgr: mgr}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PluginHandler) RegisterRoutes(mux *http.ServeMux) {
|
|
||||||
mux.HandleFunc("/api/v1/plugins", h.listPlugins)
|
|
||||||
mux.HandleFunc("/api/v1/plugins/", h.pluginRoute)
|
|
||||||
mux.HandleFunc("/api/v1/tools", h.listTools)
|
|
||||||
mux.HandleFunc("/api/v1/tools/", h.toolRoute)
|
|
||||||
mux.HandleFunc("/api/v1/health", h.health)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PluginHandler) health(w http.ResponseWriter, r *http.Request) {
|
|
||||||
writeJSON(w, http.StatusOK, map[string]interface{}{"status": "ok", "service": "plugin-manager"})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PluginHandler) listPlugins(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method != "GET" {
|
|
||||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
plugins := h.mgr.List()
|
|
||||||
writeJSON(w, http.StatusOK, map[string]interface{}{"plugins": plugins, "total": len(plugins)})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PluginHandler) pluginRoute(w http.ResponseWriter, r *http.Request) {
|
|
||||||
path := strings.TrimPrefix(r.URL.Path, "/api/v1/plugins/")
|
|
||||||
parts := strings.SplitN(path, "/", 2)
|
|
||||||
pluginID := parts[0]
|
|
||||||
|
|
||||||
if pluginID == "" {
|
|
||||||
h.listPlugins(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(parts) == 1 {
|
|
||||||
switch r.Method {
|
|
||||||
case "GET":
|
|
||||||
h.getPlugin(w, pluginID)
|
|
||||||
case "DELETE":
|
|
||||||
h.uninstallPlugin(w, r, pluginID)
|
|
||||||
default:
|
|
||||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
action := parts[1]
|
|
||||||
switch action {
|
|
||||||
case "enable":
|
|
||||||
if r.Method != "POST" {
|
|
||||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.enablePlugin(w, r, pluginID)
|
|
||||||
case "disable":
|
|
||||||
if r.Method != "POST" {
|
|
||||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.disablePlugin(w, r, pluginID)
|
|
||||||
case "reload":
|
|
||||||
if r.Method != "POST" {
|
|
||||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.reloadPlugin(w, r, pluginID)
|
|
||||||
case "tools":
|
|
||||||
if r.Method != "GET" {
|
|
||||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.pluginTools(w, pluginID)
|
|
||||||
default:
|
|
||||||
writeJSON(w, http.StatusNotFound, errResp("not found"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PluginHandler) getPlugin(w http.ResponseWriter, id string) {
|
|
||||||
info, ok := h.mgr.Get(id)
|
|
||||||
if !ok {
|
|
||||||
writeJSON(w, http.StatusNotFound, errResp("plugin not found"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writeJSON(w, http.StatusOK, info)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PluginHandler) enablePlugin(w http.ResponseWriter, r *http.Request, id string) {
|
|
||||||
if err := h.mgr.Enable(r.Context(), id); err != nil {
|
|
||||||
writeJSON(w, http.StatusInternalServerError, errResp(err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writeJSON(w, http.StatusOK, map[string]string{"status": "enabled"})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PluginHandler) disablePlugin(w http.ResponseWriter, r *http.Request, id string) {
|
|
||||||
if err := h.mgr.Disable(r.Context(), id); err != nil {
|
|
||||||
writeJSON(w, http.StatusInternalServerError, errResp(err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writeJSON(w, http.StatusOK, map[string]string{"status": "disabled"})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PluginHandler) reloadPlugin(w http.ResponseWriter, r *http.Request, id string) {
|
|
||||||
if err := h.mgr.Reload(r.Context(), id); err != nil {
|
|
||||||
writeJSON(w, http.StatusInternalServerError, errResp(err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writeJSON(w, http.StatusOK, map[string]string{"status": "reloaded"})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PluginHandler) uninstallPlugin(w http.ResponseWriter, r *http.Request, id string) {
|
|
||||||
if err := h.mgr.Uninstall(r.Context(), id); err != nil {
|
|
||||||
writeJSON(w, http.StatusInternalServerError, errResp(err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writeJSON(w, http.StatusOK, map[string]string{"status": "uninstalled"})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PluginHandler) pluginTools(w http.ResponseWriter, id string) {
|
|
||||||
info, ok := h.mgr.Get(id)
|
|
||||||
if !ok {
|
|
||||||
writeJSON(w, http.StatusNotFound, errResp("plugin not found"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
registry := h.mgr.Registry()
|
|
||||||
tools := make([]interface{}, 0)
|
|
||||||
for _, toolID := range info.Tools {
|
|
||||||
if t, ok := registry.Get(toolID); ok {
|
|
||||||
tools = append(tools, t.Definition())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
writeJSON(w, http.StatusOK, map[string]interface{}{"tools": tools, "total": len(tools)})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PluginHandler) listTools(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method != "GET" {
|
|
||||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defs := h.mgr.Registry().Definitions()
|
|
||||||
writeJSON(w, http.StatusOK, map[string]interface{}{"tools": defs, "total": len(defs)})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PluginHandler) toolRoute(w http.ResponseWriter, r *http.Request) {
|
|
||||||
path := strings.TrimPrefix(r.URL.Path, "/api/v1/tools/")
|
|
||||||
toolID := path
|
|
||||||
|
|
||||||
if strings.HasSuffix(path, "/execute") {
|
|
||||||
toolID = strings.TrimSuffix(path, "/execute")
|
|
||||||
if r.Method != "POST" {
|
|
||||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.executeTool(w, r, toolID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Method != "GET" {
|
|
||||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tool, ok := h.mgr.Registry().Get(toolID)
|
|
||||||
if !ok {
|
|
||||||
writeJSON(w, http.StatusNotFound, errResp("tool not found"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writeJSON(w, http.StatusOK, tool.Definition())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PluginHandler) executeTool(w http.ResponseWriter, r *http.Request, toolID string) {
|
|
||||||
var body struct {
|
|
||||||
Arguments map[string]interface{} `json:"arguments"`
|
|
||||||
}
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
||||||
writeJSON(w, http.StatusBadRequest, errResp("invalid request body"))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
result, err := h.mgr.Registry().Execute(r.Context(), toolID, body.Arguments)
|
|
||||||
if err != nil {
|
|
||||||
writeJSON(w, http.StatusInternalServerError, errResp(err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writeJSON(w, http.StatusOK, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func errResp(msg string) map[string]string {
|
|
||||||
return map[string]string{"error": msg}
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(status)
|
|
||||||
json.NewEncoder(w).Encode(data)
|
|
||||||
}
|
|
||||||
@@ -19,7 +19,7 @@ RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /voice-service ./cmd/m
|
|||||||
# ========== 运行阶段 ==========
|
# ========== 运行阶段 ==========
|
||||||
FROM alpine:3.21
|
FROM alpine:3.21
|
||||||
|
|
||||||
RUN apk add --no-cache ca-certificates tzdata && \
|
RUN apk add --no-cache ca-certificates tzdata ffmpeg && \
|
||||||
cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \
|
cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \
|
||||||
echo "Asia/Shanghai" > /etc/timezone
|
echo "Asia/Shanghai" > /etc/timezone
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,224 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
mode := flag.String("mode", "offline", "测试模式: offline (非实时) 或 realtime (实时)")
|
||||||
|
file := flag.String("file", "", "音频文件路径 (WAV/MP3/OGG/FLAC)")
|
||||||
|
server := flag.String("server", "http://localhost:8093", "Voice-Service 地址")
|
||||||
|
lang := flag.String("lang", "zh", "语言代码")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if *file == "" {
|
||||||
|
fmt.Println("用法: test_asr -mode=offline -file=audio.wav [-server=http://localhost:8093]")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch *mode {
|
||||||
|
case "offline":
|
||||||
|
testOffline(*server, *file, *lang)
|
||||||
|
case "realtime":
|
||||||
|
testRealtime(*server, *file, *lang)
|
||||||
|
default:
|
||||||
|
fmt.Printf("未知模式: %s (支持: offline, realtime)\n", *mode)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testOffline 测试非实时语音识别 (HTTP multipart 上传)。
|
||||||
|
func testOffline(server, filePath, lang string) {
|
||||||
|
fmt.Printf("=== 非实时 ASR 测试 ===\n")
|
||||||
|
fmt.Printf("服务器: %s\n", server)
|
||||||
|
fmt.Printf("文件: %s\n", filePath)
|
||||||
|
fmt.Printf("语言: %s\n\n", lang)
|
||||||
|
|
||||||
|
// 读取音频文件
|
||||||
|
audioData, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("读取文件失败: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf("音频大小: %d bytes\n", len(audioData))
|
||||||
|
|
||||||
|
// 创建 multipart 请求
|
||||||
|
req, err := http.NewRequest("POST", server+"/api/v1/transcribe", nil)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("创建请求失败: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 multipart form
|
||||||
|
body, contentType, err := createMultipartBody(audioData, filePath, lang)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("创建 multipart body 失败: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
req.Body = body
|
||||||
|
req.Header.Set("Content-Type", contentType)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("请求失败: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
fmt.Printf("状态码: %d\n", resp.StatusCode)
|
||||||
|
fmt.Printf("耗时: %v\n", elapsed)
|
||||||
|
fmt.Printf("响应:\n%s\n", string(respBody))
|
||||||
|
|
||||||
|
if resp.StatusCode == 200 {
|
||||||
|
fmt.Println("\n✅ 非实时语音识别成功!")
|
||||||
|
} else {
|
||||||
|
fmt.Println("\n❌ 非实时语音识别失败")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testRealtime 测试实时语音识别 (WebSocket 流式)。
|
||||||
|
func testRealtime(server, filePath, lang string) {
|
||||||
|
fmt.Printf("=== 实时 ASR 测试 ===\n")
|
||||||
|
fmt.Printf("服务器: %s\n", server)
|
||||||
|
fmt.Printf("文件: %s\n", filePath)
|
||||||
|
fmt.Printf("语言: %s\n\n", lang)
|
||||||
|
|
||||||
|
// 读取音频文件
|
||||||
|
audioData, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("读取文件失败: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf("音频大小: %d bytes\n", len(audioData))
|
||||||
|
|
||||||
|
// 推断格式
|
||||||
|
format := inferFormat(filePath)
|
||||||
|
|
||||||
|
// 连接 WebSocket
|
||||||
|
wsURL := fmt.Sprintf("ws://%s/api/v1/stt/stream?format=%s&language=%s",
|
||||||
|
server[7:], format, lang) // 去掉 http:// 前缀
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("WebSocket 连接失败: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
fmt.Printf("WebSocket 已连接: %s\n", wsURL)
|
||||||
|
|
||||||
|
// 设置 interrupt 处理
|
||||||
|
interrupt := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(interrupt, os.Interrupt)
|
||||||
|
|
||||||
|
// goroutine: 读取识别结果
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
for {
|
||||||
|
_, msg, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("读取结果错误: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("◀ 结果: %s\n", string(msg))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 模拟实时流式发送音频(每 100ms 发送 3200 bytes)
|
||||||
|
chunkSize := 3200
|
||||||
|
totalSent := 0
|
||||||
|
start := time.Now()
|
||||||
|
var elapsed time.Duration
|
||||||
|
|
||||||
|
cancelled := false
|
||||||
|
for i := 0; i < len(audioData); i += chunkSize {
|
||||||
|
end := i + chunkSize
|
||||||
|
if end > len(audioData) {
|
||||||
|
end = len(audioData)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-interrupt:
|
||||||
|
fmt.Println("\n用户中断")
|
||||||
|
cancelled = true
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if cancelled {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.WriteMessage(websocket.BinaryMessage, audioData[i:end]); err != nil {
|
||||||
|
fmt.Printf("发送音频失败: %v\n", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
totalSent += end - i
|
||||||
|
fmt.Printf("▶ 发送 %d/%d bytes (%.1f%%)\n", totalSent, len(audioData),
|
||||||
|
float64(totalSent)/float64(len(audioData))*100)
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
elapsed = time.Since(start)
|
||||||
|
|
||||||
|
// 发送停止消息
|
||||||
|
conn.WriteMessage(websocket.TextMessage, []byte(`{"action":"stop"}`))
|
||||||
|
|
||||||
|
// 等待最后的结果
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
|
fmt.Printf("\n总耗时: %v, 总发送: %d bytes\n", elapsed, totalSent)
|
||||||
|
fmt.Println("✅ 实时语音识别测试完成")
|
||||||
|
}
|
||||||
|
|
||||||
|
func inferFormat(filename string) string {
|
||||||
|
ext := ""
|
||||||
|
for i := len(filename) - 1; i >= 0; i-- {
|
||||||
|
if filename[i] == '.' {
|
||||||
|
ext = filename[i+1:]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch ext {
|
||||||
|
case "wav", "wave":
|
||||||
|
return "wav"
|
||||||
|
case "mp3", "mpeg":
|
||||||
|
return "mp3"
|
||||||
|
case "ogg", "opus":
|
||||||
|
return "ogg"
|
||||||
|
case "flac":
|
||||||
|
return "flac"
|
||||||
|
case "m4a", "mp4", "aac":
|
||||||
|
return "m4a"
|
||||||
|
default:
|
||||||
|
return "pcm"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createMultipartBody(audioData []byte, filename, lang string) (io.ReadCloser, string, error) {
|
||||||
|
boundary := "cyrene-asr-test-boundary"
|
||||||
|
header := fmt.Sprintf("--%s\r\nContent-Disposition: form-data; name=\"audio\"; filename=\"%s\"\r\nContent-Type: application/octet-stream\r\n\r\n",
|
||||||
|
boundary, filename)
|
||||||
|
footer := fmt.Sprintf("\r\n--%s\r\nContent-Disposition: form-data; name=\"language\"\r\n\r\n%s\r\n--%s--\r\n",
|
||||||
|
boundary, lang, boundary)
|
||||||
|
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
go func() {
|
||||||
|
pw.Write([]byte(header))
|
||||||
|
pw.Write(audioData)
|
||||||
|
pw.Write([]byte(footer))
|
||||||
|
pw.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return pr, "multipart/form-data; boundary=" + boundary, nil
|
||||||
|
}
|
||||||
@@ -2,9 +2,15 @@ module git.yeij.top/AskaEth/Cyrene/voice-service
|
|||||||
|
|
||||||
go 1.26.2
|
go 1.26.2
|
||||||
|
|
||||||
replace git.yeij.top/AskaEth/Cyrene/pkg/logger => ../pkg/logger
|
replace (
|
||||||
|
git.yeij.top/AskaEth/Cyrene/pkg/logger => ../pkg/logger
|
||||||
|
git.yeij.top/AskaEth/Cyrene/pkg/audio => ../pkg/audio
|
||||||
|
git.yeij.top/AskaEth/Cyrene/pkg/dashscope => ../pkg/dashscope
|
||||||
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
|
git.yeij.top/AskaEth/Cyrene/pkg/audio v0.0.0
|
||||||
|
git.yeij.top/AskaEth/Cyrene/pkg/dashscope v0.0.0
|
||||||
git.yeij.top/AskaEth/Cyrene/pkg/logger v0.0.0
|
git.yeij.top/AskaEth/Cyrene/pkg/logger v0.0.0
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.yeij.top/AskaEth/Cyrene/pkg/dashscope"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DashScopeRESTSTT 使用 DashScope REST API 进行离线语音识别。
|
||||||
|
// 离线模型 (qwen3-asr-flash-2026-02-10) 通过 HTTP REST 端点进行转录,
|
||||||
|
// 无需 session 协商和 Server VAD,延迟更低,适合非实时场景。
|
||||||
|
type DashScopeRESTSTT struct {
|
||||||
|
model string
|
||||||
|
client *dashscope.RESTClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDashScopeRESTSTT 创建 DashScope REST STT 客户端。
|
||||||
|
func NewDashScopeRESTSTT(apiKey, model string) *DashScopeRESTSTT {
|
||||||
|
if model == "" {
|
||||||
|
model = "qwen3-asr-flash-2026-02-10"
|
||||||
|
}
|
||||||
|
return &DashScopeRESTSTT{
|
||||||
|
model: model,
|
||||||
|
client: dashscope.NewRESTClient(apiKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAvailable 检查 API Key 是否已配置。
|
||||||
|
func (d *DashScopeRESTSTT) IsAvailable() bool {
|
||||||
|
return d.client.IsAvailable()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model 返回模型名。
|
||||||
|
func (d *DashScopeRESTSTT) Model() string { return d.model }
|
||||||
|
|
||||||
|
// Transcribe 使用 DashScope REST API 进行离线语音识别。
|
||||||
|
func (d *DashScopeRESTSTT) Transcribe(ctx context.Context, audioData []byte, format, language string) (string, error) {
|
||||||
|
if !d.IsAvailable() {
|
||||||
|
return "", fmt.Errorf("DashScope REST ASR API key 未配置")
|
||||||
|
}
|
||||||
|
return d.client.Transcribe(ctx, d.model, audioData, format, 16000, language)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStatus 返回 REST STT 客户端的运行状态。
|
||||||
|
func (d *DashScopeRESTSTT) GetStatus() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"available": d.IsAvailable(),
|
||||||
|
"model": d.model,
|
||||||
|
"protocol": "rest",
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDashScopeRESTSTT_Transcribe_Success(t *testing.T) {
|
||||||
|
client := NewDashScopeRESTSTT("test-key", "qwen3-asr-flash-2026-02-10")
|
||||||
|
|
||||||
|
if !client.IsAvailable() {
|
||||||
|
t.Error("client should be available when apiKey is set")
|
||||||
|
}
|
||||||
|
if client.Model() != "qwen3-asr-flash-2026-02-10" {
|
||||||
|
t.Errorf("unexpected model: %s", client.Model())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashScopeRESTSTT_NotAvailable(t *testing.T) {
|
||||||
|
client := NewDashScopeRESTSTT("", "")
|
||||||
|
if client.IsAvailable() {
|
||||||
|
t.Error("client should not be available without apiKey")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashScopeRESTSTT_DefaultModel(t *testing.T) {
|
||||||
|
client := NewDashScopeRESTSTT("key", "")
|
||||||
|
if client.Model() != "qwen3-asr-flash-2026-02-10" {
|
||||||
|
t.Errorf("expected default model, got %s", client.Model())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashScopeRESTSTT_Transcribe_NoAPIKey(t *testing.T) {
|
||||||
|
client := NewDashScopeRESTSTT("", "")
|
||||||
|
_, err := client.Transcribe(context.Background(), []byte{}, "wav", "zh")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error when API key is not configured")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDashScopeRESTSTT_GetStatus(t *testing.T) {
|
||||||
|
client := NewDashScopeRESTSTT("key", "test-model")
|
||||||
|
status := client.GetStatus()
|
||||||
|
if status["available"] != true {
|
||||||
|
t.Error("status should be available")
|
||||||
|
}
|
||||||
|
if status["model"] != "test-model" {
|
||||||
|
t.Errorf("unexpected model in status: %v", status["model"])
|
||||||
|
}
|
||||||
|
if status["protocol"] != "rest" {
|
||||||
|
t.Errorf("unexpected protocol: %v", status["protocol"])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,12 +5,10 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.yeij.top/AskaEth/Cyrene/pkg/audio"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -145,7 +143,7 @@ func (d *DashScopeSTT) Transcribe(ctx context.Context, audioData []byte, format
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 4. 规范化音频格式并发送
|
// 4. 规范化音频格式并发送
|
||||||
pcmData, err := convertToPCM16(audioData, format)
|
pcmData, err := audio.ConvertToPCM16(audioData, format)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("音频格式转换失败: %w", err)
|
return "", fmt.Errorf("音频格式转换失败: %w", err)
|
||||||
}
|
}
|
||||||
@@ -447,74 +445,3 @@ func (d *DashScopeSTT) GetStatus() map[string]interface{} {
|
|||||||
"provider": "dashscope",
|
"provider": "dashscope",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeSTTFormat 规范化音频格式字符串。
|
|
||||||
func normalizeSTTFormat(format string) string {
|
|
||||||
switch strings.ToLower(format) {
|
|
||||||
case "pcm", "wav", "mp3", "mpeg", "ogg", "opus", "flac", "m4a", "mp4", "aac", "webm":
|
|
||||||
return strings.ToLower(format)
|
|
||||||
default:
|
|
||||||
return format
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// convertToPCM16 将音频数据转换为 16-bit PCM 16000Hz mono。
|
|
||||||
func convertToPCM16(data []byte, format string) ([]byte, error) {
|
|
||||||
normFormat := normalizeSTTFormat(format)
|
|
||||||
switch normFormat {
|
|
||||||
case "pcm":
|
|
||||||
return data, nil
|
|
||||||
case "wav":
|
|
||||||
if len(data) > 44 {
|
|
||||||
return data[44:], nil
|
|
||||||
}
|
|
||||||
return data, nil
|
|
||||||
default:
|
|
||||||
return transcodeToPCM(data, normFormat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// transcodeToPCM 使用 ffmpeg 将音频数据转码为 PCM 16-bit 16000Hz mono。
|
|
||||||
func transcodeToPCM(data []byte, format string) ([]byte, error) {
|
|
||||||
inFile, err := os.CreateTemp(os.TempDir(), "cyrene-asr-in-*."+format)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("创建输入临时文件失败: %w", err)
|
|
||||||
}
|
|
||||||
inPath := inFile.Name()
|
|
||||||
defer os.Remove(inPath)
|
|
||||||
if _, err := inFile.Write(data); err != nil {
|
|
||||||
inFile.Close()
|
|
||||||
return nil, fmt.Errorf("写入输入临时文件失败: %w", err)
|
|
||||||
}
|
|
||||||
inFile.Close()
|
|
||||||
|
|
||||||
outFile, err := os.CreateTemp(os.TempDir(), "cyrene-asr-out-*.pcm")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("创建输出临时文件失败: %w", err)
|
|
||||||
}
|
|
||||||
outPath := outFile.Name()
|
|
||||||
outFile.Close()
|
|
||||||
defer os.Remove(outPath)
|
|
||||||
|
|
||||||
cmd := exec.Command("ffmpeg",
|
|
||||||
"-i", inPath,
|
|
||||||
"-ar", "16000",
|
|
||||||
"-ac", "1",
|
|
||||||
"-c:a", "pcm_s16le",
|
|
||||||
"-f", "s16le",
|
|
||||||
outPath,
|
|
||||||
"-y",
|
|
||||||
)
|
|
||||||
cmd.Stderr = nil
|
|
||||||
|
|
||||||
if err := cmd.Run(); err != nil {
|
|
||||||
return nil, fmt.Errorf("音频转码失败 (ffmpeg): %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
outData, err := os.ReadFile(outPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("读取转码结果失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return outData, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -16,41 +16,59 @@ import (
|
|||||||
var SupportedLanguages = []string{"zh", "en", "ja", "ko", "auto"}
|
var SupportedLanguages = []string{"zh", "en", "ja", "ko", "auto"}
|
||||||
|
|
||||||
// STTService 语音转文字服务。
|
// STTService 语音转文字服务。
|
||||||
// 优先使用 DashScope API,不可用时回退到本地 Whisper。
|
// 离线转录优先使用 DashScope REST API,失败回退 Whisper。
|
||||||
|
// 流式转录使用 DashScope Realtime WS。
|
||||||
type STTService struct {
|
type STTService struct {
|
||||||
whisperBinary string
|
whisperBinary string
|
||||||
whisperModel string
|
whisperModel string
|
||||||
language string
|
language string
|
||||||
dashscope *DashScopeSTT // 实时 ASR (qwen3-asr-flash-realtime)
|
dashscope *DashScopeSTT // 实时 ASR (qwen3-asr-flash-realtime)
|
||||||
|
dashscopeREST *DashScopeRESTSTT // 离线 ASR (qwen3-asr-flash-2026-02-10)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSTTService 创建 STT 服务。
|
// NewSTTService 创建 STT 服务。
|
||||||
func NewSTTService(cfg *config.Config) *STTService {
|
func NewSTTService(cfg *config.Config) *STTService {
|
||||||
// 实时模型用于所有 WebSocket ASR 请求(支持 one-shot 和 streaming)
|
realtimeModel := cfg.DashScopeSTTRealtime
|
||||||
// 离线模型 (qwen3-asr-flash-2026-02-10) 是 HTTP REST API,暂未实现
|
if realtimeModel == "" {
|
||||||
model := cfg.DashScopeSTTRealtime
|
realtimeModel = "qwen3-asr-flash-realtime"
|
||||||
if model == "" {
|
}
|
||||||
model = cfg.DashScopeModel
|
offlineModel := cfg.DashScopeModel
|
||||||
|
if offlineModel == "" {
|
||||||
|
offlineModel = "qwen3-asr-flash-2026-02-10"
|
||||||
}
|
}
|
||||||
return &STTService{
|
return &STTService{
|
||||||
whisperBinary: cfg.WhisperBinary,
|
whisperBinary: cfg.WhisperBinary,
|
||||||
whisperModel: cfg.WhisperModel,
|
whisperModel: cfg.WhisperModel,
|
||||||
language: cfg.WhisperLanguage,
|
language: cfg.WhisperLanguage,
|
||||||
dashscope: NewDashScopeSTT(cfg.DashScopeAPIKey, model),
|
dashscope: NewDashScopeSTT(cfg.DashScopeAPIKey, realtimeModel),
|
||||||
|
dashscopeREST: NewDashScopeRESTSTT(cfg.DashScopeAPIKey, offlineModel),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAvailable 检查是否有任一 STT 引擎可用。
|
// IsAvailable 检查是否有任一 STT 引擎可用。
|
||||||
func (s *STTService) IsAvailable() bool {
|
func (s *STTService) IsAvailable() bool {
|
||||||
if s.dashscope.IsAvailable() {
|
if s.dashscopeREST.IsAvailable() || s.dashscope.IsAvailable() {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
_, err := os.Stat(s.whisperBinary)
|
return s.whisperAvailable()
|
||||||
return err == nil
|
}
|
||||||
|
|
||||||
|
// whisperAvailable 检查本地 Whisper 引擎是否真正可用。
|
||||||
|
func (s *STTService) whisperAvailable() bool {
|
||||||
|
if _, err := os.Stat(s.whisperBinary); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(s.whisperModel); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, err := exec.LookPath("ffmpeg"); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transcribe 将音频数据转录为文字。
|
// Transcribe 将音频数据转录为文字。
|
||||||
// 优先使用 DashScope,不可用时回退到本地 Whisper。
|
// 优先使用 DashScope REST 离线模型,失败回退到本地 Whisper。
|
||||||
func (s *STTService) Transcribe(audioData []byte, format string, language string) (string, error) {
|
func (s *STTService) Transcribe(audioData []byte, format string, language string) (string, error) {
|
||||||
if language == "" {
|
if language == "" {
|
||||||
language = s.language
|
language = s.language
|
||||||
@@ -59,16 +77,15 @@ func (s *STTService) Transcribe(audioData []byte, format string, language string
|
|||||||
return "", fmt.Errorf("不支持的语言: %s,支持的语言: %s", language, strings.Join(SupportedLanguages, ", "))
|
return "", fmt.Errorf("不支持的语言: %s,支持的语言: %s", language, strings.Join(SupportedLanguages, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 优先 DashScope
|
// 优先 DashScope REST 离线模型(低延迟,无需 session 协商)
|
||||||
if s.dashscope.IsAvailable() {
|
if s.dashscopeREST.IsAvailable() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
text, err := s.dashscope.Transcribe(ctx, audioData, format, language)
|
text, err := s.dashscopeREST.Transcribe(ctx, audioData, format, language)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return text, nil
|
return text, nil
|
||||||
}
|
}
|
||||||
// DashScope 失败,返回具体错误而不是回退到 Whisper
|
fmt.Printf("[stt] DashScope REST 失败,回退 Whisper: %v\n", err)
|
||||||
return "", fmt.Errorf("语音识别失败: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 回退到本地 Whisper
|
// 回退到本地 Whisper
|
||||||
@@ -152,15 +169,21 @@ func (s *STTService) GetStatus() map[string]interface{} {
|
|||||||
if _, err := os.Stat(s.whisperModel); err == nil {
|
if _, err := os.Stat(s.whisperModel); err == nil {
|
||||||
modelExists = true
|
modelExists = true
|
||||||
}
|
}
|
||||||
|
ffmpegAvailable := false
|
||||||
|
if _, err := exec.LookPath("ffmpeg"); err == nil {
|
||||||
|
ffmpegAvailable = true
|
||||||
|
}
|
||||||
|
|
||||||
return map[string]interface{}{
|
return map[string]interface{}{
|
||||||
"available": s.IsAvailable(),
|
"available": s.IsAvailable(),
|
||||||
"primary": "dashscope",
|
"primary": "dashscope_rest",
|
||||||
"dashscope": s.dashscope.GetStatus(),
|
"dashscope_rest": s.dashscopeREST.GetStatus(),
|
||||||
|
"dashscope_ws": s.dashscope.GetStatus(),
|
||||||
"whisper": map[string]interface{}{
|
"whisper": map[string]interface{}{
|
||||||
"available": binaryAvailable && modelExists,
|
"available": s.whisperAvailable(),
|
||||||
"binary_available": binaryAvailable,
|
"binary_available": binaryAvailable,
|
||||||
"model_loaded": modelExists,
|
"model_loaded": modelExists,
|
||||||
|
"ffmpeg_available": ffmpegAvailable,
|
||||||
"model_name": filepath.Base(s.whisperModel),
|
"model_name": filepath.Base(s.whisperModel),
|
||||||
},
|
},
|
||||||
"default_language": s.language,
|
"default_language": s.language,
|
||||||
|
|||||||
@@ -3,9 +3,14 @@
|
|||||||
**Base URL:** `http://<host>:8093` | **Auth:** 无
|
**Base URL:** `http://<host>:8093` | **Auth:** 无
|
||||||
|
|
||||||
语音服务封装两层引擎:
|
语音服务封装两层引擎:
|
||||||
- **STT (语音转文字):** DashScope `qwen3-asr-flash-realtime` (主) + 本地 Whisper (备)
|
- **STT (语音转文字):** DashScope REST 离线模型 `qwen3-asr-flash-2026-02-10` (主) → DashScope Realtime WS `qwen3-asr-flash-realtime` (流式) → 本地 Whisper (备)
|
||||||
- **TTS (文字转语音):** edge-tts (主) + espeak-ng (备)
|
- **TTS (文字转语音):** edge-tts (主) + espeak-ng (备)
|
||||||
|
|
||||||
|
> **引擎分层说明:**
|
||||||
|
> - **离线转录** (`POST /api/v1/transcribe`): 使用 DashScope REST API,无需 session 协商和 Server VAD,延迟更低。失败自动回退 Whisper。
|
||||||
|
> - **流式转录** (`GET /api/v1/stt/stream`): 使用 DashScope Realtime WebSocket,支持实时分片输入和中间结果输出,需客户端发送 PCM 音频。
|
||||||
|
> - **音频转码**: 所有非 PCM 格式通过 ffmpeg 转码为 16kHz mono PCM 后再识别,支持 WAV/MP3/OGG/FLAC/WebM/Opus/AMR 等格式。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
@@ -24,13 +29,25 @@
|
|||||||
|
|
||||||
**Content-Type:** `multipart/form-data` | **Max body:** 10 MB
|
**Content-Type:** `multipart/form-data` | **Max body:** 10 MB
|
||||||
|
|
||||||
|
### 引擎选择
|
||||||
|
|
||||||
|
优先使用 DashScope REST 离线模型 `qwen3-asr-flash-2026-02-10`,失败自动回退本地 Whisper。
|
||||||
|
|
||||||
|
```
|
||||||
|
接收音频 → DashScope REST API (HTTP POST)
|
||||||
|
↓ 失败
|
||||||
|
ffmpeg 转码 PCM → 本地 Whisper 引擎
|
||||||
|
```
|
||||||
|
|
||||||
### 表单字段
|
### 表单字段
|
||||||
|
|
||||||
| 字段 | 类型 | 必填 | 说明 |
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|------|------|------|------|
|
|------|------|------|------|
|
||||||
| `audio` | file | 是 | 音频文件。格式从扩展名推断:wav/mp3/ogg/flac/m4a |
|
| `audio` | file | 是 | 音频文件。支持格式: wav, mp3, ogg, flac, m4a, aac, webm, opus, amr, pcm |
|
||||||
| `language` | string | 否 | 默认 `"zh"`。可选: `zh`, `en`, `ja`, `ko`, `auto` |
|
| `language` | string | 否 | 默认 `"zh"`。可选: `zh`, `en`, `ja`, `ko`, `auto` |
|
||||||
|
|
||||||
|
> **转码说明:** 非 PCM 格式(含 Opus/WebM/AMR)通过 ffmpeg 自动转码为 16-bit PCM 16000Hz mono 后识别。需部署环境安装 ffmpeg。
|
||||||
|
|
||||||
### 响应 200
|
### 响应 200
|
||||||
|
|
||||||
```json
|
```json
|
||||||
@@ -49,7 +66,7 @@
|
|||||||
| 400 | `{"error":"文件过大或解析失败,最大支持 10MB"}` |
|
| 400 | `{"error":"文件过大或解析失败,最大支持 10MB"}` |
|
||||||
| 400 | `{"error":"缺少 audio 文件字段"}` |
|
| 400 | `{"error":"缺少 audio 文件字段"}` |
|
||||||
| 400 | `{"error":"音频文件为空"}` |
|
| 400 | `{"error":"音频文件为空"}` |
|
||||||
| 400 | `{"error":"不支持的音频格式: <ext>,支持的格式: WAV, MP3, OGG, FLAC, M4A"}` |
|
| 400 | `{"error":"不支持的语言: <lang>,支持的语言: zh, en, ja, ko, auto"}` |
|
||||||
| 405 | `{"error":"method not allowed"}` |
|
| 405 | `{"error":"method not allowed"}` |
|
||||||
| 500 | `{"error":"读取音频文件失败"}` |
|
| 500 | `{"error":"读取音频文件失败"}` |
|
||||||
| 500 | `{"success":false,"error":"<engine error>"}` |
|
| 500 | `{"success":false,"error":"<engine error>"}` |
|
||||||
@@ -112,12 +129,23 @@
|
|||||||
"service": "voice-service",
|
"service": "voice-service",
|
||||||
"stt": {
|
"stt": {
|
||||||
"available": true,
|
"available": true,
|
||||||
"primary": "dashscope",
|
"primary": "dashscope_rest",
|
||||||
"dashscope": { "available": true, "model": "qwen3-asr-flash-realtime", "provider": "dashscope" },
|
"dashscope_rest": {
|
||||||
|
"available": true,
|
||||||
|
"model": "qwen3-asr-flash-2026-02-10",
|
||||||
|
"protocol": "rest"
|
||||||
|
},
|
||||||
|
"dashscope_ws": {
|
||||||
|
"available": true,
|
||||||
|
"model": "qwen3-asr-flash-realtime",
|
||||||
|
"protocol": "websocket",
|
||||||
|
"state": "idle"
|
||||||
|
},
|
||||||
"whisper": {
|
"whisper": {
|
||||||
"available": true,
|
"available": true,
|
||||||
"binary_available": true,
|
"binary_available": true,
|
||||||
"model_loaded": true,
|
"model_loaded": true,
|
||||||
|
"ffmpeg_available": true,
|
||||||
"model_name": "ggml-small.bin"
|
"model_name": "ggml-small.bin"
|
||||||
},
|
},
|
||||||
"default_language": "zh",
|
"default_language": "zh",
|
||||||
@@ -138,9 +166,15 @@
|
|||||||
|
|
||||||
| 字段 | 说明 |
|
| 字段 | 说明 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| `stt.available` | DashScope 或 Whisper 至少一个可用 |
|
| `stt.available` | DashScope REST / WS 或 Whisper 至少一个可用 |
|
||||||
| `stt.dashscope.available` | DashScope API Key 已配置 |
|
| `stt.primary` | 当前优先引擎: `dashscope_rest` |
|
||||||
| `stt.whisper.available` | Whisper 二进制 + 模型文件均存在 |
|
| `stt.dashscope_rest.available` | DashScope REST API Key 已配置 |
|
||||||
|
| `stt.dashscope_rest.protocol` | 协议类型: `rest` |
|
||||||
|
| `stt.dashscope_ws.available` | DashScope Realtime WS 可用 |
|
||||||
|
| `stt.dashscope_ws.protocol` | 协议类型: `websocket` |
|
||||||
|
| `stt.dashscope_ws.state` | 连接状态: `idle`, `connected`, `error` |
|
||||||
|
| `stt.whisper.available` | Whisper 二进制 + 模型文件 + ffmpeg 均存在 |
|
||||||
|
| `stt.whisper.ffmpeg_available` | ffmpeg 可用于音频转码 |
|
||||||
| `tts.available` | 至少一个 TTS 引擎可用 |
|
| `tts.available` | 至少一个 TTS 引擎可用 |
|
||||||
| `tts.engine` | 当前激活引擎: `edge-tts`, `espeak-ng`, `fallback (silent WAV)`, `none` |
|
| `tts.engine` | 当前激活引擎: `edge-tts`, `espeak-ng`, `fallback (silent WAV)`, `none` |
|
||||||
|
|
||||||
@@ -183,6 +217,10 @@
|
|||||||
**Query 参数:** `?language=zh&format=pcm` (language 默认 zh, format 默认 pcm)
|
**Query 参数:** `?language=zh&format=pcm` (language 默认 zh, format 默认 pcm)
|
||||||
**Read deadline:** 300s
|
**Read deadline:** 300s
|
||||||
|
|
||||||
|
> **注意:** 此端点使用 DashScope Realtime WebSocket (`qwen3-asr-flash-realtime`),音频帧必须是 PCM 格式。非 PCM 格式应使用 REST 离线转录 (`POST /api/v1/transcribe`)。
|
||||||
|
>
|
||||||
|
> **Gateway 代理:** Gateway 的 `voice_stream_*` 消息类型通过此端点与前端 VAD 配合,实现端到端流式语音 → STT → LLM 管道。详见 [Gateway WebSocket 文档](../../gateway-api.md#流式语音输入流程-voice_stream_)。
|
||||||
|
|
||||||
### 客户端 → 服务端
|
### 客户端 → 服务端
|
||||||
|
|
||||||
**Binary 帧:** 原始 PCM 音频 (16-bit LE, 16000Hz, mono)。每帧通过 `input_audio_buffer.append` 转发到 DashScope。
|
**Binary 帧:** 原始 PCM 音频 (16-bit LE, 16000Hz, mono)。每帧通过 `input_audio_buffer.append` 转发到 DashScope。
|
||||||
|
|||||||
+137
-6
@@ -167,11 +167,14 @@ ws://<gateway>/ws/chat?token=<jwt>&session_id=<optional>&client_id=<optional>&de
|
|||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"type": "message|voice_input|ping|history",
|
"type": "message|voice_input|voice_stream_start|voice_stream_chunk|voice_stream_end|ping|history",
|
||||||
"session_id": "string (可选)",
|
"session_id": "string (可选)",
|
||||||
"mode": "text|voice_msg|voice_assistant",
|
"mode": "text|voice_msg|voice_assistant",
|
||||||
"content": "string (纯图片消息可留空,文字+图片时填写提问内容)",
|
"content": "string (纯图片消息可留空,文字+图片时填写提问内容)",
|
||||||
"audio_data": "string (voice_input 类型必填, base64)",
|
"audio_data": "string (voice_input / voice_stream_chunk 类型必填, base64)",
|
||||||
|
"format": "string (voice_stream_start 可选, 音频格式: webm, wav, pcm, opus; 默认 webm)",
|
||||||
|
"language": "string (voice_stream_start 可选, 识别语言: zh, en, ja, ko, auto; 默认 zh)",
|
||||||
|
"sequence": 0,
|
||||||
"attachments": [
|
"attachments": [
|
||||||
{
|
{
|
||||||
"type": "image",
|
"type": "image",
|
||||||
@@ -194,14 +197,18 @@ ws://<gateway>/ws/chat?token=<jwt>&session_id=<optional>&client_id=<optional>&de
|
|||||||
"timestamp": 1717000000000,
|
"timestamp": 1717000000000,
|
||||||
"client_id": "string",
|
"client_id": "string",
|
||||||
"device_name": "string",
|
"device_name": "string",
|
||||||
"user_agent": "string"
|
"user_agent": "string",
|
||||||
|
"client_msg_id": "string"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
| type | 说明 |
|
| type | 说明 |
|
||||||
|------|------|
|
|------|------|
|
||||||
| `message` | 文字聊天,触发 AI 回复 |
|
| `message` | 文字聊天,触发 AI 回复 |
|
||||||
| `voice_input` | 语音输入,先转录再作为 message 处理 |
|
| `voice_input` | 语音输入(完整音频),先转录再作为 message 处理 |
|
||||||
|
| `voice_stream_start` | 开启流式语音会话,Gateway 连接 Voice-Service 流式 STT |
|
||||||
|
| `voice_stream_chunk` | 流式语音音频分片 (base64),Gateway 转发至 Voice-Service |
|
||||||
|
| `voice_stream_end` | 结束流式语音,等待最终识别结果,自动触发 LLM 回复 |
|
||||||
| `ping` | 心跳,自动回复 pong |
|
| `ping` | 心跳,自动回复 pong |
|
||||||
| `history` | 请求历史消息 |
|
| `history` | 请求历史消息 |
|
||||||
|
|
||||||
@@ -244,7 +251,9 @@ ws://<gateway>/ws/chat?token=<jwt>&session_id=<optional>&client_id=<optional>&de
|
|||||||
| `stream_chunk` | 增量文本块 |
|
| `stream_chunk` | 增量文本块 |
|
||||||
| `stream_end` | AI 生成结束(含完整 text) |
|
| `stream_end` | AI 生成结束(含完整 text) |
|
||||||
| `stream_segments` | 流式断句(语音) |
|
| `stream_segments` | 流式断句(语音) |
|
||||||
| `voice_transcript` | 语音转录结果 |
|
| `voice_transcript` | 语音转录结果 (非流式, voice_input) |
|
||||||
|
| `voice_interim` | 流式语音中间识别结果 |
|
||||||
|
| `voice_final` | 流式语音最终识别文本 |
|
||||||
| `error` | 错误 |
|
| `error` | 错误 |
|
||||||
| `history_response` | 历史消息返回 |
|
| `history_response` | 历史消息返回 |
|
||||||
| `notification` | 推送通知 |
|
| `notification` | 推送通知 |
|
||||||
@@ -325,7 +334,7 @@ Client Gateway
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### 语音输入流程
|
### 语音输入流程 (非流式)
|
||||||
|
|
||||||
```
|
```
|
||||||
Client Gateway Voice-Service
|
Client Gateway Voice-Service
|
||||||
@@ -343,6 +352,128 @@ Client Gateway Voice-Service
|
|||||||
|<-- ... 正常流式回复 ... | |
|
|<-- ... 正常流式回复 ... | |
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **注意:** `voice_input` 为非流式模式,客户端发送完整音频后一次性获取转录结果。适合 MediaRecorder 录音完成后使用。
|
||||||
|
> 推荐使用下方的流式语音输入,配合前端 VAD 实现边说边识别。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 流式语音输入流程 (voice_stream_*)
|
||||||
|
|
||||||
|
配合前端 VAD (Voice Activity Detection) 实现自动语音检测和边说边识别。前端逐帧发送音频分片,Gateway 通过 WebSocket 代理到 Voice-Service 流式 STT,实时返回中间结果。
|
||||||
|
|
||||||
|
```
|
||||||
|
Client Gateway Voice-Service
|
||||||
|
| | |
|
||||||
|
|-- {type:"voice_stream_start", | |
|
||||||
|
| format:"webm", language:"zh"} --> | |
|
||||||
|
| |-- WS /api/v1/stt/stream --------> |
|
||||||
|
| |<-- session ready |
|
||||||
|
|<-- {type:"voice_interim", text:""} | |
|
||||||
|
| | |
|
||||||
|
|-- {type:"voice_stream_chunk", | |
|
||||||
|
| audio_data:"<base64>", | |
|
||||||
|
| sequence:0} ------------------> | |
|
||||||
|
| |-- binary audio frame ----------> |
|
||||||
|
| |<-- {type:"result", |
|
||||||
|
| | text:"你好", isFinal:false} |
|
||||||
|
|<-- {type:"voice_interim", | |
|
||||||
|
| text:"你好"} | |
|
||||||
|
| | |
|
||||||
|
|-- ... more chunks ... | |
|
||||||
|
| | |
|
||||||
|
|-- {type:"voice_stream_end"} -----> | |
|
||||||
|
| |-- {action:"stop"} --------------> |
|
||||||
|
| |<-- {type:"result", |
|
||||||
|
| | text:"你好世界", isFinal:true}|
|
||||||
|
|<-- {type:"voice_final", | |
|
||||||
|
| text:"你好世界"} | |
|
||||||
|
| | |
|
||||||
|
| (Gateway 自动将最终文本 | |
|
||||||
|
| 作为 message 发给 AI-Core) | |
|
||||||
|
|<-- {type:"stream_start"} | |
|
||||||
|
|<-- ... 正常流式 LLM 回复 ... | |
|
||||||
|
```
|
||||||
|
|
||||||
|
**消息详情:**
|
||||||
|
|
||||||
|
#### voice_stream_start — 开启流式语音会话
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "voice_stream_start",
|
||||||
|
"format": "webm",
|
||||||
|
"language": "zh"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| `format` | string | 否 | 音频格式,默认 `"webm"`。支持: `webm`, `wav`, `pcm`, `opus` |
|
||||||
|
| `language` | string | 否 | 识别语言,默认 `"zh"`。支持: `zh`, `en`, `ja`, `ko`, `auto` |
|
||||||
|
|
||||||
|
Gateway 收到后连接 Voice-Service 流式 STT WebSocket。成功时返回空 `voice_interim` 确认会话建立;失败返回 `error`。
|
||||||
|
|
||||||
|
#### voice_stream_chunk — 发送音频分片
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "voice_stream_chunk",
|
||||||
|
"audio_data": "<base64 encoded audio>",
|
||||||
|
"sequence": 0
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| `audio_data` | string | 是 | Base64 编码的音频数据 |
|
||||||
|
| `sequence` | int | 否 | 分片序号,从 0 递增,用于排序和去重 |
|
||||||
|
|
||||||
|
Gateway 将 audio_data 解码后以 binary 帧转发至 Voice-Service。无直接响应;识别结果通过 `voice_interim` 异步推送。
|
||||||
|
|
||||||
|
#### voice_stream_end — 结束流式语音
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "voice_stream_end"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Gateway 向 Voice-Service 发送 stop 信号,等待最终识别结果。最终文本通过 `voice_final` 返回,并自动触发 LLM 回复流程。
|
||||||
|
|
||||||
|
#### voice_interim — 中间识别结果 (Server → Client)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "voice_interim",
|
||||||
|
"message_id": "voice_<random>",
|
||||||
|
"text": "中间识别文本",
|
||||||
|
"timestamp": 1717000000000
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| 字段 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `text` | 当前累积的识别文本,**非最终结果**,会随更多音频输入而更新 |
|
||||||
|
|
||||||
|
> **前端处理:** 收到 `voice_interim` 后应在 UI 中展示实时识别文本(如灰色斜体),收到 `voice_final` 后替换为最终文本。
|
||||||
|
|
||||||
|
#### voice_final — 最终识别结果 (Server → Client)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "voice_final",
|
||||||
|
"message_id": "voice_<random>",
|
||||||
|
"text": "最终识别文本",
|
||||||
|
"timestamp": 1717000000000
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| 字段 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `text` | 最终的完整识别文本。空字符串表示未识别到语音 |
|
||||||
|
|
||||||
|
收到 `voice_final` 后,Gateway 自动将 `text` 作为 `message` 类型转发至 AI-Core 触发 LLM 流式回复。随后的 `stream_start` / `review` / `stream_end` 流程与普通文字消息相同。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 3. 会话管理
|
## 3. 会话管理
|
||||||
|
|||||||
+32
-24
@@ -8,7 +8,7 @@ setlocal enabledelayedexpansion
|
|||||||
set "SCRIPT_DIR=%~dp0"
|
set "SCRIPT_DIR=%~dp0"
|
||||||
set "ETHEND_DIR=%SCRIPT_DIR%ethend"
|
set "ETHEND_DIR=%SCRIPT_DIR%ethend"
|
||||||
set "ROOT=%SCRIPT_DIR%"
|
set "ROOT=%SCRIPT_DIR%"
|
||||||
if "%ETHEND_PORT%"=="" (set PORT=9090) else (set PORT=%ETHEND_PORT%)
|
if "%ETHEND_PORT%"=="" (set ETHEND_PORT=9090) else (set ETHEND_PORT=%ETHEND_PORT%)
|
||||||
set "LOG_DIR=%ETHEND_DIR%\logs"
|
set "LOG_DIR=%ETHEND_DIR%\logs"
|
||||||
set "LOG_FILE=%LOG_DIR%\sh.log"
|
set "LOG_FILE=%LOG_DIR%\sh.log"
|
||||||
set "DB_COMPOSE_FILE=%ROOT%docker-compose.dev.db.yml"
|
set "DB_COMPOSE_FILE=%ROOT%docker-compose.dev.db.yml"
|
||||||
@@ -63,7 +63,7 @@ echo ethend.bat logs gateway View Gateway log
|
|||||||
echo ethend.bat build ai-core Build AI-Core only
|
echo ethend.bat build ai-core Build AI-Core only
|
||||||
echo ethend.bat db:status Check database
|
echo ethend.bat db:status Check database
|
||||||
echo.
|
echo.
|
||||||
echo Web console: http://localhost:%PORT%
|
echo Web console: http://localhost:%ETHEND_PORT%
|
||||||
exit /b 0
|
exit /b 0
|
||||||
|
|
||||||
:: ==========================================
|
:: ==========================================
|
||||||
@@ -89,7 +89,7 @@ echo Cyrene ethend
|
|||||||
echo ==========================================
|
echo ==========================================
|
||||||
call :check_node
|
call :check_node
|
||||||
for /f "tokens=*" %%i in ('!NODE_CMD! --version') do echo Node.js: %%i
|
for /f "tokens=*" %%i in ('!NODE_CMD! --version') do echo Node.js: %%i
|
||||||
echo Port: %PORT%
|
echo Port: %ETHEND_PORT%
|
||||||
|
|
||||||
:: Check for --build / --fresh
|
:: Check for --build / --fresh
|
||||||
set DO_BUILD=0
|
set DO_BUILD=0
|
||||||
@@ -110,21 +110,21 @@ if %DO_BUILD%==1 (
|
|||||||
if %DO_FRESH%==1 (
|
if %DO_FRESH%==1 (
|
||||||
echo.
|
echo.
|
||||||
echo [INFO] Force restarting all services...
|
echo [INFO] Force restarting all services...
|
||||||
curl -s -X POST "http://localhost:%PORT%/api/services/start-all-fresh" >nul 2>&1
|
curl -s -X POST "http://localhost:%ETHEND_PORT%/api/services/start-all-fresh" >nul 2>&1
|
||||||
)
|
)
|
||||||
|
|
||||||
:: Check if already running
|
:: Check if already running
|
||||||
curl -s -o nul "http://localhost:%PORT%/api/health" 2>nul
|
curl -s -o nul "http://localhost:%ETHEND_PORT%/api/health" 2>nul
|
||||||
if %ERRORLEVEL%==0 (
|
if %ERRORLEVEL%==0 (
|
||||||
echo.
|
echo.
|
||||||
echo [OK] ethend already running: http://localhost:%PORT%
|
echo [OK] ethend already running: http://localhost:%ETHEND_PORT%
|
||||||
echo API: http://localhost:%PORT%/api/health
|
echo API: http://localhost:%ETHEND_PORT%/api/health
|
||||||
exit /b 0
|
exit /b 0
|
||||||
)
|
)
|
||||||
|
|
||||||
:: Free port
|
:: Free port
|
||||||
for /f "tokens=5" %%a in ('netstat -ano ^| findstr /R ":%PORT% " ^| findstr "LISTENING"') do (
|
for /f "tokens=5" %%a in ('netstat -ano ^| findstr /R ":%ETHEND_PORT% " ^| findstr "LISTENING"') do (
|
||||||
echo [WARN] Port %PORT% in use by PID %%a, releasing...
|
echo [WARN] Port %ETHEND_PORT% in use by PID %%a, releasing...
|
||||||
taskkill /PID %%a /F >nul 2>&1
|
taskkill /PID %%a /F >nul 2>&1
|
||||||
timeout /t 1 /nobreak >nul
|
timeout /t 1 /nobreak >nul
|
||||||
)
|
)
|
||||||
@@ -141,10 +141,10 @@ if not exist "node_modules\" (
|
|||||||
if not exist "%LOG_DIR%" mkdir "%LOG_DIR%"
|
if not exist "%LOG_DIR%" mkdir "%LOG_DIR%"
|
||||||
|
|
||||||
echo.
|
echo.
|
||||||
echo [INFO] Starting ethend on port %PORT%
|
echo [INFO] Starting ethend on port %ETHEND_PORT%
|
||||||
echo Web UI: http://localhost:%PORT%
|
echo Web UI: http://localhost:%ETHEND_PORT%
|
||||||
echo API: http://localhost:%PORT%/api/health
|
echo API: http://localhost:%ETHEND_PORT%/api/health
|
||||||
echo WebSocket: ws://localhost:%PORT%/ws
|
echo WebSocket: ws://localhost:%ETHEND_PORT%/ws
|
||||||
echo.
|
echo.
|
||||||
|
|
||||||
start "Cyrene-ethend" /B !NODE_CMD! src\index.js 1>>"%LOG_FILE%" 2>&1
|
start "Cyrene-ethend" /B !NODE_CMD! src\index.js 1>>"%LOG_FILE%" 2>&1
|
||||||
@@ -155,12 +155,12 @@ set MAX_WAIT=30
|
|||||||
set WAITED=0
|
set WAITED=0
|
||||||
:health_loop
|
:health_loop
|
||||||
if %WAITED% geq %MAX_WAIT% goto :timeout
|
if %WAITED% geq %MAX_WAIT% goto :timeout
|
||||||
powershell -Command "try { (Invoke-WebRequest 'http://localhost:%PORT%/api/health' -TimeoutSec 2 -UseBasicParsing).StatusCode -eq 200 } catch { $false }" | findstr "True" >nul 2>&1
|
powershell -Command "try { (Invoke-WebRequest 'http://localhost:%ETHEND_PORT%/api/health' -TimeoutSec 2 -UseBasicParsing).StatusCode -eq 200 } catch { $false }" | findstr "True" >nul 2>&1
|
||||||
if %ERRORLEVEL%==0 (
|
if %ERRORLEVEL%==0 (
|
||||||
echo.
|
echo.
|
||||||
echo ==========================================
|
echo ==========================================
|
||||||
echo ethend is ready!
|
echo ethend is ready!
|
||||||
echo Console: http://localhost:%PORT%
|
echo Console: http://localhost:%ETHEND_PORT%
|
||||||
echo Log: %LOG_FILE%
|
echo Log: %LOG_FILE%
|
||||||
echo ==========================================
|
echo ==========================================
|
||||||
echo.
|
echo.
|
||||||
@@ -173,12 +173,12 @@ goto :health_loop
|
|||||||
:timeout
|
:timeout
|
||||||
echo.
|
echo.
|
||||||
echo [WARN] Still starting - waited %MAX_WAIT%s
|
echo [WARN] Still starting - waited %MAX_WAIT%s
|
||||||
echo Check http://localhost:%PORT%/api/health
|
echo Check http://localhost:%ETHEND_PORT%/api/health
|
||||||
exit /b 0
|
exit /b 0
|
||||||
|
|
||||||
:: ==========================================
|
:: ==========================================
|
||||||
:stop
|
:stop
|
||||||
for /f "tokens=5" %%a in ('netstat -ano ^| findstr /R ":%PORT% " ^| findstr "LISTENING"') do (
|
for /f "tokens=5" %%a in ('netstat -ano ^| findstr /R ":%ETHEND_PORT% " ^| findstr "LISTENING"') do (
|
||||||
echo [INFO] Stopping ethend (PID: %%a)...
|
echo [INFO] Stopping ethend (PID: %%a)...
|
||||||
taskkill /PID %%a /F >nul 2>&1
|
taskkill /PID %%a /F >nul 2>&1
|
||||||
echo [OK] ethend stopped
|
echo [OK] ethend stopped
|
||||||
@@ -189,7 +189,7 @@ exit /b 0
|
|||||||
|
|
||||||
:: ==========================================
|
:: ==========================================
|
||||||
:status
|
:status
|
||||||
curl -s -o nul "http://localhost:%PORT%/api/health" 2>nul
|
curl -s -o nul "http://localhost:%ETHEND_PORT%/api/health" 2>nul
|
||||||
if %ERRORLEVEL% neq 0 (
|
if %ERRORLEVEL% neq 0 (
|
||||||
echo [ERROR] ethend is offline
|
echo [ERROR] ethend is offline
|
||||||
exit /b 1
|
exit /b 1
|
||||||
@@ -197,7 +197,7 @@ if %ERRORLEVEL% neq 0 (
|
|||||||
echo [OK] ethend online
|
echo [OK] ethend online
|
||||||
echo.
|
echo.
|
||||||
echo Service Status:
|
echo Service Status:
|
||||||
curl -s "http://localhost:%PORT%/api/services" 2>nul | !NODE_CMD! -e "const d=require('fs').readFileSync(0,'utf-8');const s=JSON.parse(d);for(const[k,v]of Object.entries(s)){const icon=v.status==='running'?'+':' ';const pid=v.pid?' (PID:'+v.pid+')':'';const upt=v.uptime?' uptime:'+Math.round(v.uptime/1000)+'s':'';console.log(' ['+icon+'] '+v.name.padEnd(18)+v.status+pid+upt);}" 2>nul
|
curl -s "http://localhost:%ETHEND_PORT%/api/services" 2>nul | !NODE_CMD! -e "const d=require('fs').readFileSync(0,'utf-8');const s=JSON.parse(d);for(const[k,v]of Object.entries(s)){const icon=v.status==='running'?'+':' ';const pid=v.pid?' (PID:'+v.pid+')':'';const upt=v.uptime?' uptime:'+Math.round(v.uptime/1000)+'s':'';console.log(' ['+icon+'] '+v.name.padEnd(18)+v.status+pid+upt);}" 2>nul
|
||||||
echo.
|
echo.
|
||||||
echo Database:
|
echo Database:
|
||||||
powershell -Command "$tcp=New-Object Net.Sockets.TcpClient;try{$tcp.Connect('127.0.0.1',5432);' [OK] PostgreSQL online'}catch{' [--] PostgreSQL offline'};$tcp.Dispose()" 2>nul
|
powershell -Command "$tcp=New-Object Net.Sockets.TcpClient;try{$tcp.Connect('127.0.0.1',5432);' [OK] PostgreSQL online'}catch{' [--] PostgreSQL offline'};$tcp.Dispose()" 2>nul
|
||||||
@@ -230,14 +230,18 @@ if not "%SVC_ID%"=="" goto :build_one
|
|||||||
|
|
||||||
:build_all
|
:build_all
|
||||||
echo [INFO] Building all backend services...
|
echo [INFO] Building all backend services...
|
||||||
set SERVICES=memory-service:backend/memory-service tool-engine:backend/tool-engine iot-debug-service:backend/iot-debug-service voice-service:backend/voice-service ai-core:backend/ai-core plugin-manager:backend/plugin-manager platform-bridge:backend/platform-bridge gateway:backend/gateway
|
set SERVICES=memory-service:backend/memory-service iot-debug-service:backend/iot-debug-service voice-service:backend/voice-service ai-core:backend/ai-core plugin-manager:backend/cyrene-plugins platform-bridge:backend/platform-bridge gateway:backend/gateway
|
||||||
for %%s in (%SERVICES%) do (
|
for %%s in (%SERVICES%) do (
|
||||||
for /f "tokens=1,2 delims=:" %%a in ("%%s") do (
|
for /f "tokens=1,2 delims=:" %%a in ("%%s") do (
|
||||||
if exist "%ROOT%%%b" (
|
if exist "%ROOT%%%b" (
|
||||||
echo Building %%a...
|
echo Building %%a...
|
||||||
cd /d "%ROOT%%%b"
|
cd /d "%ROOT%%%b"
|
||||||
set GOWORK=off
|
set GOWORK=off
|
||||||
go build -o main.exe .\cmd\main.go 2>&1
|
if "%%a"=="plugin-manager" (
|
||||||
|
go build -o main.exe .\cmd\plugin-manager\ 2>&1
|
||||||
|
) else (
|
||||||
|
go build -o main.exe .\cmd\main.go 2>&1
|
||||||
|
)
|
||||||
if !ERRORLEVEL!==0 (
|
if !ERRORLEVEL!==0 (
|
||||||
echo [OK] %%a
|
echo [OK] %%a
|
||||||
) else (
|
) else (
|
||||||
@@ -251,7 +255,7 @@ echo [OK] Build complete
|
|||||||
exit /b 0
|
exit /b 0
|
||||||
|
|
||||||
:build_one
|
:build_one
|
||||||
for %%s in (memory-service:backend/memory-service tool-engine:backend/tool-engine iot-debug-service:backend/iot-debug-service voice-service:backend/voice-service ai-core:backend/ai-core plugin-manager:backend/plugin-manager platform-bridge:backend/platform-bridge gateway:backend/gateway) do (
|
for %%s in (memory-service:backend/memory-service iot-debug-service:backend/iot-debug-service voice-service:backend/voice-service ai-core:backend/ai-core plugin-manager:backend/cyrene-plugins platform-bridge:backend/platform-bridge gateway:backend/gateway) do (
|
||||||
for /f "tokens=1,2 delims=:" %%a in ("%%s") do (
|
for /f "tokens=1,2 delims=:" %%a in ("%%s") do (
|
||||||
if "%%a"=="%SVC_ID%" (
|
if "%%a"=="%SVC_ID%" (
|
||||||
if not exist "%ROOT%%%b" (
|
if not exist "%ROOT%%%b" (
|
||||||
@@ -261,7 +265,11 @@ for %%s in (memory-service:backend/memory-service tool-engine:backend/tool-engin
|
|||||||
echo [INFO] Building %%a...
|
echo [INFO] Building %%a...
|
||||||
cd /d "%ROOT%%%b"
|
cd /d "%ROOT%%%b"
|
||||||
set GOWORK=off
|
set GOWORK=off
|
||||||
go build -o main.exe .\cmd\main.go 2>&1
|
if "%%a"=="plugin-manager" (
|
||||||
|
go build -o main.exe .\cmd\plugin-manager\ 2>&1
|
||||||
|
) else (
|
||||||
|
go build -o main.exe .\cmd\main.go 2>&1
|
||||||
|
)
|
||||||
if !ERRORLEVEL!==0 (echo [OK] %%a) else (echo [FAIL] %%a)
|
if !ERRORLEVEL!==0 (echo [OK] %%a) else (echo [FAIL] %%a)
|
||||||
cd /d "%ROOT%"
|
cd /d "%ROOT%"
|
||||||
exit /b !ERRORLEVEL!
|
exit /b !ERRORLEVEL!
|
||||||
@@ -269,7 +277,7 @@ for %%s in (memory-service:backend/memory-service tool-engine:backend/tool-engin
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
echo [ERROR] Unknown service: %SVC_ID%
|
echo [ERROR] Unknown service: %SVC_ID%
|
||||||
echo Available: memory-service, tool-engine, iot-debug-service, voice-service, ai-core, plugin-manager, platform-bridge, gateway
|
echo Available: memory-service, iot-debug-service, voice-service, ai-core, plugin-manager, platform-bridge, gateway
|
||||||
exit /b 1
|
exit /b 1
|
||||||
|
|
||||||
:: ==========================================
|
:: ==========================================
|
||||||
|
|||||||
@@ -185,11 +185,10 @@ db_stop() {
|
|||||||
# ========== 服务编译 ==========
|
# ========== 服务编译 ==========
|
||||||
SERVICES=(
|
SERVICES=(
|
||||||
"memory-service:backend/memory-service"
|
"memory-service:backend/memory-service"
|
||||||
"tool-engine:backend/tool-engine"
|
|
||||||
"iot-debug-service:backend/iot-debug-service"
|
"iot-debug-service:backend/iot-debug-service"
|
||||||
"voice-service:backend/voice-service"
|
"voice-service:backend/voice-service"
|
||||||
"ai-core:backend/ai-core"
|
"ai-core:backend/ai-core"
|
||||||
"plugin-manager:backend/plugin-manager"
|
"plugin-manager:backend/cyrene-plugins"
|
||||||
"platform-bridge:backend/platform-bridge"
|
"platform-bridge:backend/platform-bridge"
|
||||||
"gateway:backend/gateway"
|
"gateway:backend/gateway"
|
||||||
)
|
)
|
||||||
@@ -204,7 +203,11 @@ build_service() {
|
|||||||
$IS_WIN && binary="main.exe"
|
$IS_WIN && binary="main.exe"
|
||||||
|
|
||||||
cd "$ROOT/$dir"
|
cd "$ROOT/$dir"
|
||||||
if GOWORK=off go build -o "$binary" ./cmd/main.go 2>&1; then
|
local build_target="./cmd/main.go"
|
||||||
|
if [ "$id" = "plugin-manager" ]; then
|
||||||
|
build_target="./cmd/plugin-manager/"
|
||||||
|
fi
|
||||||
|
if GOWORK=off go build -o "$binary" "$build_target" 2>&1; then
|
||||||
echo -e "${GREEN} ✓ $label 编译完成${NC}"
|
echo -e "${GREEN} ✓ $label 编译完成${NC}"
|
||||||
return 0
|
return 0
|
||||||
else
|
else
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ export const SERVICES = {
|
|||||||
},
|
},
|
||||||
'plugin-manager': {
|
'plugin-manager': {
|
||||||
name: '插件管理器',
|
name: '插件管理器',
|
||||||
cwd: path.join(ROOT, 'backend/plugin-manager'),
|
cwd: path.join(ROOT, 'backend/cyrene-plugins'),
|
||||||
command: './main',
|
command: './main',
|
||||||
env: {
|
env: {
|
||||||
PORT: '8094',
|
PORT: '8094',
|
||||||
@@ -156,7 +156,7 @@ export const SERVICES = {
|
|||||||
healthUrl: IS_DOCKER ? 'http://plugin-manager:8094/api/v1/health' : 'http://localhost:8094/api/v1/health',
|
healthUrl: IS_DOCKER ? 'http://plugin-manager:8094/api/v1/health' : 'http://localhost:8094/api/v1/health',
|
||||||
port: 8094,
|
port: 8094,
|
||||||
buildCommand: 'go',
|
buildCommand: 'go',
|
||||||
buildArgs: ['build', '-o', isWin ? 'main.exe' : 'main', './cmd/main.go'],
|
buildArgs: ['build', '-o', isWin ? 'main.exe' : 'main', './cmd/plugin-manager/'],
|
||||||
goBin: GO_BIN,
|
goBin: GO_BIN,
|
||||||
},
|
},
|
||||||
'platform-bridge': {
|
'platform-bridge': {
|
||||||
|
|||||||
Generated
+156
@@ -8,6 +8,7 @@
|
|||||||
"name": "cyrene-frontend",
|
"name": "cyrene-frontend",
|
||||||
"version": "0.1.0",
|
"version": "0.1.0",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@ricky0123/vad-web": "^0.0.30",
|
||||||
"react": "^18.3.1",
|
"react": "^18.3.1",
|
||||||
"react-dom": "^18.3.1",
|
"react-dom": "^18.3.1",
|
||||||
"zustand": "^4.5.5"
|
"zustand": "^4.5.5"
|
||||||
@@ -848,6 +849,78 @@
|
|||||||
"node": ">= 8"
|
"node": ">= 8"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@protobufjs/aspromise": {
|
||||||
|
"version": "1.1.2",
|
||||||
|
"resolved": "https://registry.npmmirror.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz",
|
||||||
|
"integrity": "sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ==",
|
||||||
|
"license": "BSD-3-Clause"
|
||||||
|
},
|
||||||
|
"node_modules/@protobufjs/base64": {
|
||||||
|
"version": "1.1.2",
|
||||||
|
"resolved": "https://registry.npmmirror.com/@protobufjs/base64/-/base64-1.1.2.tgz",
|
||||||
|
"integrity": "sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg==",
|
||||||
|
"license": "BSD-3-Clause"
|
||||||
|
},
|
||||||
|
"node_modules/@protobufjs/codegen": {
|
||||||
|
"version": "2.0.5",
|
||||||
|
"resolved": "https://registry.npmmirror.com/@protobufjs/codegen/-/codegen-2.0.5.tgz",
|
||||||
|
"integrity": "sha512-zgXFLzW3Ap33e6d0Wlj4MGIm6Ce8O89n/apUaGNB/jx+hw+ruWEp7EwGUshdLKVRCxZW12fp9r40E1mQrf/34g==",
|
||||||
|
"license": "BSD-3-Clause"
|
||||||
|
},
|
||||||
|
"node_modules/@protobufjs/eventemitter": {
|
||||||
|
"version": "1.1.1",
|
||||||
|
"resolved": "https://registry.npmmirror.com/@protobufjs/eventemitter/-/eventemitter-1.1.1.tgz",
|
||||||
|
"integrity": "sha512-vW1GmwMZNnL+gMRaovlh9yZX74kc+TTU3FObkkurpMaRtBfLP3ldjS9KQWlwZgraRE0+dheEEoAxdzcJQ8eXZg==",
|
||||||
|
"license": "BSD-3-Clause"
|
||||||
|
},
|
||||||
|
"node_modules/@protobufjs/fetch": {
|
||||||
|
"version": "1.1.1",
|
||||||
|
"resolved": "https://registry.npmmirror.com/@protobufjs/fetch/-/fetch-1.1.1.tgz",
|
||||||
|
"integrity": "sha512-GpptLrs57adMSuHi3VNj0mAF8dwh36LMaYF6XyJ6JMWlVsc+t42tm1HSEDmOs3A8fC9yyeisgLhsTVQokOZ0zw==",
|
||||||
|
"license": "BSD-3-Clause",
|
||||||
|
"dependencies": {
|
||||||
|
"@protobufjs/aspromise": "^1.1.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@protobufjs/float": {
|
||||||
|
"version": "1.0.2",
|
||||||
|
"resolved": "https://registry.npmmirror.com/@protobufjs/float/-/float-1.0.2.tgz",
|
||||||
|
"integrity": "sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ==",
|
||||||
|
"license": "BSD-3-Clause"
|
||||||
|
},
|
||||||
|
"node_modules/@protobufjs/inquire": {
|
||||||
|
"version": "1.1.2",
|
||||||
|
"resolved": "https://registry.npmmirror.com/@protobufjs/inquire/-/inquire-1.1.2.tgz",
|
||||||
|
"integrity": "sha512-pa0vFRuws4wkvaXKK1uXZMAwAX4/t8ANaJo45iw/oQHNQ9q5xUzwgFmVJGXiga2BeN+zpX7Vf9vmsiIa2J+MUw==",
|
||||||
|
"license": "BSD-3-Clause"
|
||||||
|
},
|
||||||
|
"node_modules/@protobufjs/path": {
|
||||||
|
"version": "1.1.2",
|
||||||
|
"resolved": "https://registry.npmmirror.com/@protobufjs/path/-/path-1.1.2.tgz",
|
||||||
|
"integrity": "sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA==",
|
||||||
|
"license": "BSD-3-Clause"
|
||||||
|
},
|
||||||
|
"node_modules/@protobufjs/pool": {
|
||||||
|
"version": "1.1.0",
|
||||||
|
"resolved": "https://registry.npmmirror.com/@protobufjs/pool/-/pool-1.1.0.tgz",
|
||||||
|
"integrity": "sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw==",
|
||||||
|
"license": "BSD-3-Clause"
|
||||||
|
},
|
||||||
|
"node_modules/@protobufjs/utf8": {
|
||||||
|
"version": "1.1.1",
|
||||||
|
"resolved": "https://registry.npmmirror.com/@protobufjs/utf8/-/utf8-1.1.1.tgz",
|
||||||
|
"integrity": "sha512-oOAWABowe8EAbMyWKM0tYDKi8Yaox52D+HWZhAIJqQXbqe0xI/GV7FhLWqlEKreMkfDjshR5FKgi3mnle0h6Eg==",
|
||||||
|
"license": "BSD-3-Clause"
|
||||||
|
},
|
||||||
|
"node_modules/@ricky0123/vad-web": {
|
||||||
|
"version": "0.0.30",
|
||||||
|
"resolved": "https://registry.npmmirror.com/@ricky0123/vad-web/-/vad-web-0.0.30.tgz",
|
||||||
|
"integrity": "sha512-cJyYrh4YeeUBJcbR9Bic/bFDyB9qBkAepvpuWM3vLxnAi7bC3VHzf51UeNdT+OtY4D7MLAgV8iJMc4z41ZnaWg==",
|
||||||
|
"license": "ISC",
|
||||||
|
"dependencies": {
|
||||||
|
"onnxruntime-web": "^1.17.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@rolldown/pluginutils": {
|
"node_modules/@rolldown/pluginutils": {
|
||||||
"version": "1.0.0-beta.27",
|
"version": "1.0.0-beta.27",
|
||||||
"resolved": "https://registry.npmjs.org/@rolldown/pluginutils/-/pluginutils-1.0.0-beta.27.tgz",
|
"resolved": "https://registry.npmjs.org/@rolldown/pluginutils/-/pluginutils-1.0.0-beta.27.tgz",
|
||||||
@@ -1257,6 +1330,15 @@
|
|||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
|
"node_modules/@types/node": {
|
||||||
|
"version": "25.9.2",
|
||||||
|
"resolved": "https://registry.npmmirror.com/@types/node/-/node-25.9.2.tgz",
|
||||||
|
"integrity": "sha512-G05zqtJhcDLb8uslf5EjCxXg9G1KQxiV8OS0R26IC//Eoyitzqe8z37I7cqvnZlrlSfgocQRfSn/AHBZJJFyGw==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"undici-types": ">=7.24.0 <7.24.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@types/prop-types": {
|
"node_modules/@types/prop-types": {
|
||||||
"version": "15.7.15",
|
"version": "15.7.15",
|
||||||
"resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.15.tgz",
|
"resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.15.tgz",
|
||||||
@@ -1704,6 +1786,12 @@
|
|||||||
"node": ">=8"
|
"node": ">=8"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/flatbuffers": {
|
||||||
|
"version": "25.9.23",
|
||||||
|
"resolved": "https://registry.npmmirror.com/flatbuffers/-/flatbuffers-25.9.23.tgz",
|
||||||
|
"integrity": "sha512-MI1qs7Lo4Syw0EOzUl0xjs2lsoeqFku44KpngfIduHBYvzm8h2+7K8YMQh1JtVVVrUvhLpNwqVi4DERegUJhPQ==",
|
||||||
|
"license": "Apache-2.0"
|
||||||
|
},
|
||||||
"node_modules/fraction.js": {
|
"node_modules/fraction.js": {
|
||||||
"version": "5.3.4",
|
"version": "5.3.4",
|
||||||
"resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-5.3.4.tgz",
|
"resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-5.3.4.tgz",
|
||||||
@@ -1766,6 +1854,12 @@
|
|||||||
"node": ">=10.13.0"
|
"node": ">=10.13.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/guid-typescript": {
|
||||||
|
"version": "1.0.9",
|
||||||
|
"resolved": "https://registry.npmmirror.com/guid-typescript/-/guid-typescript-1.0.9.tgz",
|
||||||
|
"integrity": "sha512-Y8T4vYhEfwJOTbouREvG+3XDsjr8E3kIr7uf+JZ0BYloFsttiHU0WfvANVsR7TxNUJa/WpCnw/Ino/p+DeBhBQ==",
|
||||||
|
"license": "ISC"
|
||||||
|
},
|
||||||
"node_modules/hasown": {
|
"node_modules/hasown": {
|
||||||
"version": "2.0.3",
|
"version": "2.0.3",
|
||||||
"resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.3.tgz",
|
"resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.3.tgz",
|
||||||
@@ -1903,6 +1997,12 @@
|
|||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
|
"node_modules/long": {
|
||||||
|
"version": "5.3.2",
|
||||||
|
"resolved": "https://registry.npmmirror.com/long/-/long-5.3.2.tgz",
|
||||||
|
"integrity": "sha512-mNAgZ1GmyNhD7AuqnTG3/VQ26o760+ZYBPKjPvugO8+nLbYfX6TVpJPseBvopbdY+qpZ/lKUnmEc1LeZYS3QAA==",
|
||||||
|
"license": "Apache-2.0"
|
||||||
|
},
|
||||||
"node_modules/loose-envify": {
|
"node_modules/loose-envify": {
|
||||||
"version": "1.4.0",
|
"version": "1.4.0",
|
||||||
"resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz",
|
"resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz",
|
||||||
@@ -2024,6 +2124,26 @@
|
|||||||
"node": ">= 6"
|
"node": ">= 6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/onnxruntime-common": {
|
||||||
|
"version": "1.26.0",
|
||||||
|
"resolved": "https://registry.npmmirror.com/onnxruntime-common/-/onnxruntime-common-1.26.0.tgz",
|
||||||
|
"integrity": "sha512-qVyMR4lcWgbkc4getFV+GQijsTnbg/siteoqcDwa3sI/LxbrMSNw4ePyvCq/ymdQaRomCA7YuWmhzsswxvymdw==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
|
"node_modules/onnxruntime-web": {
|
||||||
|
"version": "1.26.0",
|
||||||
|
"resolved": "https://registry.npmmirror.com/onnxruntime-web/-/onnxruntime-web-1.26.0.tgz",
|
||||||
|
"integrity": "sha512-LbRr/8zZt2xilI2smrVQGGKINo0U46i8qJp+UXyMBGfqN7KjnH1BiwCwLwyNIVV4i9CKFv7Sf4PwLKWnT8/bEA==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"flatbuffers": "^25.1.24",
|
||||||
|
"guid-typescript": "^1.0.9",
|
||||||
|
"long": "^5.2.3",
|
||||||
|
"onnxruntime-common": "1.26.0",
|
||||||
|
"platform": "^1.3.6",
|
||||||
|
"protobufjs": "^7.2.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/path-parse": {
|
"node_modules/path-parse": {
|
||||||
"version": "1.0.7",
|
"version": "1.0.7",
|
||||||
"resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz",
|
"resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz",
|
||||||
@@ -2071,6 +2191,12 @@
|
|||||||
"node": ">= 6"
|
"node": ">= 6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/platform": {
|
||||||
|
"version": "1.3.6",
|
||||||
|
"resolved": "https://registry.npmmirror.com/platform/-/platform-1.3.6.tgz",
|
||||||
|
"integrity": "sha512-fnWVljUchTro6RiCFvCXBbNhJc2NijN7oIQxbwsyL0buWJPG85v81ehlHI9fXrJsMNgTofEoWIQeClKpgxFLrg==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/postcss": {
|
"node_modules/postcss": {
|
||||||
"version": "8.5.14",
|
"version": "8.5.14",
|
||||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.14.tgz",
|
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.14.tgz",
|
||||||
@@ -2234,6 +2360,30 @@
|
|||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
|
"node_modules/protobufjs": {
|
||||||
|
"version": "7.6.2",
|
||||||
|
"resolved": "https://registry.npmmirror.com/protobufjs/-/protobufjs-7.6.2.tgz",
|
||||||
|
"integrity": "sha512-N9EiLovGEQOJSPF26Ij7qUGvahfEnq0eeYZ02aigIedkmz1qZSwjnP9SBITHJuF/6MYbIW4HDN8zdYjsjqJKXQ==",
|
||||||
|
"hasInstallScript": true,
|
||||||
|
"license": "BSD-3-Clause",
|
||||||
|
"dependencies": {
|
||||||
|
"@protobufjs/aspromise": "^1.1.2",
|
||||||
|
"@protobufjs/base64": "^1.1.2",
|
||||||
|
"@protobufjs/codegen": "^2.0.5",
|
||||||
|
"@protobufjs/eventemitter": "^1.1.1",
|
||||||
|
"@protobufjs/fetch": "^1.1.1",
|
||||||
|
"@protobufjs/float": "^1.0.2",
|
||||||
|
"@protobufjs/inquire": "^1.1.2",
|
||||||
|
"@protobufjs/path": "^1.1.2",
|
||||||
|
"@protobufjs/pool": "^1.1.0",
|
||||||
|
"@protobufjs/utf8": "^1.1.1",
|
||||||
|
"@types/node": ">=13.7.0",
|
||||||
|
"long": "^5.3.2"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/queue-microtask": {
|
"node_modules/queue-microtask": {
|
||||||
"version": "1.2.3",
|
"version": "1.2.3",
|
||||||
"resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz",
|
"resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz",
|
||||||
@@ -2623,6 +2773,12 @@
|
|||||||
"node": ">=14.17"
|
"node": ">=14.17"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/undici-types": {
|
||||||
|
"version": "7.24.6",
|
||||||
|
"resolved": "https://registry.npmmirror.com/undici-types/-/undici-types-7.24.6.tgz",
|
||||||
|
"integrity": "sha512-WRNW+sJgj5OBN4/0JpHFqtqzhpbnV0GuB+OozA9gCL7a993SmU+1JBZCzLNxYsbMfIeDL+lTsphD5jN5N+n0zg==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/update-browserslist-db": {
|
"node_modules/update-browserslist-db": {
|
||||||
"version": "1.2.3",
|
"version": "1.2.3",
|
||||||
"resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.2.3.tgz",
|
"resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.2.3.tgz",
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
"preview": "vite preview"
|
"preview": "vite preview"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@ricky0123/vad-web": "^0.0.30",
|
||||||
"react": "^18.3.1",
|
"react": "^18.3.1",
|
||||||
"react-dom": "^18.3.1",
|
"react-dom": "^18.3.1",
|
||||||
"zustand": "^4.5.5"
|
"zustand": "^4.5.5"
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ function setHashSessionId(sessionId: string | null) {
|
|||||||
|
|
||||||
export default function App() {
|
export default function App() {
|
||||||
const { isLoggedIn, login, register, loading: authLoading, userId } = useAuth();
|
const { isLoggedIn, login, register, loading: authLoading, userId } = useAuth();
|
||||||
const { send } = useChat();
|
const { send, sendVoiceStreamMessage } = useChat();
|
||||||
const { loadSessionsFromServer, ensureMainSession, setCurrentSessionId, setMessages, loadMessagesFromServer, sessions, currentSessionId } = useSessionStore();
|
const { loadSessionsFromServer, ensureMainSession, setCurrentSessionId, setMessages, loadMessagesFromServer, sessions, currentSessionId } = useSessionStore();
|
||||||
|
|
||||||
const [authMode, setAuthMode] = useState<'login' | 'register'>('login');
|
const [authMode, setAuthMode] = useState<'login' | 'register'>('login');
|
||||||
@@ -330,41 +330,42 @@ export default function App() {
|
|||||||
return (
|
return (
|
||||||
<ErrorBoundary>
|
<ErrorBoundary>
|
||||||
<AppLayout>
|
<AppLayout>
|
||||||
<PageRouter onSend={send} />
|
<PageRouter onSend={send} onSendVoiceStream={sendVoiceStreamMessage} />
|
||||||
</AppLayout>
|
</AppLayout>
|
||||||
</ErrorBoundary>
|
</ErrorBoundary>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
type SendFn = (content: string, mode?: import('@/types/chat').ChatMode, attachments?: import('@/types/chat').MessageAttachment[]) => void;
|
type SendFn = (content: string, mode?: import('@/types/chat').ChatMode, attachments?: import('@/types/chat').MessageAttachment[]) => void;
|
||||||
|
type SendVoiceStreamFn = (msg: import('@/types/chat').WSClientMessage) => void;
|
||||||
|
|
||||||
function PageRouter({ onSend }: { onSend: SendFn }) {
|
function PageRouter({ onSend, onSendVoiceStream }: { onSend: SendFn; onSendVoiceStream: SendVoiceStreamFn }) {
|
||||||
const currentPage = usePageStore((s) => s.currentPage);
|
const currentPage = usePageStore((s) => s.currentPage);
|
||||||
const isAdmin = isAdminUser(localStorage.getItem('user_id') || '');
|
const isAdmin = isAdminUser(localStorage.getItem('user_id') || '');
|
||||||
|
|
||||||
switch (currentPage) {
|
switch (currentPage) {
|
||||||
case 'admin-models':
|
case 'admin-models':
|
||||||
if (!isAdmin) return <ChatPage onSend={onSend} />;
|
if (!isAdmin) return <ChatPage onSend={onSend} onSendVoiceStream={onSendVoiceStream} />;
|
||||||
return <ModelsAdminPage />;
|
return <ModelsAdminPage />;
|
||||||
case 'admin-dashboard':
|
case 'admin-dashboard':
|
||||||
if (!isAdmin) return <ChatPage onSend={onSend} />;
|
if (!isAdmin) return <ChatPage onSend={onSend} onSendVoiceStream={onSendVoiceStream} />;
|
||||||
return <AdminDashboard />;
|
return <AdminDashboard />;
|
||||||
case 'profile':
|
case 'profile':
|
||||||
return <ProfilePage />;
|
return <ProfilePage />;
|
||||||
case 'chat':
|
case 'chat':
|
||||||
default:
|
default:
|
||||||
return <ChatPage onSend={onSend} />;
|
return <ChatPage onSend={onSend} onSendVoiceStream={onSendVoiceStream} />;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function ChatPage({ onSend }: { onSend: SendFn }) {
|
function ChatPage({ onSend, onSendVoiceStream }: { onSend: SendFn; onSendVoiceStream: SendVoiceStreamFn }) {
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col h-full overflow-hidden">
|
<div className="flex flex-col h-full overflow-hidden">
|
||||||
<div className="flex-1 min-h-0 overflow-hidden">
|
<div className="flex-1 min-h-0 overflow-hidden">
|
||||||
<ChatContainer />
|
<ChatContainer />
|
||||||
</div>
|
</div>
|
||||||
<div className="flex-shrink-0">
|
<div className="flex-shrink-0">
|
||||||
<ChatInput onSend={onSend} />
|
<ChatInput onSend={onSend} onSendVoiceStream={onSendVoiceStream} />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
import { useState, useRef, useCallback, useEffect } from 'react';
|
import { useState, useRef, useCallback, useEffect } from 'react';
|
||||||
import type { ChatMode, MessageAttachment } from '@/types/chat';
|
import type { ChatMode, MessageAttachment } from '@/types/chat';
|
||||||
import { useSpeechRecognition } from '@/hooks/useSpeechRecognition';
|
import { useSpeechRecognition } from '@/hooks/useSpeechRecognition';
|
||||||
|
import { useVoiceInput } from '@/hooks/useVoiceInput';
|
||||||
import { uploadFile } from '@/api/files';
|
import { uploadFile } from '@/api/files';
|
||||||
import { useChatStore } from '@/store/chatStore';
|
import { useChatStore } from '@/store/chatStore';
|
||||||
|
|
||||||
interface ChatInputProps {
|
interface ChatInputProps {
|
||||||
onSend: (content: string, mode: ChatMode, attachments?: MessageAttachment[]) => void;
|
onSend: (content: string, mode: ChatMode, attachments?: MessageAttachment[]) => void;
|
||||||
|
onSendVoiceStream?: (msg: import('@/types/chat').WSClientMessage) => void;
|
||||||
disabled?: boolean;
|
disabled?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -19,7 +21,7 @@ const MAX_IMAGE_SIZE = 10 * 1024 * 1024; // 10MB
|
|||||||
const SUPPORTED_IMAGE_TYPES = ['image/jpeg', 'image/png', 'image/gif', 'image/webp', 'image/bmp'];
|
const SUPPORTED_IMAGE_TYPES = ['image/jpeg', 'image/png', 'image/gif', 'image/webp', 'image/bmp'];
|
||||||
const MAX_IMAGES = 5;
|
const MAX_IMAGES = 5;
|
||||||
|
|
||||||
export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
export function ChatInput({ onSend, onSendVoiceStream, disabled }: ChatInputProps) {
|
||||||
const [content, setContent] = useState('');
|
const [content, setContent] = useState('');
|
||||||
const [mode, setMode] = useState<ChatMode>('text');
|
const [mode, setMode] = useState<ChatMode>('text');
|
||||||
const [pendingImages, setPendingImages] = useState<PendingImage[]>([]);
|
const [pendingImages, setPendingImages] = useState<PendingImage[]>([]);
|
||||||
@@ -30,27 +32,50 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
|||||||
const isTyping = useChatStore((s) => s.isTyping);
|
const isTyping = useChatStore((s) => s.isTyping);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
isListening,
|
isListening: isSRListening,
|
||||||
isSupported,
|
isSupported: isSRSpported,
|
||||||
isFallbackMode,
|
isFallbackMode,
|
||||||
interimText,
|
interimText,
|
||||||
finalText,
|
finalText,
|
||||||
error,
|
error: srError,
|
||||||
startListening,
|
startListening: startSR,
|
||||||
stopListening,
|
stopListening: stopSR,
|
||||||
resetText,
|
resetText,
|
||||||
} = useSpeechRecognition();
|
} = useSpeechRecognition();
|
||||||
|
|
||||||
// 当 finalText 更新时,追加到输入框
|
// VAD-based voice input (primary when supported)
|
||||||
|
const {
|
||||||
|
isListening: isVADListening,
|
||||||
|
isSpeaking: isVADSpeaking,
|
||||||
|
isSupported: isVADSupported,
|
||||||
|
interimText: vadInterimText,
|
||||||
|
finalText: vadFinalText,
|
||||||
|
error: vadError,
|
||||||
|
startListening: startVAD,
|
||||||
|
stopListening: stopVAD,
|
||||||
|
} = useVoiceInput({
|
||||||
|
onTranscription: (text: string) => {
|
||||||
|
setContent((prev) => {
|
||||||
|
const trimmed = prev.trimEnd();
|
||||||
|
return (trimmed ? trimmed + ' ' : '') + text;
|
||||||
|
});
|
||||||
|
},
|
||||||
|
sendMessage: onSendVoiceStream,
|
||||||
|
});
|
||||||
|
|
||||||
|
const isListening = isVADSupported ? isVADListening : isSRListening;
|
||||||
|
const voiceError = isVADSupported ? vadError : srError;
|
||||||
|
|
||||||
|
// 当 SR finalText 更新时,追加到输入框 (仅非 VAD 模式)
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (finalText) {
|
if (!isVADSupported && finalText) {
|
||||||
setContent((prev) => {
|
setContent((prev) => {
|
||||||
const trimmed = prev.trimEnd();
|
const trimmed = prev.trimEnd();
|
||||||
return (trimmed ? trimmed + ' ' : '') + finalText;
|
return (trimmed ? trimmed + ' ' : '') + finalText;
|
||||||
});
|
});
|
||||||
resetText();
|
resetText();
|
||||||
}
|
}
|
||||||
}, [finalText, resetText]);
|
}, [isVADSupported, finalText, resetText]);
|
||||||
|
|
||||||
const handleSend = useCallback(async () => {
|
const handleSend = useCallback(async () => {
|
||||||
const trimmed = content.trim();
|
const trimmed = content.trim();
|
||||||
@@ -121,13 +146,13 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
|||||||
if (e.key === 'V' && e.ctrlKey && e.shiftKey) {
|
if (e.key === 'V' && e.ctrlKey && e.shiftKey) {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
if (isListening) {
|
if (isListening) {
|
||||||
stopListening();
|
isVADSupported ? stopVAD() : stopSR();
|
||||||
} else {
|
} else {
|
||||||
startListening();
|
isVADSupported ? startVAD() : startSR();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[handleSend, isListening, startListening, stopListening]
|
[handleSend, isListening, isVADSupported, startVAD, stopVAD, startSR, stopSR]
|
||||||
);
|
);
|
||||||
|
|
||||||
// 粘贴图片
|
// 粘贴图片
|
||||||
@@ -260,11 +285,11 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
|||||||
|
|
||||||
const handleVoiceToggle = useCallback(() => {
|
const handleVoiceToggle = useCallback(() => {
|
||||||
if (isListening) {
|
if (isListening) {
|
||||||
stopListening();
|
isVADSupported ? stopVAD() : stopSR();
|
||||||
} else {
|
} else {
|
||||||
startListening();
|
isVADSupported ? startVAD() : startSR();
|
||||||
}
|
}
|
||||||
}, [isListening, startListening, stopListening]);
|
}, [isListening, isVADSupported, startVAD, stopVAD, startSR, stopSR]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
@@ -305,8 +330,17 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* 实时识别文本提示 */}
|
{/* VAD 语音状态提示 */}
|
||||||
{isListening && interimText && (
|
{isVADSupported && isVADListening && (
|
||||||
|
<div className="text-sm text-pink-500 dark:text-pink-400 italic px-1" aria-live="polite">
|
||||||
|
{isVADSpeaking
|
||||||
|
? (vadInterimText || '检测到语音,正在识别...')
|
||||||
|
: '正在聆听...'}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 实时识别文本提示 (仅非 VAD 模式) */}
|
||||||
|
{!isVADSupported && isListening && interimText && (
|
||||||
<div
|
<div
|
||||||
className="interim-text text-sm text-pink-500 dark:text-pink-400 italic px-1"
|
className="interim-text text-sm text-pink-500 dark:text-pink-400 italic px-1"
|
||||||
aria-live="polite"
|
aria-live="polite"
|
||||||
@@ -317,12 +351,12 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
|||||||
)}
|
)}
|
||||||
|
|
||||||
{/* 错误提示 */}
|
{/* 错误提示 */}
|
||||||
{error && (
|
{voiceError && (
|
||||||
<div
|
<div
|
||||||
className="text-xs text-red-500 dark:text-red-400 px-1"
|
className="text-xs text-red-500 dark:text-red-400 px-1"
|
||||||
role="alert"
|
role="alert"
|
||||||
>
|
>
|
||||||
⚠️ {error}
|
⚠️ {voiceError}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
@@ -434,18 +468,20 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
|||||||
className="flex-1 resize-none rounded-xl border border-pink-200 dark:border-pink-800 bg-white dark:bg-gray-800 px-4 py-2 text-sm text-gray-700 dark:text-gray-200 placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-pink-400 focus:border-transparent disabled:opacity-50"
|
className="flex-1 resize-none rounded-xl border border-pink-200 dark:border-pink-800 bg-white dark:bg-gray-800 px-4 py-2 text-sm text-gray-700 dark:text-gray-200 placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-pink-400 focus:border-transparent disabled:opacity-50"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{/* 语音输入按钮 (仅浏览器支持时显示) */}
|
{/* 语音输入按钮 (VAD 或浏览器 SpeechRecognition/MediaRecorder 支持时显示) */}
|
||||||
{isSupported && (
|
{(isVADSupported || isSRSpported) && (
|
||||||
<button
|
<button
|
||||||
onClick={handleVoiceToggle}
|
onClick={handleVoiceToggle}
|
||||||
disabled={disabled || uploading}
|
disabled={disabled || uploading}
|
||||||
aria-label={isListening ? '停止语音输入' : '开始语音输入'}
|
aria-label={isListening ? '停止语音输入' : '开始语音输入'}
|
||||||
aria-pressed={isListening}
|
aria-pressed={isListening}
|
||||||
title={isListening ? '停止聆听 (Ctrl+Shift+V)' : '语音输入 (Ctrl+Shift+V)'}
|
title={isListening ? '停止聆听 (Ctrl+Shift+V)' : isVADSupported ? '语音输入 (自动检测说话)' : '语音输入 (Ctrl+Shift+V)'}
|
||||||
className={`p-2 rounded-xl transition-all flex-shrink-0 border-2 ${
|
className={`p-2 rounded-xl transition-all flex-shrink-0 border-2 ${
|
||||||
isListening
|
isListening
|
||||||
? 'voice-btn-active bg-red-500 border-red-500 text-white'
|
? 'voice-btn-active bg-red-500 border-red-500 text-white'
|
||||||
: 'bg-gray-100 dark:bg-gray-700 border-gray-200 dark:border-gray-600 text-gray-500 hover:text-red-500 hover:border-red-300'
|
: isVADSpeaking
|
||||||
|
? 'bg-yellow-400 border-yellow-400 text-white'
|
||||||
|
: 'bg-gray-100 dark:bg-gray-700 border-gray-200 dark:border-gray-600 text-gray-500 hover:text-red-500 hover:border-red-300'
|
||||||
} disabled:opacity-40 disabled:cursor-not-allowed`}
|
} disabled:opacity-40 disabled:cursor-not-allowed`}
|
||||||
>
|
>
|
||||||
<svg
|
<svg
|
||||||
@@ -461,7 +497,7 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
|||||||
)}
|
)}
|
||||||
|
|
||||||
{/* 不支持时显示禁用按钮 */}
|
{/* 不支持时显示禁用按钮 */}
|
||||||
{!isSupported && (
|
{!isVADSupported && !isSRSpported && (
|
||||||
<button
|
<button
|
||||||
disabled
|
disabled
|
||||||
title="您的浏览器不支持语音识别"
|
title="您的浏览器不支持语音识别"
|
||||||
@@ -508,7 +544,10 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
|||||||
{/* 语音输入状态提示 */}
|
{/* 语音输入状态提示 */}
|
||||||
{isListening && (
|
{isListening && (
|
||||||
<p className="text-xs text-red-400 text-center animate-pulse">
|
<p className="text-xs text-red-400 text-center animate-pulse">
|
||||||
{isFallbackMode ? '🎤 后端语音识别中...' : '🎤 正在聆听...'}
|
{isVADSupported
|
||||||
|
? (isVADSpeaking ? '🔊 检测到语音,正在识别...' : '🎤 正在聆听...')
|
||||||
|
: (isFallbackMode ? '🎤 后端语音识别中...' : '🎤 正在聆听...')
|
||||||
|
}
|
||||||
<span className="text-gray-400 ml-2">(Ctrl+Shift+V 停止)</span>
|
<span className="text-gray-400 ml-2">(Ctrl+Shift+V 停止)</span>
|
||||||
</p>
|
</p>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -43,11 +43,19 @@ export function useChat() {
|
|||||||
[addMessage, setTyping, sendMessage]
|
[addMessage, setTyping, sendMessage]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const sendVoiceStreamMessage = useCallback(
|
||||||
|
(msg: import('@/types/chat').WSClientMessage) => {
|
||||||
|
sendMessage(msg);
|
||||||
|
},
|
||||||
|
[sendMessage]
|
||||||
|
);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
messages,
|
messages,
|
||||||
isTyping,
|
isTyping,
|
||||||
isConnected,
|
isConnected,
|
||||||
send,
|
send,
|
||||||
|
sendVoiceStreamMessage,
|
||||||
clearMessages,
|
clearMessages,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,221 @@
|
|||||||
|
import { useState, useRef, useCallback, useEffect } from 'react';
|
||||||
|
import { MicVAD, utils } from '@ricky0123/vad-web';
|
||||||
|
import { transcribeAudio } from '@/api/voice';
|
||||||
|
import { useChatStore } from '@/store/chatStore';
|
||||||
|
import type { WSClientMessage } from '@/types/chat';
|
||||||
|
|
||||||
|
interface UseVoiceInputOptions {
|
||||||
|
onTranscription: (text: string) => void;
|
||||||
|
language?: string;
|
||||||
|
/** 提供后启用流式模式:通过 WebSocket 分片发送音频,实时返回中间结果 */
|
||||||
|
sendMessage?: (msg: WSClientMessage) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UseVoiceInputReturn {
|
||||||
|
isListening: boolean;
|
||||||
|
isSpeaking: boolean;
|
||||||
|
isSupported: boolean;
|
||||||
|
/** 流式模式:voice_interim 中间识别文本 */
|
||||||
|
interimText: string;
|
||||||
|
/** 流式模式:voice_final 最终识别文本 */
|
||||||
|
finalText: string;
|
||||||
|
error: string | null;
|
||||||
|
startListening: () => Promise<void>;
|
||||||
|
stopListening: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useVoiceInput({
|
||||||
|
onTranscription,
|
||||||
|
language = 'zh',
|
||||||
|
sendMessage,
|
||||||
|
}: UseVoiceInputOptions): UseVoiceInputReturn {
|
||||||
|
const [isListening, setIsListening] = useState(false);
|
||||||
|
const [isSpeaking, setIsSpeaking] = useState(false);
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
|
||||||
|
const vadRef = useRef<MicVAD | null>(null);
|
||||||
|
const inSpeechRef = useRef(false);
|
||||||
|
const seqRef = useRef(0);
|
||||||
|
|
||||||
|
const isSupported =
|
||||||
|
typeof window !== 'undefined' &&
|
||||||
|
typeof AudioWorkletNode !== 'undefined' &&
|
||||||
|
typeof WebAssembly !== 'undefined';
|
||||||
|
|
||||||
|
const isStreaming = !!sendMessage;
|
||||||
|
|
||||||
|
// Subscribe to streaming results from store
|
||||||
|
const voiceInterimText = useChatStore((s) => s.voiceInterimText);
|
||||||
|
const voiceFinalText = useChatStore((s) => s.voiceFinalText);
|
||||||
|
|
||||||
|
const stopListening = useCallback(() => {
|
||||||
|
if (vadRef.current) {
|
||||||
|
vadRef.current.pause();
|
||||||
|
}
|
||||||
|
inSpeechRef.current = false;
|
||||||
|
setIsListening(false);
|
||||||
|
setIsSpeaking(false);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const startListening = useCallback(async () => {
|
||||||
|
if (!isSupported) return;
|
||||||
|
|
||||||
|
setError(null);
|
||||||
|
seqRef.current = 0;
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (vadRef.current) {
|
||||||
|
await vadRef.current.destroy();
|
||||||
|
vadRef.current = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
vadRef.current = await MicVAD.new({
|
||||||
|
onSpeechStart: () => {
|
||||||
|
setIsSpeaking(true);
|
||||||
|
inSpeechRef.current = true;
|
||||||
|
seqRef.current = 0;
|
||||||
|
|
||||||
|
if (isStreaming) {
|
||||||
|
sendMessage!({
|
||||||
|
type: 'voice_stream_start',
|
||||||
|
format: 'wav',
|
||||||
|
language,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
onFrameProcessed: async (_probabilities, frame) => {
|
||||||
|
if (!isStreaming || !inSpeechRef.current) return;
|
||||||
|
|
||||||
|
// Accumulate audio frames during speech for the voice_stream_chunk
|
||||||
|
// Send every ~300ms of audio (10 frames at ~30ms each)
|
||||||
|
const frameSeq = seqRef.current++;
|
||||||
|
const CHUNK_INTERVAL = 10; // send every 10 frames
|
||||||
|
|
||||||
|
if (frameSeq % CHUNK_INTERVAL === 0) {
|
||||||
|
try {
|
||||||
|
const wavBuffer = utils.encodeWAV(frame);
|
||||||
|
const base64 = arrayBufferToBase64(wavBuffer);
|
||||||
|
|
||||||
|
sendMessage!({
|
||||||
|
type: 'voice_stream_chunk',
|
||||||
|
audio_data: base64,
|
||||||
|
sequence: Math.floor(frameSeq / CHUNK_INTERVAL),
|
||||||
|
timestamp: Date.now(),
|
||||||
|
});
|
||||||
|
} catch {
|
||||||
|
// Ignore encoding errors for individual chunks
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
onSpeechEnd: async (audio: Float32Array) => {
|
||||||
|
setIsSpeaking(false);
|
||||||
|
inSpeechRef.current = false;
|
||||||
|
|
||||||
|
if (isStreaming) {
|
||||||
|
// Send the last chunk with accumulated audio if any
|
||||||
|
try {
|
||||||
|
const wavBuffer = utils.encodeWAV(audio);
|
||||||
|
const base64 = arrayBufferToBase64(wavBuffer);
|
||||||
|
|
||||||
|
sendMessage!({
|
||||||
|
type: 'voice_stream_chunk',
|
||||||
|
audio_data: base64,
|
||||||
|
sequence: -1,
|
||||||
|
timestamp: Date.now(),
|
||||||
|
});
|
||||||
|
} catch {
|
||||||
|
// Ignore
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signal end of voice stream — gateway returns voice_final
|
||||||
|
sendMessage!({
|
||||||
|
type: 'voice_stream_end',
|
||||||
|
timestamp: Date.now(),
|
||||||
|
});
|
||||||
|
// onTranscription will be called in useEffect when voiceFinalText updates
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// REST mode: send full audio for transcription
|
||||||
|
const wavBuffer = utils.encodeWAV(audio);
|
||||||
|
const wavBlob = new Blob([wavBuffer], { type: 'audio/wav' });
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result = await transcribeAudio(wavBlob, language);
|
||||||
|
if (result.error) {
|
||||||
|
setError(result.error);
|
||||||
|
} else if (result.data?.text) {
|
||||||
|
onTranscription(result.data.text);
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
setError(err instanceof Error ? err.message : '语音识别失败');
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
onVADMisfire: () => {
|
||||||
|
setIsSpeaking(false);
|
||||||
|
inSpeechRef.current = false;
|
||||||
|
|
||||||
|
if (isStreaming) {
|
||||||
|
sendMessage!({
|
||||||
|
type: 'voice_stream_end',
|
||||||
|
timestamp: Date.now(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
startOnLoad: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
setIsListening(true);
|
||||||
|
} catch (err) {
|
||||||
|
const message =
|
||||||
|
err instanceof DOMException && err.name === 'NotAllowedError'
|
||||||
|
? '麦克风权限被拒绝'
|
||||||
|
: err instanceof Error
|
||||||
|
? err.message
|
||||||
|
: 'VAD 初始化失败';
|
||||||
|
setError(message);
|
||||||
|
}
|
||||||
|
}, [isSupported, language, isStreaming, sendMessage, onTranscription]);
|
||||||
|
|
||||||
|
// In streaming mode, watch for voice_final from the server
|
||||||
|
useEffect(() => {
|
||||||
|
if (isStreaming && voiceFinalText) {
|
||||||
|
onTranscription(voiceFinalText);
|
||||||
|
useChatStore.getState().setVoiceFinalText('');
|
||||||
|
}
|
||||||
|
}, [isStreaming, voiceFinalText, onTranscription]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
return () => {
|
||||||
|
if (vadRef.current) {
|
||||||
|
vadRef.current.destroy();
|
||||||
|
vadRef.current = null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return {
|
||||||
|
isListening,
|
||||||
|
isSpeaking,
|
||||||
|
isSupported,
|
||||||
|
interimText: isStreaming ? voiceInterimText : '',
|
||||||
|
finalText: isStreaming ? voiceFinalText : '',
|
||||||
|
error,
|
||||||
|
startListening,
|
||||||
|
stopListening,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function arrayBufferToBase64(buffer: ArrayBuffer): string {
|
||||||
|
const bytes = new Uint8Array(buffer);
|
||||||
|
let binary = '';
|
||||||
|
for (let i = 0; i < bytes.byteLength; i++) {
|
||||||
|
binary += String.fromCharCode(bytes[i]);
|
||||||
|
}
|
||||||
|
return btoa(binary);
|
||||||
|
}
|
||||||
@@ -254,7 +254,7 @@ function handleServerMessage(msg: WSServerMessage) {
|
|||||||
client_info: msg.client_info,
|
client_info: msg.client_info,
|
||||||
audioUrl: msg.full_audio_url,
|
audioUrl: msg.full_audio_url,
|
||||||
segments: msg.segments,
|
segments: msg.segments,
|
||||||
metadata: msg.tool_calls ? { tool_calls: msg.tool_calls } : undefined,
|
metadata: msg.metadata || (msg.tool_calls ? { tool_calls: msg.tool_calls } : undefined),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
setTyping(false);
|
setTyping(false);
|
||||||
@@ -462,6 +462,18 @@ function handleServerMessage(msg: WSServerMessage) {
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case 'voice_interim':
|
||||||
|
if (msg.text !== undefined) {
|
||||||
|
useChatStore.getState().setVoiceInterimText(msg.text);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'voice_final':
|
||||||
|
if (msg.text !== undefined) {
|
||||||
|
useChatStore.getState().setVoiceFinalText(msg.text);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
case 'pong':
|
case 'pong':
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ interface ChatStore {
|
|||||||
isLoadingHistory: boolean;
|
isLoadingHistory: boolean;
|
||||||
historyPage: number;
|
historyPage: number;
|
||||||
|
|
||||||
|
// 流式语音识别状态
|
||||||
|
voiceInterimText: string;
|
||||||
|
voiceFinalText: string;
|
||||||
|
|
||||||
// 多气泡消息队列:确保气泡依次出现 + 逐字动画
|
// 多气泡消息队列:确保气泡依次出现 + 逐字动画
|
||||||
messageQueue: Message[];
|
messageQueue: Message[];
|
||||||
|
|
||||||
@@ -37,6 +41,8 @@ interface ChatStore {
|
|||||||
clearMessages: () => void;
|
clearMessages: () => void;
|
||||||
|
|
||||||
setContinuousMode: (enabled: boolean) => void;
|
setContinuousMode: (enabled: boolean) => void;
|
||||||
|
setVoiceInterimText: (text: string) => void;
|
||||||
|
setVoiceFinalText: (text: string) => void;
|
||||||
setBackgroundThinkingStatus: (status: BackgroundThinkingStatus) => void;
|
setBackgroundThinkingStatus: (status: BackgroundThinkingStatus) => void;
|
||||||
setIoTDevices: (devices: IoTDevice[]) => void;
|
setIoTDevices: (devices: IoTDevice[]) => void;
|
||||||
|
|
||||||
@@ -54,6 +60,8 @@ export const useChatStore = create<ChatStore>((set) => ({
|
|||||||
backgroundThinkingStatus: 'idle',
|
backgroundThinkingStatus: 'idle',
|
||||||
iotDevices: [],
|
iotDevices: [],
|
||||||
iotDevicesLastUpdated: null,
|
iotDevicesLastUpdated: null,
|
||||||
|
voiceInterimText: '',
|
||||||
|
voiceFinalText: '',
|
||||||
hasMoreMessages: false,
|
hasMoreMessages: false,
|
||||||
isLoadingHistory: false,
|
isLoadingHistory: false,
|
||||||
historyPage: 1,
|
historyPage: 1,
|
||||||
@@ -143,6 +151,9 @@ export const useChatStore = create<ChatStore>((set) => ({
|
|||||||
|
|
||||||
setContinuousMode: (enabled) => set({ continuousMode: enabled }),
|
setContinuousMode: (enabled) => set({ continuousMode: enabled }),
|
||||||
|
|
||||||
|
setVoiceInterimText: (text) => set({ voiceInterimText: text }),
|
||||||
|
setVoiceFinalText: (text) => set({ voiceFinalText: text, voiceInterimText: '' }),
|
||||||
|
|
||||||
setBackgroundThinkingStatus: (status) => set({ backgroundThinkingStatus: status }),
|
setBackgroundThinkingStatus: (status) => set({ backgroundThinkingStatus: status }),
|
||||||
|
|
||||||
setIoTDevices: (devices) =>
|
setIoTDevices: (devices) =>
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
// 页面路由 Store — 管理当前显示的页面
|
||||||
|
|
||||||
|
import { create } from 'zustand';
|
||||||
|
|
||||||
|
export type PageId = 'chat' | 'admin-models' | 'admin-dashboard' | 'profile';
|
||||||
|
|
||||||
|
interface PageState {
|
||||||
|
currentPage: PageId;
|
||||||
|
setPage: (page: PageId) => void;
|
||||||
|
goToChat: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const usePageStore = create<PageState>((set) => ({
|
||||||
|
currentPage: 'chat',
|
||||||
|
setPage: (page) => set({ currentPage: page }),
|
||||||
|
goToChat: () => set({ currentPage: 'chat' }),
|
||||||
|
}));
|
||||||
@@ -105,11 +105,14 @@ export interface StreamSegment {
|
|||||||
|
|
||||||
/** WebSocket 客户端消息 */
|
/** WebSocket 客户端消息 */
|
||||||
export interface WSClientMessage {
|
export interface WSClientMessage {
|
||||||
type: 'message' | 'voice_input' | 'ping' | 'history';
|
type: 'message' | 'voice_input' | 'voice_stream_start' | 'voice_stream_chunk' | 'voice_stream_end' | 'ping' | 'history';
|
||||||
session_id?: string;
|
session_id?: string;
|
||||||
mode?: ChatMode;
|
mode?: ChatMode;
|
||||||
content?: string;
|
content?: string;
|
||||||
audio_data?: string; // base64
|
audio_data?: string; // base64
|
||||||
|
format?: string; // 音频格式 (voice_stream_start): webm, wav, pcm, opus
|
||||||
|
language?: string; // 识别语言 (voice_stream_start): zh, en, ja, ko, auto
|
||||||
|
sequence?: number; // 音频分片序号 (voice_stream_chunk)
|
||||||
attachments?: MessageAttachment[];
|
attachments?: MessageAttachment[];
|
||||||
timestamp: number;
|
timestamp: number;
|
||||||
client_id?: string;
|
client_id?: string;
|
||||||
@@ -139,7 +142,7 @@ export interface AppNotification extends NotificationData {
|
|||||||
|
|
||||||
/** WebSocket 服务端消息 */
|
/** WebSocket 服务端消息 */
|
||||||
export interface WSServerMessage {
|
export interface WSServerMessage {
|
||||||
type: 'stream_start' | 'response' | 'segment' | 'audio' | 'error' | 'device_update' | 'pong' | 'history_response' | 'stream_chunk' | 'stream_end' | 'background_thinking' | 'notification' | 'multi_message' | 'stream_segments' | 'review' | 'thinking' | 'tool_progress' | 'system_info';
|
type: 'stream_start' | 'response' | 'segment' | 'audio' | 'error' | 'device_update' | 'pong' | 'history_response' | 'stream_chunk' | 'stream_end' | 'background_thinking' | 'notification' | 'multi_message' | 'stream_segments' | 'review' | 'thinking' | 'tool_progress' | 'system_info' | 'voice_interim' | 'voice_final';
|
||||||
message_id?: string;
|
message_id?: string;
|
||||||
text?: string;
|
text?: string;
|
||||||
content?: string;
|
content?: string;
|
||||||
@@ -161,6 +164,7 @@ export interface WSServerMessage {
|
|||||||
notification?: NotificationData;
|
notification?: NotificationData;
|
||||||
tool_progress?: ToolProgressInfo;
|
tool_progress?: ToolProgressInfo;
|
||||||
system_info?: SystemInfoPayload;
|
system_info?: SystemInfoPayload;
|
||||||
|
metadata?: Record<string, unknown>;
|
||||||
protocol_version?: number;
|
protocol_version?: number;
|
||||||
timestamp: number;
|
timestamp: number;
|
||||||
client_info?: ClientInfo;
|
client_info?: ClientInfo;
|
||||||
|
|||||||
@@ -0,0 +1,100 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
批量 WEM → WAV 转换,使用 vgmstream-cli + ffmpeg 标准化。
|
||||||
|
输出: 22050Hz, mono, 16-bit PCM WAV
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
VGMSTREAM = r"D:\Project\Code\Uni\Cyrene\scripts\voice\tools\vgmstream\vgmstream-cli.exe"
|
||||||
|
RAW_DIR = r"D:\Project\Code\Uni\Cyrene-Voice-Model\data\raw"
|
||||||
|
CLEANED_DIR = r"D:\Project\Code\Uni\Cyrene-Voice-Model\data\cleaned"
|
||||||
|
|
||||||
|
# 要转换的子目录(按优先级)
|
||||||
|
TARGETS = [
|
||||||
|
"VoBanks27", "VoBanks28", "VoBanks29", "VoBanks30", "VoBanks31",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_wem_to_wav(wem_path: str, wav_path: str) -> bool:
|
||||||
|
"""vgmstream WEM→临时WAV, ffmpeg 标准化 → 最终WAV (22050Hz mono s16)"""
|
||||||
|
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
|
||||||
|
|
||||||
|
# 跳过已存在且非空的文件
|
||||||
|
if os.path.exists(wav_path) and os.path.getsize(wav_path) > 100:
|
||||||
|
return True
|
||||||
|
|
||||||
|
tmp_path = wav_path + ".tmp.wav"
|
||||||
|
try:
|
||||||
|
# Step 1: vgmstream → temp WAV
|
||||||
|
result = subprocess.run(
|
||||||
|
[VGMSTREAM, "-o", tmp_path, wem_path],
|
||||||
|
capture_output=True, timeout=30,
|
||||||
|
)
|
||||||
|
if result.returncode != 0 or not os.path.exists(tmp_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Step 2: ffmpeg → 标准化 22050Hz mono s16
|
||||||
|
result = subprocess.run(
|
||||||
|
["ffmpeg", "-y", "-i", tmp_path,
|
||||||
|
"-ar", "22050", "-ac", "1", "-sample_fmt", "s16",
|
||||||
|
wav_path],
|
||||||
|
capture_output=True, timeout=30,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return os.path.exists(wav_path) and os.path.getsize(wav_path) > 100
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" FAIL [{os.path.basename(wem_path)}]: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
if os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("=== 批量 WEM → WAV 转换 (VoBanks) ===\n")
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
ok = 0
|
||||||
|
fail = 0
|
||||||
|
|
||||||
|
for target in TARGETS:
|
||||||
|
src_dir = os.path.join(RAW_DIR, target)
|
||||||
|
dst_dir = os.path.join(CLEANED_DIR, target)
|
||||||
|
|
||||||
|
if not os.path.isdir(src_dir):
|
||||||
|
print(f"SKIP: {target} (not found)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
wem_files = sorted(Path(src_dir).glob("*.wem"))
|
||||||
|
if not wem_files:
|
||||||
|
print(f"SKIP: {target} (empty)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"[{target}] {len(wem_files)} files...")
|
||||||
|
|
||||||
|
for i, wem in enumerate(wem_files):
|
||||||
|
wav = os.path.join(dst_dir, wem.stem + ".wav")
|
||||||
|
if convert_wem_to_wav(str(wem), wav):
|
||||||
|
ok += 1
|
||||||
|
else:
|
||||||
|
fail += 1
|
||||||
|
total += 1
|
||||||
|
|
||||||
|
if (i + 1) % 50 == 0:
|
||||||
|
print(f" {i+1}/{len(wem_files)} (ok:{ok} fail:{fail})")
|
||||||
|
|
||||||
|
print(f" Done: {len(wem_files)} files\n")
|
||||||
|
|
||||||
|
print(f"=== 转换完成: {ok} ok, {fail} fail, {total} total ===")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# 批量 WEM → WAV 转换 (使用 vgmstream-cli)
|
||||||
|
set -e
|
||||||
|
|
||||||
|
RAW_DIR="D:/Project/Code/Uni/Cyrene-Voice-Model/data/raw"
|
||||||
|
CLEANED_DIR="D:/Project/Code/Uni/Cyrene-Voice-Model/data/cleaned"
|
||||||
|
VGMSTREAM="D:/Project/Code/Uni/Cyrene/scripts/voice/tools/vgmstream/vgmstream-cli.exe"
|
||||||
|
|
||||||
|
if [ ! -f "$VGMSTREAM" ]; then
|
||||||
|
echo "错误: 找不到 vgmstream-cli.exe"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "=== 批量 WEM → WAV 转换 ==="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
TOTAL=0
|
||||||
|
SUCCESS=0
|
||||||
|
FAILED=0
|
||||||
|
|
||||||
|
while IFS= read -r -d '' wem_file; do
|
||||||
|
TOTAL=$((TOTAL + 1))
|
||||||
|
|
||||||
|
# 输出路径: cleaned/ 目录下保持相同子目录结构,改 .wem 为 .wav
|
||||||
|
rel_path="${wem_file#$RAW_DIR/}"
|
||||||
|
wav_file="${CLEANED_DIR}/${rel_path%.wem}.wav"
|
||||||
|
wav_dir="$(dirname "$wav_file")"
|
||||||
|
|
||||||
|
mkdir -p "$wav_dir"
|
||||||
|
|
||||||
|
# 跳过已转换的
|
||||||
|
if [ -f "$wav_file" ] && [ "$(stat -c%s "$wav_file" 2>/dev/null || echo 0)" -gt 100 ]; then
|
||||||
|
SUCCESS=$((SUCCESS + 1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 转换
|
||||||
|
if cmd.exe //c "$VGMSTREAM -o \"$wav_file\" \"$wem_file\"" 2>/dev/null; then
|
||||||
|
SUCCESS=$((SUCCESS + 1))
|
||||||
|
else
|
||||||
|
FAILED=$((FAILED + 1))
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 进度显示
|
||||||
|
if [ $((TOTAL % 100)) -eq 0 ]; then
|
||||||
|
echo " 进度: $TOTAL 文件 (成功: $SUCCESS, 失败: $FAILED)"
|
||||||
|
fi
|
||||||
|
done < <(find "$RAW_DIR" -name "*.wem" -print0)
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 转换完成 ==="
|
||||||
|
echo "总计: $TOTAL | 成功: $SUCCESS | 失败: $FAILED"
|
||||||
|
|
||||||
|
# 统计分类
|
||||||
|
echo ""
|
||||||
|
echo "音频时长分布:"
|
||||||
|
find "$CLEANED_DIR" -name "*.wav" | while read wav; do
|
||||||
|
dur=$(ffprobe -v quiet -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 "$wav" 2>/dev/null || echo "0")
|
||||||
|
echo "$dur"
|
||||||
|
done | awk '
|
||||||
|
{ d = $1 + 0 }
|
||||||
|
d < 1 { lt1++ }
|
||||||
|
d < 3 { lt3++ }
|
||||||
|
d < 10 { lt10++ }
|
||||||
|
d < 30 { lt30++ }
|
||||||
|
d >= 30 { gt30++ }
|
||||||
|
END {
|
||||||
|
printf " < 1s: %d\n", lt1
|
||||||
|
printf " 1-3s: %d\n", lt3
|
||||||
|
printf " 3-10s: %d\n", lt10
|
||||||
|
printf " 10-30s: %d\n", lt30
|
||||||
|
printf " > 30s: %d\n", gt30
|
||||||
|
}'
|
||||||
@@ -0,0 +1,175 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
通过 Chrome DevTools Protocol 下载文件。
|
||||||
|
需要 Chrome 以 --remote-debugging-port=9222 启动。
|
||||||
|
|
||||||
|
用法:
|
||||||
|
python cdp_download.py <url> <output_path>
|
||||||
|
python cdp_download.py https://github.com/.../vgmstream-win.zip tools/vgmstream.zip
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from websocket import create_connection
|
||||||
|
|
||||||
|
|
||||||
|
class CDPClient:
|
||||||
|
def __init__(self, ws_url: str):
|
||||||
|
self.ws = create_connection(ws_url, origin="http://127.0.0.1:9222")
|
||||||
|
self._id = 0
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._pending = {}
|
||||||
|
self._events = []
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
# Start background reader
|
||||||
|
self._reader_thread = threading.Thread(target=self._read_loop, daemon=True)
|
||||||
|
self._reader_thread.start()
|
||||||
|
|
||||||
|
def _read_loop(self):
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
msg = json.loads(self.ws.recv())
|
||||||
|
msg_id = msg.get('id')
|
||||||
|
if msg_id is not None:
|
||||||
|
with self._lock:
|
||||||
|
self._pending[msg_id] = msg
|
||||||
|
else:
|
||||||
|
self._events.append(msg)
|
||||||
|
except Exception:
|
||||||
|
if self._running:
|
||||||
|
time.sleep(0.1)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
def send(self, method: str, params: dict = None) -> dict:
|
||||||
|
self._id += 1
|
||||||
|
msg_id = self._id
|
||||||
|
payload = json.dumps({
|
||||||
|
'id': msg_id,
|
||||||
|
'method': method,
|
||||||
|
'params': params or {}
|
||||||
|
})
|
||||||
|
self.ws.send(payload)
|
||||||
|
|
||||||
|
# Wait for response
|
||||||
|
timeout = 60
|
||||||
|
start = time.time()
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
with self._lock:
|
||||||
|
if msg_id in self._pending:
|
||||||
|
result = self._pending.pop(msg_id)
|
||||||
|
if 'error' in result:
|
||||||
|
raise Exception(f"CDP Error: {result['error']}")
|
||||||
|
return result.get('result', {})
|
||||||
|
time.sleep(0.1)
|
||||||
|
raise TimeoutError(f"CDP command {method} timed out")
|
||||||
|
|
||||||
|
def wait_for_event(self, event_type: str, timeout: float = 60) -> dict:
|
||||||
|
start = time.time()
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
for i, evt in enumerate(self._events):
|
||||||
|
if evt.get('method') == event_type:
|
||||||
|
return self._events.pop(i)['params']
|
||||||
|
time.sleep(0.2)
|
||||||
|
raise TimeoutError(f"Event {event_type} not received within {timeout}s")
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._running = False
|
||||||
|
self.ws.close()
|
||||||
|
|
||||||
|
|
||||||
|
def download_via_cdp(url: str, output_path: str, cdp_url: str = "http://localhost:9222"):
|
||||||
|
"""
|
||||||
|
使用 Chrome CDP 下载文件。
|
||||||
|
"""
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
# 1. 创建新标签页
|
||||||
|
print(f"[CDP] 创建标签页...")
|
||||||
|
req = urllib.request.Request(f"{cdp_url}/json/new", method='PUT')
|
||||||
|
resp = urllib.request.urlopen(req)
|
||||||
|
tab = json.loads(resp.read())
|
||||||
|
ws_url = tab['webSocketDebuggerUrl']
|
||||||
|
print(f"[CDP] 标签页: {tab['id']}")
|
||||||
|
|
||||||
|
client = CDPClient(ws_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 2. 启用必要的域
|
||||||
|
print(f"[CDP] 启用 Page...")
|
||||||
|
client.send('Page.enable')
|
||||||
|
|
||||||
|
# 3. 设置下载目录
|
||||||
|
download_dir = os.path.abspath(os.path.dirname(output_path))
|
||||||
|
os.makedirs(download_dir, exist_ok=True)
|
||||||
|
print(f"[CDP] 下载目录: {download_dir}")
|
||||||
|
|
||||||
|
client.send('Browser.setDownloadBehavior', {
|
||||||
|
'behavior': 'allow',
|
||||||
|
'downloadPath': download_dir
|
||||||
|
})
|
||||||
|
|
||||||
|
# 4. 导航到下载 URL
|
||||||
|
print(f"[CDP] 导航到: {url}")
|
||||||
|
client.send('Page.navigate', {'url': url})
|
||||||
|
|
||||||
|
# 5. 等待下载完成
|
||||||
|
print(f"[CDP] 等待下载开始...")
|
||||||
|
will_begin = client.wait_for_event('Browser.downloadWillBegin', timeout=30)
|
||||||
|
guid = will_begin['guid']
|
||||||
|
suggested_name = will_begin.get('suggestedFilename', 'unknown')
|
||||||
|
print(f"[CDP] 下载开始: {suggested_name} (guid={guid})")
|
||||||
|
|
||||||
|
# 等待下载进度完成
|
||||||
|
print(f"[CDP] 等待下载完成...")
|
||||||
|
while True:
|
||||||
|
progress = client.wait_for_event('Browser.downloadProgress', timeout=120)
|
||||||
|
state = progress.get('state', '')
|
||||||
|
if state == 'completed':
|
||||||
|
print(f"[CDP] 下载完成")
|
||||||
|
break
|
||||||
|
elif state == 'canceled':
|
||||||
|
raise Exception("下载被取消")
|
||||||
|
elif state == 'interrupted':
|
||||||
|
print(f"[CDP] 下载中断, 重试...")
|
||||||
|
client.send('Browser.resumeDownload', {'guid': guid})
|
||||||
|
else:
|
||||||
|
received = progress.get('receivedBytes', 0)
|
||||||
|
total = progress.get('totalBytes', 0)
|
||||||
|
if total > 0:
|
||||||
|
print(f"[CDP] 进度: {received}/{total} ({100*received//total}%)")
|
||||||
|
|
||||||
|
# 6. 移动到目标路径
|
||||||
|
downloaded_file = os.path.join(download_dir, suggested_name)
|
||||||
|
if os.path.exists(downloaded_file) and downloaded_file != output_path:
|
||||||
|
if os.path.exists(output_path):
|
||||||
|
os.remove(output_path)
|
||||||
|
os.rename(downloaded_file, output_path)
|
||||||
|
print(f"[CDP] 文件保存到: {output_path}")
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
finally:
|
||||||
|
client.close()
|
||||||
|
# 关闭标签页
|
||||||
|
urllib.request.urlopen(urllib.request.Request(
|
||||||
|
f"{cdp_url}/json/close/{tab['id']}", method='PUT'))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) < 3:
|
||||||
|
print("用法: python cdp_download.py <url> <output_path>")
|
||||||
|
print("示例: python cdp_download.py https://example.com/file.zip tools/file.zip")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
url = sys.argv[1]
|
||||||
|
output = sys.argv[2]
|
||||||
|
download_via_cdp(url, output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -0,0 +1,125 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
将 .wem (Wwise Encoded Media) 文件批量转换为 .wav 格式。
|
||||||
|
使用 ffmpeg 进行转换(需预先安装 ffmpeg)。
|
||||||
|
|
||||||
|
.wem 文件本质上是 RIFF/WAVE 容器,内部编码可能是:
|
||||||
|
- PCM 16-bit (ffmpeg 直接支持)
|
||||||
|
- Wwise ADPCM (ffmpeg 需要额外解码器)
|
||||||
|
- Vorbis (部分 ffmpeg 版本支持)
|
||||||
|
|
||||||
|
用法:
|
||||||
|
python convert_wem.py <input_dir> <output_dir>
|
||||||
|
python convert_wem.py ./wem_output/ ./wav_output/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def convert_wem_to_wav(wem_path: str, wav_path: str) -> bool:
|
||||||
|
"""使用 ffmpeg 将单个 .wem 文件转为 .wav."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
['ffmpeg', '-y', '-i', wem_path,
|
||||||
|
'-ar', '22050', # 22.05kHz (与 persona.yaml 设定的训练格式一致)
|
||||||
|
'-ac', '1', # mono
|
||||||
|
'-sample_fmt', 's16', # 16-bit
|
||||||
|
wav_path],
|
||||||
|
capture_output=True,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
if result.returncode == 0 and os.path.getsize(wav_path) > 100:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# 部分 WEM 是 Vorbis 编码,需要用不同方式
|
||||||
|
return _convert_wem_vorbis(wem_path, wav_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ffmpeg 错误 [{os.path.basename(wem_path)}]: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_wem_vorbis(wem_path: str, wav_path: str) -> bool:
|
||||||
|
"""尝试处理 Vorbis 编码的 WEM 文件 (带 Wwise 头的 Ogg Vorbis)."""
|
||||||
|
try:
|
||||||
|
# 方式: 跳过 RIFF 头 + fmt chunk, 直接取 Vorbis 数据
|
||||||
|
# .wem 的 Vorbis 数据从 "vorb" chunk 开始
|
||||||
|
with open(wem_path, 'rb') as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
# 查找 "vorb" 标识
|
||||||
|
vorb_pos = data.find(b'vorb')
|
||||||
|
if vorb_pos == -1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 重新封装为标准 Ogg (在 vorb 数据前加 OggS 头)
|
||||||
|
# 简化方法: 用 ffmpeg 的 libvorbis 解码
|
||||||
|
# 如果上面失败了,尝试用 -f s16le 强制读取
|
||||||
|
result = subprocess.run(
|
||||||
|
['ffmpeg', '-y', '-f', 's16le',
|
||||||
|
'-ar', '48000', '-ac', '1',
|
||||||
|
'-i', wem_path,
|
||||||
|
'-ar', '22050', '-ac', '1',
|
||||||
|
wav_path],
|
||||||
|
capture_output=True,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
return result.returncode == 0 and os.path.getsize(wav_path) > 100
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def convert_directory(input_dir: str, output_dir: str) -> tuple[int, int]:
|
||||||
|
"""批量转换目录中所有 .wem 文件."""
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
wem_files = sorted(Path(input_dir).glob('*.wem'))
|
||||||
|
|
||||||
|
if not wem_files:
|
||||||
|
print(f"在 {input_dir} 中未找到 .wem 文件")
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
print(f"找到 {len(wem_files)} 个 .wem 文件,开始转换...")
|
||||||
|
success = 0
|
||||||
|
failed = 0
|
||||||
|
|
||||||
|
for i, wem_path in enumerate(wem_files):
|
||||||
|
wav_name = wem_path.stem + '.wav'
|
||||||
|
wav_path = os.path.join(output_dir, wav_name)
|
||||||
|
|
||||||
|
# 跳过已存在的
|
||||||
|
if os.path.exists(wav_path) and os.path.getsize(wav_path) > 100:
|
||||||
|
success += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if convert_wem_to_wav(str(wem_path), wav_path):
|
||||||
|
success += 1
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
|
||||||
|
if (i + 1) % 50 == 0:
|
||||||
|
print(f" 进度: {i+1}/{len(wem_files)} (成功: {success}, 失败: {failed})")
|
||||||
|
|
||||||
|
print(f"\n转换完成: {success} 成功, {failed} 失败")
|
||||||
|
return success, failed
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="批量转换 .wem → .wav (需要 ffmpeg)")
|
||||||
|
parser.add_argument('input_dir', help='包含 .wem 文件的输入目录')
|
||||||
|
parser.add_argument('output_dir', help='输出目录')
|
||||||
|
parser.add_argument('--single', nargs=2, metavar=('WEM', 'WAV'),
|
||||||
|
help='转换单个文件')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.single:
|
||||||
|
ok = convert_wem_to_wav(args.single[0], args.single[1])
|
||||||
|
print(f"{'OK' if ok else 'FAILED'}: {args.single[0]} -> {args.single[1]}")
|
||||||
|
else:
|
||||||
|
convert_directory(args.input_dir, args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -0,0 +1,161 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# ============================================================
|
||||||
|
# 昔涟语音提取管线
|
||||||
|
#
|
||||||
|
# 步骤:
|
||||||
|
# 1. 从 HSR 音频包提取 .wem (本脚本)
|
||||||
|
# 2. 用 ww2ogg/vgmstream 转换 .wem → .wav (需手动安装工具)
|
||||||
|
# 3. 用 ffmpeg 标准化音频格式
|
||||||
|
# ============================================================
|
||||||
|
set -e
|
||||||
|
|
||||||
|
HSR_AUDIO_DIR="D:/MeowG/Honkai:Star_Rail/StarRail_Data/Persistent/Audio/AudioPackage/Windows/Chinese(PRC)"
|
||||||
|
RAW_DIR="D:/Project/Code/Uni/Cyrene-Voice-Model/data/raw"
|
||||||
|
CLEANED_DIR="D:/Project/Code/Uni/Cyrene-Voice-Model/data/cleaned"
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
|
||||||
|
echo "=== 昔涟语音提取管线 ==="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# ---- 阶段 1: 从 .pck 提取 .wem ----
|
||||||
|
echo "[阶段 1/4] 提取候选 .pck 文件..."
|
||||||
|
|
||||||
|
# 昔涟是 3.x 角色,语音在以下文件中:
|
||||||
|
# - VoBanks 27-31 (最新角色语音库)
|
||||||
|
# - External_del_4.0_chapter5_* (3.0 主线昔涟出场)
|
||||||
|
# - External_del_4.1_chapter_* (3.1 主线昔涟出场)
|
||||||
|
TARGETS=(
|
||||||
|
"VoBanks27.pck"
|
||||||
|
"VoBanks28.pck"
|
||||||
|
"VoBanks29.pck"
|
||||||
|
"VoBanks30.pck"
|
||||||
|
"VoBanks31.pck"
|
||||||
|
"External_del_4.0_chapter5_0.pck"
|
||||||
|
"External_del_4.0_chapter5_1.pck"
|
||||||
|
"External_del_4.0_chapter5_2.pck"
|
||||||
|
"External_del_4.1_chapter_0.pck"
|
||||||
|
"External_del_4.1_chapter_1.pck"
|
||||||
|
"External_del_4.1_chapter_2.pck"
|
||||||
|
)
|
||||||
|
|
||||||
|
for target in "${TARGETS[@]}"; do
|
||||||
|
pck_path="${HSR_AUDIO_DIR}/${target}"
|
||||||
|
if [ -f "$pck_path" ]; then
|
||||||
|
echo " 提取: $target ($(du -h "$pck_path" | cut -f1))"
|
||||||
|
python3 "${SCRIPT_DIR}/extract_pck.py" "$pck_path" "${RAW_DIR}/${target%.pck}/"
|
||||||
|
else
|
||||||
|
echo " 跳过: $target (文件不存在)"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "[阶段 1/4] 完成: .wem 文件已提取到 ${RAW_DIR}/"
|
||||||
|
|
||||||
|
# ---- 阶段 2: .wem → .wav 转换 ----
|
||||||
|
echo ""
|
||||||
|
echo "[阶段 2/4] 需要转换 .wem → .wav"
|
||||||
|
echo ""
|
||||||
|
echo " HSR 使用 Wwise 专有编码 (0xFFFF),ffmpeg 无法直接解码。"
|
||||||
|
echo " 请使用以下工具之一进行转换:"
|
||||||
|
echo ""
|
||||||
|
echo " 方案 A — vgmstream CLI (推荐, 最简单):"
|
||||||
|
echo " 下载: https://github.com/vgmstream/vgmstream/releases"
|
||||||
|
echo " 解压后将 vgmstream-cli.exe 放入 scripts/voice/tools/"
|
||||||
|
echo " 然后运行本脚本的 --convert 模式"
|
||||||
|
echo ""
|
||||||
|
echo " 方案 B — AnimeWwise GUI (最强大, 保留原始文件名):"
|
||||||
|
echo " 下载: https://github.com/Escartem/AnimeWwise/releases"
|
||||||
|
echo " 直接打开 GUI, 选择 HSR 目录, 导出昔涟语音"
|
||||||
|
echo ""
|
||||||
|
echo " 方案 C — ww2ogg + revorb (传统方案):"
|
||||||
|
echo " 下载 ww2ogg.exe + revorb.exe + packed_codebooks.bin"
|
||||||
|
echo " 放入 scripts/voice/tools/"
|
||||||
|
echo ""
|
||||||
|
echo " 安装工具后, 运行: $0 --convert"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# ---- 阶段 3 (条件): 批量转换 ----
|
||||||
|
if [ "$1" = "--convert" ]; then
|
||||||
|
echo "[阶段 3/4] 转换 .wem → .wav..."
|
||||||
|
|
||||||
|
TOOLS_DIR="${SCRIPT_DIR}/tools"
|
||||||
|
|
||||||
|
# 优先使用 vgmstream
|
||||||
|
if [ -f "${TOOLS_DIR}/vgmstream-cli.exe" ]; then
|
||||||
|
echo " 使用 vgmstream-cli..."
|
||||||
|
find "${RAW_DIR}" -name "*.wem" | while read wem; do
|
||||||
|
wav="${wem%.wem}.wav"
|
||||||
|
if [ ! -f "$wav" ]; then
|
||||||
|
"${TOOLS_DIR}/vgmstream-cli.exe" -o "$wav" "$wem" 2>/dev/null
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
elif [ -f "${TOOLS_DIR}/ww2ogg.exe" ]; then
|
||||||
|
echo " 使用 ww2ogg + ffmpeg..."
|
||||||
|
find "${RAW_DIR}" -name "*.wem" | while read wem; do
|
||||||
|
ogg="${wem%.wem}.ogg"
|
||||||
|
wav="${wem%.wem}.wav"
|
||||||
|
if [ ! -f "$wav" ]; then
|
||||||
|
"${TOOLS_DIR}/ww2ogg.exe" "$wem" -o "$ogg" --pcb "${TOOLS_DIR}/packed_codebooks.bin" 2>/dev/null
|
||||||
|
ffmpeg -y -i "$ogg" -ar 22050 -ac 1 -sample_fmt s16 "$wav" 2>/dev/null
|
||||||
|
rm -f "$ogg"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
else
|
||||||
|
echo " 错误: 未找到转换工具, 请先安装 vgmstream 或 ww2ogg"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "[阶段 3/4] 完成"
|
||||||
|
|
||||||
|
# ---- 阶段 4: 音频标准化 + 分类 ----
|
||||||
|
echo ""
|
||||||
|
echo "[阶段 4/4] 标准化 + 分类..."
|
||||||
|
|
||||||
|
# 按音频时长初步分类 (语音通常 1-15 秒)
|
||||||
|
mkdir -p "${CLEANED_DIR}/daily" "${CLEANED_DIR}/battle" \
|
||||||
|
"${CLEANED_DIR}/emotional" "${CLEANED_DIR}/story"
|
||||||
|
|
||||||
|
find "${RAW_DIR}" -name "*.wav" | while read wav; do
|
||||||
|
# 获取时长
|
||||||
|
duration=$(ffprobe -v quiet -show_entries format=duration \
|
||||||
|
-of default=noprint_wrappers=1:nokey=1 "$wav" 2>/dev/null || echo "0")
|
||||||
|
dur_float=$(echo "$duration" | awk '{print int($1 * 1000)}')
|
||||||
|
|
||||||
|
basename=$(basename "$wav" .wav)
|
||||||
|
parent=$(basename "$(dirname "$wav")")
|
||||||
|
|
||||||
|
# 分类逻辑
|
||||||
|
if [ "$dur_float" -lt 500 ]; then
|
||||||
|
# < 0.5s: 可能是战斗短语音 / 语气词
|
||||||
|
target_dir="${CLEANED_DIR}/battle"
|
||||||
|
elif [ "$dur_float" -gt 15000 ]; then
|
||||||
|
# > 15s: 可能是剧情长对话
|
||||||
|
target_dir="${CLEANED_DIR}/story"
|
||||||
|
elif echo "$parent" | grep -qi "chapter"; then
|
||||||
|
target_dir="${CLEANED_DIR}/story"
|
||||||
|
elif echo "$parent" | grep -qi "vobanks"; then
|
||||||
|
# VoBanks 包含战斗 + 日常语音, 需要人工筛选
|
||||||
|
target_dir="${CLEANED_DIR}/daily"
|
||||||
|
else
|
||||||
|
target_dir="${CLEANED_DIR}/daily"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 用 ffmpeg 标准化: 22.05kHz mono 16bit
|
||||||
|
ffmpeg -y -i "$wav" -ar 22050 -ac 1 -sample_fmt s16 \
|
||||||
|
"${target_dir}/${parent}_${basename}.wav" 2>/dev/null
|
||||||
|
|
||||||
|
echo -n "."
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "[阶段 4/4] 完成: 音频已分类到 ${CLEANED_DIR}/"
|
||||||
|
echo ""
|
||||||
|
echo "文件分布:"
|
||||||
|
echo " 日常对话: $(ls "${CLEANED_DIR}/daily/" 2>/dev/null | wc -l) 个"
|
||||||
|
echo " 战斗语音: $(ls "${CLEANED_DIR}/battle/" 2>/dev/null | wc -l) 个"
|
||||||
|
echo " 情感表达: $(ls "${CLEANED_DIR}/emotional/" 2>/dev/null | wc -l) 个"
|
||||||
|
echo " 剧情对话: $(ls "${CLEANED_DIR}/story/" 2>/dev/null | wc -l) 个"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 管线完成 ==="
|
||||||
@@ -0,0 +1,99 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
从 Honkai: Star Rail 的 .pck (AKPK/Wwise SoundBank) 文件中提取 .wem 音频。
|
||||||
|
|
||||||
|
用法:
|
||||||
|
python extract_pck.py <input.pck> <output_dir>
|
||||||
|
python extract_pck.py VoBanks31.pck ./output/
|
||||||
|
python extract_pck.py --all <pck_dir> <output_dir>
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def find_wem_files(data: bytes) -> list[tuple[int, int, str]]:
|
||||||
|
"""
|
||||||
|
扫描数据中的所有 RIFF/WAVE 块 (.wem 文件).
|
||||||
|
返回 [(offset, size, riff_type), ...] 列表.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
pos = 0
|
||||||
|
data_len = len(data)
|
||||||
|
while pos < data_len - 12:
|
||||||
|
if data[pos:pos + 4] == b'RIFF':
|
||||||
|
chunk_size = struct.unpack_from('<I', data, pos + 4)[0]
|
||||||
|
riff_type = data[pos + 8:pos + 12]
|
||||||
|
# .wem 文件的 riff_type 是 b'WAVE'
|
||||||
|
if riff_type == b'WAVE' and chunk_size > 100:
|
||||||
|
total_size = chunk_size + 8 # RIFF header + data
|
||||||
|
if pos + total_size <= data_len:
|
||||||
|
results.append((pos, total_size, riff_type.decode('ascii', errors='replace')))
|
||||||
|
# 跳过已匹配的块
|
||||||
|
pos += 8 + chunk_size
|
||||||
|
continue
|
||||||
|
pos += 1
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def extract_pck(pck_path: str, output_dir: str, prefix: str = "") -> list[str]:
|
||||||
|
"""
|
||||||
|
从单个 .pck 文件提取所有 .wem 文件到 output_dir.
|
||||||
|
返回提取的文件路径列表.
|
||||||
|
"""
|
||||||
|
pck_name = Path(pck_path).stem
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
with open(pck_path, 'rb') as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
print(f"[{pck_name}] 文件大小: {len(data):,} bytes, 扫描 WEM 块...")
|
||||||
|
wem_entries = find_wem_files(data)
|
||||||
|
print(f"[{pck_name}] 找到 {len(wem_entries)} 个音频文件")
|
||||||
|
|
||||||
|
extracted = []
|
||||||
|
for i, (offset, size, riff_type) in enumerate(wem_entries):
|
||||||
|
# 使用 offset 作为唯一 ID (Wwise 文件 ID 就是 offset 的某种映射)
|
||||||
|
wem_data = data[offset:offset + size]
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
filename = f"{prefix}_{i:04d}_{offset:08x}.wem"
|
||||||
|
else:
|
||||||
|
filename = f"{pck_name}_{i:04d}_{offset:08x}.wem"
|
||||||
|
|
||||||
|
out_path = os.path.join(output_dir, filename)
|
||||||
|
with open(out_path, 'wb') as f:
|
||||||
|
f.write(wem_data)
|
||||||
|
extracted.append(out_path)
|
||||||
|
|
||||||
|
total_mb = sum(wem_entries[i][1] for i in range(len(wem_entries))) / 1024 / 1024
|
||||||
|
print(f"[{pck_name}] 提取完成: {len(extracted)} 文件, {total_mb:.1f} MB")
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="从 HSR .pck 文件提取 .wem 音频")
|
||||||
|
parser.add_argument('input', help='输入 .pck 文件路径,或 --all 模式下的目录路径')
|
||||||
|
parser.add_argument('output', help='输出目录')
|
||||||
|
parser.add_argument('--all', action='store_true', help='批量模式:提取目录中所有 .pck 文件')
|
||||||
|
parser.add_argument('--prefix', default='', help='输出文件名前缀')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.all:
|
||||||
|
pck_dir = Path(args.input)
|
||||||
|
pck_files = sorted(pck_dir.glob('*.pck'))
|
||||||
|
print(f"批量模式: 找到 {len(pck_files)} 个 .pck 文件")
|
||||||
|
total_extracted = 0
|
||||||
|
for pck_file in pck_files:
|
||||||
|
extracted = extract_pck(str(pck_file), args.output, args.prefix)
|
||||||
|
total_extracted += len(extracted)
|
||||||
|
print(f"\n总计: {total_extracted} 个音频文件")
|
||||||
|
else:
|
||||||
|
extract_pck(args.input, args.output, args.prefix)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
Copyright (c) 2008-2025 Adam Gashlin, Fastelbja, Ronny Elfert, bnnm,
|
||||||
|
Christopher Snowhill, NicknineTheEagle, bxaimc,
|
||||||
|
Thealexbarney, CyberBotX, EdnessP, et al
|
||||||
|
|
||||||
|
Portions Copyright (c) 2004-2008, Marko Kreen
|
||||||
|
Portions Copyright 2001-2007 jagarl / Kazunori Ueno <jagarl@creator.club.ne.jp>
|
||||||
|
Portions Copyright (c) 1998, Justin Frankel/Nullsoft Inc.
|
||||||
|
Portions Copyright (C) 2006 Nullsoft, Inc.
|
||||||
|
Portions Copyright (c) 2005-2007 Paul Hsieh
|
||||||
|
Portions Copyright (C) 2000-2004 Leshade Entis, Entis-soft.
|
||||||
|
Portions Public Domain originating with Sun Microsystems
|
||||||
|
|
||||||
|
Permission to use, copy, modify, and distribute this software for any
|
||||||
|
purpose with or without fee is hereby granted, provided that the above
|
||||||
|
copyright notice and this permission notice appear in all copies.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
# vgmstream
|
||||||
|
This is vgmstream, a library for playing streamed (prerecorded) video game audio.
|
||||||
|
|
||||||
|
Some of vgmstream's features:
|
||||||
|
- Decodes [hundreds of video game music formats and codecs](doc/FORMATS.md), from typical
|
||||||
|
game engine files to obscure single-game codecs, aiming for high accuracy and compatibility.
|
||||||
|
- Support for looped BGM, using file's internal metadata for smooth transitions, with accurate
|
||||||
|
sample counts.
|
||||||
|
- [Subsongs](doc/USAGE.md#subsongs), playing a format's multiple internal songs separately.
|
||||||
|
- Many types of companion files (data split into multiple files) and custom containers.
|
||||||
|
- Encryption keys, internal stream names, and other unusual cases found in game audio.
|
||||||
|
- [TXTH](doc/TXTH.md) function, to add external support for extra formats, including raw audio in
|
||||||
|
many forms.
|
||||||
|
- [TXTP](doc/TXTP.md) function, for real-time and per-file config, like forced looping, removing
|
||||||
|
channels, playing certain subsong, or fusing multiple files into a single one.
|
||||||
|
- Simple [external tagging](doc/USAGE.md#tagging) via .m3u files.
|
||||||
|
- [Plugins](#getting-vgmstream) are available for various media player software and operating systems.
|
||||||
|
|
||||||
|
The main development repository: https://github.com/vgmstream/vgmstream/
|
||||||
|
|
||||||
|
Automated builds with the latest changes: https://vgmstream.org
|
||||||
|
(https://github.com/vgmstream/vgmstream-releases/releases/tag/nightly)
|
||||||
|
|
||||||
|
Numbered releases: https://github.com/vgmstream/vgmstream/releases
|
||||||
|
|
||||||
|
Help can be found here: https://www.hcs64.com/
|
||||||
|
|
||||||
|
More documentation: https://github.com/vgmstream/vgmstream/tree/master/doc
|
||||||
|
|
||||||
|
## Getting vgmstream
|
||||||
|
There are multiple end-user components:
|
||||||
|
- [vgmstream-cli](doc/USAGE.md#testexevgmstream-cli-command-line-decoder): A command-line decoder.
|
||||||
|
- [in_vgmstream](doc/USAGE.md#in_vgmstream-winamp-plugin): A Winamp plugin.
|
||||||
|
- [foo_input_vgmstream](doc/USAGE.md#foo_input_vgmstream-foobar2000-plugin): A foobar2000 component.
|
||||||
|
- [xmp-vgmstream](doc/USAGE.md#xmp-vgmstream-xmplay-plugin): An XMPlay plugin.
|
||||||
|
- [vgmstream.so](doc/USAGE.md#audacious-plugin): An Audacious plugin.
|
||||||
|
- [vgmstream123](doc/USAGE.md#vgmstream123-command-line-player): A command-line player.
|
||||||
|
|
||||||
|
The main library (plain *vgmstream*) is the code that handles the internal conversion, while the
|
||||||
|
above components are what you use to get sound.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
If you want to convert game audio to `.wav`, get *vgmstream-cli* then drag-and-drop one
|
||||||
|
or more files to the executable (support may vary per O.S. or distro). This should create
|
||||||
|
`(file.extension).wav`, if the format is supported. You can also try the online web player
|
||||||
|
instead. See: https://vgmstream.org
|
||||||
|
|
||||||
|
More user-friendly would be installing a player like *foobar2000* (on Windows) or *Audacious*
|
||||||
|
(on Linux) and the vgmstream plugin. Then you can directly listen your files and set options like
|
||||||
|
infinite looping, or convert to `.wav` with the player's options (also easier to use if your file
|
||||||
|
has multiple "subsongs").
|
||||||
|
|
||||||
|
See [components](doc/USAGE.md#components) in the *usage guide* for full install instructions and
|
||||||
|
explanations. The aim is feature parity, but there are a few differences between them due to
|
||||||
|
missing parts on vgmstream's side or lack of support in the player.
|
||||||
|
|
||||||
|
Note that vgmstream cannot *encode* (convert from `.wav` to a game format), it only *decodes*
|
||||||
|
(plays game audio).
|
||||||
|
|
||||||
|
### Windows binaries
|
||||||
|
Prebuilt binaries:
|
||||||
|
- https://vgmstream.org (latest)
|
||||||
|
- https://github.com/vgmstream/vgmstream/releases (infrequent numbered releases)
|
||||||
|
|
||||||
|
The foobar2000 component is also available on https://www.foobar2000.org based on current
|
||||||
|
release.
|
||||||
|
|
||||||
|
You may also try the alternative versions (irregularly) built by [bnnm](https://github.com/bnnm):
|
||||||
|
- https://github.com/bnnm/vgmstream-builds/raw/master/bin/vgmstream-latest-test-u.zip
|
||||||
|
|
||||||
|
Or compile from source, see the [build guide](doc/BUILD.md).
|
||||||
|
|
||||||
|
### Linux binaries
|
||||||
|
A prebuilt CLI binary is available. It's statically linked and should work on systems running
|
||||||
|
Linux kernel v3.2 and above:
|
||||||
|
- https://vgmstream.org (latest)
|
||||||
|
- https://github.com/vgmstream/vgmstream/releases (infrequent numbered releases)
|
||||||
|
|
||||||
|
Building from source will also give you *vgmstream.so* (Audacious plugin), and *vgmstream123*
|
||||||
|
(command-line player), which can't be statically linked.
|
||||||
|
|
||||||
|
When building it needs several external libraries. For a quick script for Debian and Ubuntu-style
|
||||||
|
distros run `./make-build-cmake.sh`. The script will need to install dependencies first, so you
|
||||||
|
may prefer to run steps manually, which the [build guide](doc/BUILD.md) describes in detail.
|
||||||
|
|
||||||
|
### macOS binaries
|
||||||
|
A prebuilt CLI binary is available:
|
||||||
|
- https://vgmstream.org (latest)
|
||||||
|
- https://github.com/vgmstream/vgmstream/releases (infrequent numbered releases)
|
||||||
|
|
||||||
|
Otherwise follow the [build guide](doc/BUILD.md).
|
||||||
|
|
||||||
|
|
||||||
|
## More info
|
||||||
|
- [Usage guide](doc/USAGE.md)
|
||||||
|
- [List of supported audio formats](doc/FORMATS.md)
|
||||||
|
- [Build guide](doc/BUILD.md)
|
||||||
|
- [TXTH file format](doc/TXTH.md)
|
||||||
|
- [TXTP file format](doc/TXTP.md)
|
||||||
|
|
||||||
|
|
||||||
|
Enjoy! *hcs*
|
||||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,296 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
将 Wwise Vorbis .wem 文件转换为标准 Ogg Vorbis (.ogg) 文件。
|
||||||
|
纯 Python 实现,不依赖 ww2ogg 或 revorb 外部工具。
|
||||||
|
|
||||||
|
基于 Wwise RIFF/Vorbis 格式:
|
||||||
|
- Codec ID: 0xFFFF
|
||||||
|
- Vorbis 数据存储在 "vorb" chunk 中
|
||||||
|
- 数据包可直接封装为 Ogg 容器
|
||||||
|
|
||||||
|
用法:
|
||||||
|
python wem2ogg.py <input.wem> <output.ogg>
|
||||||
|
python wem2ogg.py --batch <input_dir> <output_dir>
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
import sys
|
||||||
|
import zlib
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
# Ogg 页类型
|
||||||
|
OGG_HEADER = 0x02
|
||||||
|
OGG_FIRST_DATA = 0x00
|
||||||
|
OGG_CONTINUED = 0x00
|
||||||
|
OGG_LAST = 0x04
|
||||||
|
|
||||||
|
# CRC32 表 (预计算)
|
||||||
|
_crc_table = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_crc_table():
|
||||||
|
global _crc_table
|
||||||
|
if _crc_table is None:
|
||||||
|
_crc_table = []
|
||||||
|
for i in range(256):
|
||||||
|
r = i << 24
|
||||||
|
for _ in range(8):
|
||||||
|
if r & 0x80000000:
|
||||||
|
r = (r << 1) ^ 0x04c11db7
|
||||||
|
else:
|
||||||
|
r <<= 1
|
||||||
|
_crc_table.append(r & 0xffffffff)
|
||||||
|
return _crc_table
|
||||||
|
|
||||||
|
|
||||||
|
def ogg_crc32(data: bytes) -> int:
|
||||||
|
table = _get_crc_table()
|
||||||
|
crc = 0
|
||||||
|
for b in data:
|
||||||
|
crc = (crc << 8) ^ table[((crc >> 24) & 0xff) ^ b]
|
||||||
|
crc &= 0xffffffff
|
||||||
|
return crc
|
||||||
|
|
||||||
|
|
||||||
|
def make_ogg_page(segment_data: bytes, granule: int,
|
||||||
|
header_type: int, stream_serial: int = 0,
|
||||||
|
page_index: int = 0) -> bytes:
|
||||||
|
"""构造一个 Ogg 页."""
|
||||||
|
# 将数据分割成最多 255 字节的段
|
||||||
|
segments = []
|
||||||
|
pos = 0
|
||||||
|
while pos < len(segment_data):
|
||||||
|
seg_len = min(255, len(segment_data) - pos)
|
||||||
|
segments.append(seg_len)
|
||||||
|
pos += seg_len
|
||||||
|
|
||||||
|
num_segments = len(segments)
|
||||||
|
page_header = bytearray(27 + num_segments)
|
||||||
|
|
||||||
|
# OggS 签名
|
||||||
|
page_header[0:4] = b'OggS'
|
||||||
|
# Version
|
||||||
|
page_header[4] = 0
|
||||||
|
# Header type
|
||||||
|
page_header[5] = header_type
|
||||||
|
# Granule position (8 bytes, little-endian)
|
||||||
|
struct.pack_into('<q', page_header, 6, granule)
|
||||||
|
# Stream serial
|
||||||
|
struct.pack_into('<I', page_header, 14, stream_serial)
|
||||||
|
# Page index
|
||||||
|
struct.pack_into('<I', page_header, 18, page_index)
|
||||||
|
# Checksum (先填 0)
|
||||||
|
struct.pack_into('<I', page_header, 22, 0)
|
||||||
|
# Number of segments
|
||||||
|
page_header[26] = num_segments
|
||||||
|
# Segment table
|
||||||
|
for i, seg_len in enumerate(segments):
|
||||||
|
page_header[27 + i] = seg_len
|
||||||
|
|
||||||
|
# 计算 CRC
|
||||||
|
full_page = bytearray(page_header) + bytearray(segment_data)
|
||||||
|
crc = ogg_crc32(bytes(full_page))
|
||||||
|
struct.pack_into('<I', full_page, 22, crc)
|
||||||
|
|
||||||
|
return bytes(full_page)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_vorbis_packets(wem_path: str) -> list[bytes]:
|
||||||
|
"""从 WEM 文件中提取 Vorbis 数据包."""
|
||||||
|
with open(wem_path, 'rb') as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
# 验证 RIFF 头
|
||||||
|
if data[:4] != b'RIFF':
|
||||||
|
raise ValueError("不是有效的 RIFF 文件")
|
||||||
|
|
||||||
|
# 查找 "vorb" chunk
|
||||||
|
pos = 12 # 跳过 RIFF 头
|
||||||
|
vorb_data = None
|
||||||
|
|
||||||
|
while pos < len(data) - 8:
|
||||||
|
chunk_id = data[pos:pos + 4]
|
||||||
|
chunk_size = struct.unpack_from('<I', data, pos + 4)[0]
|
||||||
|
|
||||||
|
if chunk_id == b'vorb' or chunk_id == b'data':
|
||||||
|
vorb_start = pos + 8
|
||||||
|
vorb_data = data[vorb_start:vorb_start + chunk_size]
|
||||||
|
break
|
||||||
|
|
||||||
|
# 对齐到 2 字节边界
|
||||||
|
pos += 8 + chunk_size
|
||||||
|
if chunk_size % 2:
|
||||||
|
pos += 1
|
||||||
|
|
||||||
|
if vorb_data is None:
|
||||||
|
raise ValueError("未找到 vorb/data chunk")
|
||||||
|
|
||||||
|
# 解析 Vorbis 数据包
|
||||||
|
# 前 4 字节: 数据包数量 (实际上可能是样本数)
|
||||||
|
setup_offset = struct.unpack_from('<I', vorb_data, 0)[0]
|
||||||
|
|
||||||
|
# 每个数据包: [2 bytes: granule/size info][packet data]
|
||||||
|
packets = []
|
||||||
|
pos = 4 # 跳过 setup offset
|
||||||
|
|
||||||
|
# 第一个数据包是 Vorbis 头 (identification header)
|
||||||
|
# Wwise 格式: 2 bytes granule + 2 bytes size (或只是 2 bytes size)
|
||||||
|
# 尝试解析...
|
||||||
|
|
||||||
|
# 通常第一个 packet 是 setup 数据
|
||||||
|
# 格式: 对于每个 packet:
|
||||||
|
# - uint16: 如果最高位为 1,这是 granule 的高位部分
|
||||||
|
# 实际上 Wwise Vorbis 数据包格式比较复杂
|
||||||
|
|
||||||
|
# 简化处理: 跳过 4 字节后就是连续的 Vorbis 数据包
|
||||||
|
# 每个 packet 前 2 字节表示该 packet 的大小
|
||||||
|
# packet_size & 0x8000: granule 在下一个 packet 变化
|
||||||
|
|
||||||
|
remaining = vorb_data[4:]
|
||||||
|
while len(remaining) > 2:
|
||||||
|
# 读取 packet 大小 (可能用 2 或 4 字节)
|
||||||
|
header = struct.unpack_from('<H', remaining, 0)[0]
|
||||||
|
has_granule = (header & 0x8000) != 0
|
||||||
|
pkt_size = header & 0x7FFF
|
||||||
|
|
||||||
|
if pkt_size == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
offset = 2
|
||||||
|
granule_val = 0
|
||||||
|
if has_granule:
|
||||||
|
granule_val = struct.unpack_from('<H', remaining, offset)[0]
|
||||||
|
offset += 2
|
||||||
|
|
||||||
|
if offset + pkt_size > len(remaining):
|
||||||
|
break
|
||||||
|
|
||||||
|
packet = remaining[offset:offset + pkt_size]
|
||||||
|
packets.append(packet)
|
||||||
|
remaining = remaining[offset + pkt_size:]
|
||||||
|
|
||||||
|
return packets
|
||||||
|
|
||||||
|
|
||||||
|
def wem_to_ogg(wem_path: str, ogg_path: str) -> bool:
|
||||||
|
"""转换单个 .wem 文件为 .ogg."""
|
||||||
|
try:
|
||||||
|
packets = extract_vorbis_packets(wem_path)
|
||||||
|
|
||||||
|
if len(packets) < 3:
|
||||||
|
print(f" 警告: {os.path.basename(wem_path)} 只有 {len(packets)} 个数据包")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Vorbis 头三个数据包:
|
||||||
|
# 1. Identification header
|
||||||
|
# 2. Comment header
|
||||||
|
# 3. Setup header
|
||||||
|
# 后续: 音频数据包
|
||||||
|
|
||||||
|
ident_pkt = packets[0]
|
||||||
|
comment_pkt = packets[1]
|
||||||
|
setup_pkt = packets[2]
|
||||||
|
audio_packets = packets[3:]
|
||||||
|
|
||||||
|
# 构造 Ogg 文件
|
||||||
|
ogg_data = bytearray()
|
||||||
|
|
||||||
|
# 第 0 页: Identification header
|
||||||
|
ogg_data += make_ogg_page(ident_pkt, granule=0,
|
||||||
|
header_type=OGG_HEADER,
|
||||||
|
page_index=0)
|
||||||
|
|
||||||
|
# 第 1 页: Comment + Setup headers
|
||||||
|
header_data = comment_pkt + setup_pkt
|
||||||
|
ogg_data += make_ogg_page(header_data, granule=0,
|
||||||
|
header_type=OGG_FIRST_DATA,
|
||||||
|
page_index=1)
|
||||||
|
|
||||||
|
# 后续页: 音频数据 (每页放尽可能多的 packet)
|
||||||
|
page_idx = 2
|
||||||
|
granule = 0
|
||||||
|
buf = bytearray()
|
||||||
|
|
||||||
|
for pkt in audio_packets:
|
||||||
|
# Vorbis granule = 累积样本数
|
||||||
|
# 粗略估计: 每个 packet 约 576 或 1024 samples
|
||||||
|
granule += 576
|
||||||
|
|
||||||
|
if len(buf) + len(pkt) > 45000: # Ogg 页最大约 65KB
|
||||||
|
ogg_data += make_ogg_page(bytes(buf), granule=granule,
|
||||||
|
header_type=OGG_CONTINUED,
|
||||||
|
page_index=page_idx)
|
||||||
|
page_idx += 1
|
||||||
|
buf = bytearray()
|
||||||
|
|
||||||
|
buf += pkt
|
||||||
|
|
||||||
|
# 最后一页
|
||||||
|
if buf:
|
||||||
|
ogg_data += make_ogg_page(bytes(buf), granule=granule,
|
||||||
|
header_type=OGG_LAST,
|
||||||
|
page_index=page_idx)
|
||||||
|
|
||||||
|
with open(ogg_path, 'wb') as f:
|
||||||
|
f.write(ogg_data)
|
||||||
|
|
||||||
|
return os.path.getsize(ogg_path) > 100
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" 错误 [{os.path.basename(wem_path)}]: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def batch_convert(input_dir: str, output_dir: str) -> tuple[int, int]:
|
||||||
|
"""批量转换目录中所有 .wem 文件."""
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
wem_files = sorted(Path(input_dir).glob('*.wem'))
|
||||||
|
|
||||||
|
if not wem_files:
|
||||||
|
print(f"在 {input_dir} 中未找到 .wem 文件")
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
print(f"找到 {len(wem_files)} 个 .wem 文件")
|
||||||
|
success = 0
|
||||||
|
failed = 0
|
||||||
|
|
||||||
|
for i, wem_path in enumerate(wem_files):
|
||||||
|
ogg_name = wem_path.stem + '.ogg'
|
||||||
|
ogg_path = os.path.join(output_dir, ogg_name)
|
||||||
|
|
||||||
|
if wem_to_ogg(str(wem_path), ogg_path):
|
||||||
|
success += 1
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
|
||||||
|
if (i + 1) % 50 == 0:
|
||||||
|
print(f" 进度: {i+1}/{len(wem_files)} (成功: {success})")
|
||||||
|
|
||||||
|
print(f"\n转换完成: {success} 成功, {failed} 失败")
|
||||||
|
return success, failed
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="WEM → OGG 转换器 (纯 Python)")
|
||||||
|
parser.add_argument('input', help='输入 .wem 文件或目录')
|
||||||
|
parser.add_argument('output', help='输出 .ogg 文件或目录')
|
||||||
|
parser.add_argument('--batch', action='store_true',
|
||||||
|
help='批量模式: 转换目录中所有 .wem 文件')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.batch:
|
||||||
|
batch_convert(args.input, args.output)
|
||||||
|
else:
|
||||||
|
ok = wem_to_ogg(args.input, args.output)
|
||||||
|
if ok:
|
||||||
|
print(f"OK: {args.input} → {args.output}")
|
||||||
|
else:
|
||||||
|
print(f"FAILED: {args.input}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user