Compare commits
14 Commits
Alpha_v0.1.0
..
dev
| Author | SHA1 | Date | |
|---|---|---|---|
| 6ef9e082a6 | |||
| 258cf81b25 | |||
| 4954c1e58b | |||
| 67b204b23c | |||
| b085e58031 | |||
| a9c79d7887 | |||
| d112fdd540 | |||
| e5f8e42a78 | |||
| 7e29be8ae3 | |||
| eef21fc91a | |||
| 465fa4307f | |||
| 3ad728406e | |||
| 677385ec17 | |||
| 47dce276a4 |
+13
@@ -1,10 +1,16 @@
|
||||
# ========== 依赖 ==========
|
||||
node_modules/
|
||||
|
||||
# ========== 测试 ==========
|
||||
test/
|
||||
|
||||
# ========== 构建产物 ==========
|
||||
dist/
|
||||
*.exe
|
||||
|
||||
# ========== 子仓库 ==========
|
||||
backend/cyrene-plugins/
|
||||
|
||||
# ========== Go 编译二进制 ==========
|
||||
backend/ai-core/main
|
||||
backend/ai-core/cmd/main
|
||||
@@ -48,6 +54,8 @@ backend/.env
|
||||
models.json
|
||||
thinking_schedule.json
|
||||
platform_configs.json
|
||||
platform_blocklist.json
|
||||
*.exe~
|
||||
.claude/
|
||||
|
||||
# ========== 文档 (项目规范:docs/ 不纳入版本管理,docs/api/ 为例外) ==========
|
||||
@@ -68,6 +76,11 @@ ethend/package-lock.json
|
||||
backend/voice-service/whisper.cpp/
|
||||
backend/voice-service/models/
|
||||
|
||||
# ========== 昔涟语音模型 (独立仓库 Cyrene-Voice-Model) ==========
|
||||
data/cyrene_voice/
|
||||
models/cyrene_voice/
|
||||
backend/voice-service/models/cyrene/
|
||||
|
||||
# ========== 打包归档 ==========
|
||||
*.tar.gz
|
||||
*.zip
|
||||
|
||||
@@ -139,7 +139,7 @@ Cyrene/
|
||||
│ ├── memory-service/ # 记忆服务 (CRUD、语义检索、衰减、自动提取)
|
||||
│ ├── voice-service/ # 语音服务 (DashScope STT + Edge-TTS)
|
||||
│ ├── iot-debug-service/ # IoT 调试服务 (8 个模拟智能家居设备)
|
||||
│ └── pkg/ # 共享包 (logger, plugins — 15 个通用插件/工具)
|
||||
│ └── pkg/ # 共享包 (logger 等)
|
||||
├── ethend/ # ethend 管理面板 (Express + WebSocket)
|
||||
├── scripts/ # 辅助脚本 (migrate / tunnel / whisper-setup / pg-backup)
|
||||
├── searxng/ # SearXNG 搜索引擎配置
|
||||
|
||||
@@ -67,6 +67,10 @@
|
||||
- Node.js 20 LTS
|
||||
- Docker & Docker Compose
|
||||
- Git Bash(Windows 用户)
|
||||
- [cyrene-plugins](https://git.yeij.top/AskaEth/Cyrene-Plugins) — 克隆到 `backend/` 目录内:
|
||||
```bash
|
||||
git clone https://git.yeij.top/AskaEth/Cyrene-Plugins.git backend/cyrene-plugins
|
||||
```
|
||||
|
||||
### 1. 配置环境变量
|
||||
|
||||
@@ -132,9 +136,8 @@ Cyrene/
|
||||
│ ├── memory-service/ # 记忆服务 (CRUD、语义检索、衰减、LLM 提取)
|
||||
│ ├── voice-service/ # 语音服务 (DashScope STT + Edge-TTS)
|
||||
│ ├── iot-debug-service/ # IoT 调试服务 (8 个模拟智能家居设备)
|
||||
│ ├── plugin-manager/ # 插件管理器 (管理 API,插件逻辑在 pkg/plugins)
|
||||
│ ├── platform-bridge/ # 多平台桥接 (QQ / Telegram / Discord / Webhook)
|
||||
│ └── pkg/ # 共享包 (logger, plugins — 15 个通用插件/工具)
|
||||
│ └── pkg/ # 共享包 (logger 等)
|
||||
├── ethend/ # ethend 管理面板 (Express + WebSocket)
|
||||
├── scripts/ # 辅助脚本 (migrate / tunnel / whisper-setup / pg-backup)
|
||||
├── backups/ # 数据库备份文件 (.gitignore)
|
||||
@@ -150,6 +153,8 @@ Cyrene/
|
||||
└── Caddyfile # 反向代理配置
|
||||
```
|
||||
|
||||
> **关联仓库**:[cyrene-plugins](https://git.yeij.top/AskaEth/Cyrene-Plugins) — 插件 SDK + 15 个内置插件 + Plugin Manager 服务。克隆到 `backend/cyrene-plugins/`,ai-core 通过 go.mod replace 引用。
|
||||
|
||||
---
|
||||
|
||||
## 服务端口
|
||||
|
||||
@@ -19,7 +19,7 @@ RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /ai-core ./cmd/main.go
|
||||
# ========== 运行阶段 ==========
|
||||
FROM alpine:3.20
|
||||
|
||||
RUN apk add --no-cache ca-certificates tzdata && \
|
||||
RUN apk add --no-cache ca-certificates tzdata ffmpeg && \
|
||||
cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \
|
||||
echo "Asia/Shanghai" > /etc/timezone
|
||||
|
||||
|
||||
+164
-24
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
@@ -28,26 +29,31 @@ import (
|
||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/subsession"
|
||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/tools"
|
||||
|
||||
plgManager "git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
||||
plgSDK "git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
pluginCalc "git.yeij.top/AskaEth/Cyrene/pkg/plugins/calculator"
|
||||
pluginCrypto "git.yeij.top/AskaEth/Cyrene/pkg/plugins/crypto"
|
||||
pluginDate "git.yeij.top/AskaEth/Cyrene/pkg/plugins/datetime"
|
||||
pluginFile "git.yeij.top/AskaEth/Cyrene/pkg/plugins/file"
|
||||
pluginHTTP "git.yeij.top/AskaEth/Cyrene/pkg/plugins/http"
|
||||
pluginJSON "git.yeij.top/AskaEth/Cyrene/pkg/plugins/json"
|
||||
pluginMD "git.yeij.top/AskaEth/Cyrene/pkg/plugins/markdown"
|
||||
pluginRand "git.yeij.top/AskaEth/Cyrene/pkg/plugins/random"
|
||||
pluginText "git.yeij.top/AskaEth/Cyrene/pkg/plugins/text"
|
||||
pluginWF "git.yeij.top/AskaEth/Cyrene/pkg/plugins/web_fetch"
|
||||
pluginWS "git.yeij.top/AskaEth/Cyrene/pkg/plugins/web_search"
|
||||
plgManager "git.yeij.top/AskaEth/Cyrene-Plugins/manager"
|
||||
plgSDK "git.yeij.top/AskaEth/Cyrene-Plugins/sdk"
|
||||
pluginCalc "git.yeij.top/AskaEth/Cyrene-Plugins/calculator"
|
||||
pluginCrypto "git.yeij.top/AskaEth/Cyrene-Plugins/crypto"
|
||||
pluginDate "git.yeij.top/AskaEth/Cyrene-Plugins/datetime"
|
||||
pluginFile "git.yeij.top/AskaEth/Cyrene-Plugins/file"
|
||||
pluginHTTP "git.yeij.top/AskaEth/Cyrene-Plugins/http"
|
||||
pluginJSON "git.yeij.top/AskaEth/Cyrene-Plugins/json"
|
||||
pluginMD "git.yeij.top/AskaEth/Cyrene-Plugins/markdown"
|
||||
pluginRand "git.yeij.top/AskaEth/Cyrene-Plugins/random"
|
||||
pluginText "git.yeij.top/AskaEth/Cyrene-Plugins/text"
|
||||
pluginWF "git.yeij.top/AskaEth/Cyrene-Plugins/web_fetch"
|
||||
pluginWS "git.yeij.top/AskaEth/Cyrene-Plugins/web_search"
|
||||
)
|
||||
|
||||
var cfg Config
|
||||
|
||||
func main() {
|
||||
// 自动加载 .env 文件(来自仓库根目录)
|
||||
if err := godotenv.Load("../../.env"); err != nil {
|
||||
// 自动加载 .env 文件(优先从可执行文件位置反推仓库根目录)
|
||||
_ = godotenv.Load() // 先尝试当前目录
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
_ = godotenv.Load(filepath.Join(filepath.Dir(exe), "..", "..", ".env"))
|
||||
}
|
||||
// 兜底:如果 LLM_MODEL 仍未设置,打印提示
|
||||
if os.Getenv("LLM_MODEL") == "" {
|
||||
log.Println("ℹ 未找到 .env 文件,将使用环境变量或默认值")
|
||||
}
|
||||
|
||||
@@ -127,7 +133,7 @@ func main() {
|
||||
}
|
||||
|
||||
// 初始化会话历史存储
|
||||
convStore := ctxbuild.NewConversationStore(50)
|
||||
convStore := ctxbuild.NewConversationStore(100)
|
||||
|
||||
// 从数据库恢复主会话历史(避免重启丢失上下文)
|
||||
adminUserID := "admin"
|
||||
@@ -179,6 +185,8 @@ func main() {
|
||||
toolRegistry := plgManager.NewToolRegistry()
|
||||
var visionProvider llm.LLMProvider
|
||||
var ocrProvider llm.LLMProvider
|
||||
var videoProvider llm.LLMProvider
|
||||
var asrProvider llm.ASRProvider
|
||||
if getEnvBool("ENABLE_TOOLS", true) {
|
||||
// 11 个共享通用插件 — 注册其工具到统一注册中心
|
||||
registerPluginTools(toolRegistry, &pluginCalc.CalculatorPlugin{})
|
||||
@@ -246,12 +254,52 @@ func main() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
if ocrProvider == nil {
|
||||
log.Println("OCR模型未配置,图片文字提取将复用视觉模型")
|
||||
}
|
||||
|
||||
// 初始化视频理解模型
|
||||
videoProvider = nil
|
||||
if configLoader != nil && configLoader.HasConfig() {
|
||||
cfg := configLoader.GetConfig()
|
||||
if route, ok := cfg.Routing["video"]; ok && len(route.FallbackChain) > 0 {
|
||||
for _, mid := range route.FallbackChain {
|
||||
if _, ok := cfg.Models[mid]; ok {
|
||||
videoProvider, _ = modelSelector.Select(context.Background(), llm.PurposeVideo)
|
||||
log.Printf("视频理解模型已启用: %s", videoProvider.ModelName())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if videoProvider == nil {
|
||||
log.Println("视频理解模型未配置")
|
||||
}
|
||||
|
||||
// 初始化 ASR 语音识别模型
|
||||
asrProvider = nil
|
||||
if configLoader != nil && configLoader.HasConfig() {
|
||||
cfg := configLoader.GetConfig()
|
||||
if route, ok := cfg.Routing["speech_recognition"]; ok && len(route.FallbackChain) > 0 {
|
||||
for _, mid := range route.FallbackChain {
|
||||
if m, ok := cfg.Models[mid]; ok {
|
||||
if p, ok := cfg.Providers[m.Provider]; ok {
|
||||
asrProvider = llm.NewDashScopeASRProvider(p.BaseURL, p.APIKey, m.Name)
|
||||
log.Printf("ASR语音识别模型已启用: %s", m.Name)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if asrProvider == nil {
|
||||
log.Println("ASR语音识别模型未配置")
|
||||
}
|
||||
|
||||
toolRegistry.Register(wrapTool(tools.NewVisionTool(visionProvider), "vision_analyze", "Image Vision Analysis & OCR", "multimodal"))
|
||||
toolRegistry.Register(wrapTool(tools.NewVideoTool(videoProvider), "video_analyze", "Video Understanding & Analysis", "multimodal"))
|
||||
|
||||
if knowledgeRetriever != nil {
|
||||
toolRegistry.Register(wrapTool(tools.NewKnowledgeSearchTool(knowledgeRetriever), "knowledge_search", "Search Knowledge Base", "knowledge"))
|
||||
@@ -280,6 +328,7 @@ func main() {
|
||||
convStore,
|
||||
adminUserID,
|
||||
adminSessionID,
|
||||
cfg.AdminNickname,
|
||||
memClient,
|
||||
)
|
||||
|
||||
@@ -369,12 +418,24 @@ func main() {
|
||||
orch.SetOCRProvider(ocrProvider)
|
||||
log.Printf("对话编排器: OCR模型已注入 (%s)", ocrProvider.ModelName())
|
||||
}
|
||||
log.Println("对话编排器 v2.0 已就绪")
|
||||
if videoProvider != nil {
|
||||
orch.SetVideoProvider(videoProvider)
|
||||
log.Printf("对话编排器: 视频模型已注入 (%s)\n", videoProvider.ModelName())
|
||||
} else {
|
||||
log.Println("对话编排器: 视频模型未配置,视频理解功能不可用")
|
||||
}
|
||||
if asrProvider != nil && asrProvider.IsAvailable() {
|
||||
orch.SetASRProvider(asrProvider)
|
||||
log.Printf("对话编排器: ASR语音识别模型已注入 (%s)\n", asrProvider.ModelName())
|
||||
} else {
|
||||
log.Println("对话编排器: ASR语音识别模型未配置")
|
||||
}
|
||||
log.Println("对话编排器 v2.0 已就绪")
|
||||
_ = orch
|
||||
|
||||
// 注册对话API端点
|
||||
mux.HandleFunc("/api/v1/chat", func(w http.ResponseWriter, r *http.Request) {
|
||||
handleChat(w, r, orch, ctxBuilder, personaLoader, memRetriever, memExtractor, iotClient, thinker, toolRegistry)
|
||||
handleChat(w, r, orch, ctxBuilder, personaLoader, memRetriever, memExtractor, iotClient, thinker, toolRegistry, adminSessionID)
|
||||
})
|
||||
|
||||
// 注册记忆API端点
|
||||
@@ -431,6 +492,36 @@ func main() {
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
// LLM 调用 SSE 实时推送
|
||||
mux.HandleFunc("/api/v1/llm-calls/stream", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
ch, done := llm.SubscribeCalls()
|
||||
defer llm.UnsubscribeCalls(ch)
|
||||
|
||||
for {
|
||||
select {
|
||||
case rec := <-ch:
|
||||
data, _ := json.Marshal(rec)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
case <-done:
|
||||
return
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
// 工具调用记录
|
||||
mux.HandleFunc("/api/v1/tools/calls", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
@@ -710,6 +801,7 @@ func handleChat(
|
||||
iotClient *tools.IoTClient,
|
||||
thinker *background.Thinker,
|
||||
_ *plgManager.ToolRegistry,
|
||||
adminSessionID string,
|
||||
) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
@@ -722,8 +814,18 @@ func handleChat(
|
||||
SessionID string `json:"session_id"`
|
||||
Message string `json:"message"`
|
||||
Images []string `json:"images,omitempty"` // 图片 base64 data URL
|
||||
VideoURLs []string `json:"video_urls,omitempty"` // 视频 URL (多模态)
|
||||
VoiceURLs []string `json:"voice_urls,omitempty"` // 语音 URL (ASR 转录)
|
||||
Mode string `json:"mode"`
|
||||
Nickname string `json:"nickname,omitempty"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
Source struct {
|
||||
Platform string `json:"platform"`
|
||||
ChannelID string `json:"channel_id"`
|
||||
ChannelType string `json:"channel_type"`
|
||||
SenderName string `json:"sender_name"`
|
||||
OriginalUID string `json:"original_uid"`
|
||||
} `json:"source,omitempty"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "无效的请求体", http.StatusBadRequest)
|
||||
@@ -734,13 +836,48 @@ func handleChat(
|
||||
req.Mode = "text"
|
||||
}
|
||||
|
||||
// 平台静默观察模式:只记录消息、提取记忆、触发后台思考,不生成回复。
|
||||
if req.Mode == "platform_silent" {
|
||||
if thinker != nil {
|
||||
thinker.RecordUserMessage(req.SessionID)
|
||||
}
|
||||
// 图片预处理:静默观察时也分析图片内容,供后台思考使用
|
||||
message := req.Message
|
||||
if len(req.Images) > 0 {
|
||||
startTime := time.Now()
|
||||
augmented := orch.PreprocessImages(r.Context(), message, req.Images)
|
||||
if augmented != message {
|
||||
message = augmented
|
||||
log.Printf("[silent] 图片预处理耗时: %%v", time.Since(startTime))
|
||||
}
|
||||
}
|
||||
ctxBuilder.CacheMessage(req.SessionID, model.RoleUser, message)
|
||||
// 从观察到的群聊消息中提取记忆。
|
||||
orch.ExtractMemoriesOnly(r.Context(), req.UserID, req.SessionID, message)
|
||||
if thinker != nil {
|
||||
thinker.TriggerPostChatThink()
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"status":"silent_processed"}`))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
// Inject admin flag for tool access control.
|
||||
ctx = context.WithValue(ctx, plgManager.CtxKeyIsAdmin, req.IsAdmin)
|
||||
|
||||
// 0. 记录用户活动(重置闲置计时器)
|
||||
if thinker != nil {
|
||||
thinker.RecordUserMessage(req.SessionID)
|
||||
}
|
||||
|
||||
// Admin private messages: redirect to the main admin session so conversation
|
||||
// history is shared across platforms (QQ, web UI, etc.).
|
||||
if req.UserID == "admin" && req.Source.ChannelType == "direct" && adminSessionID != "" {
|
||||
req.SessionID = adminSessionID
|
||||
}
|
||||
|
||||
// 确定用户昵称
|
||||
userNickname := req.Nickname
|
||||
if userNickname == "" {
|
||||
@@ -765,12 +902,15 @@ func handleChat(
|
||||
// 2. 调用 Orchestrator 处理(替代原有的线性处理流程)
|
||||
// Orchestrator 内部处理:意图分析 → 子会话分派 → 结果汇总 → 综合生成回复
|
||||
eventCh, err := orch.ProcessInput(ctx, orchestrator.ProcessParams{
|
||||
UserID: req.UserID,
|
||||
SessionID: req.SessionID,
|
||||
Message: req.Message,
|
||||
Images: req.Images,
|
||||
Mode: req.Mode,
|
||||
Nickname: userNickname,
|
||||
UserID: req.UserID,
|
||||
SessionID: req.SessionID,
|
||||
Message: req.Message,
|
||||
Images: req.Images,
|
||||
VideoURLs: req.VideoURLs,
|
||||
VoiceURLs: req.VoiceURLs,
|
||||
Mode: req.Mode,
|
||||
Nickname: userNickname,
|
||||
ChannelType: req.Source.ChannelType,
|
||||
})
|
||||
if err != nil {
|
||||
errData, _ := json.Marshal(map[string]string{"delta": "", "error": fmt.Sprintf("处理失败: %v", err)})
|
||||
|
||||
@@ -5,12 +5,16 @@ go 1.26.2
|
||||
require (
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/lib/pq v1.10.9
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/audio v0.0.0
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/dashscope v0.0.0
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/logger v0.0.0
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/plugins v0.0.0
|
||||
git.yeij.top/AskaEth/Cyrene-Plugins v0.0.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
replace (
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/audio => ../pkg/audio
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/dashscope => ../pkg/dashscope
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/logger => ../pkg/logger
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/plugins => ../pkg/plugins
|
||||
git.yeij.top/AskaEth/Cyrene-Plugins => ../cyrene-plugins
|
||||
)
|
||||
|
||||
@@ -18,8 +18,8 @@ import (
|
||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/persona"
|
||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/tools"
|
||||
|
||||
plgManager "git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
||||
plgSDK "git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
plgManager "git.yeij.top/AskaEth/Cyrene-Plugins/manager"
|
||||
plgSDK "git.yeij.top/AskaEth/Cyrene-Plugins/sdk"
|
||||
)
|
||||
|
||||
// PendingThought 待推送的后台思考
|
||||
@@ -29,6 +29,38 @@ type PendingThought struct {
|
||||
Consumed bool `json:"consumed"`
|
||||
}
|
||||
|
||||
// PlatformChannel represents a platform channel to observe for background thinking.
|
||||
type PlatformChannel struct {
|
||||
Platform string // qq, telegram, etc.
|
||||
ChannelType string // group, private
|
||||
ChannelID string // group ID or user QQ number
|
||||
}
|
||||
|
||||
// ParsePlatformChannels parses PLATFORM_CHANNELS env var.
|
||||
// Format: "qq:group:123456,telegram:group:789012"
|
||||
func ParsePlatformChannels(raw string) []PlatformChannel {
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
var channels []PlatformChannel
|
||||
for _, part := range strings.Split(raw, ",") {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
fields := strings.SplitN(part, ":", 3)
|
||||
if len(fields) != 3 {
|
||||
continue
|
||||
}
|
||||
channels = append(channels, PlatformChannel{
|
||||
Platform: strings.TrimSpace(fields[0]),
|
||||
ChannelType: strings.TrimSpace(fields[1]),
|
||||
ChannelID: strings.TrimSpace(fields[2]),
|
||||
})
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
// Thinker 后台思考器(事件驱动 + 定时周期双模式)
|
||||
//
|
||||
// 触发机制:
|
||||
@@ -59,6 +91,7 @@ type Thinker struct {
|
||||
convStore *ctxbuild.ConversationStore
|
||||
adminUserID string
|
||||
adminSessionID string
|
||||
adminNickname string
|
||||
activeSessionID string // 当前活跃的前端会话 ID(随用户消息更新)
|
||||
|
||||
// 记忆服务 HTTP 客户端
|
||||
@@ -128,6 +161,10 @@ type Thinker struct {
|
||||
|
||||
// 时区设置 (默认 Asia/Shanghai,可通过 TZ 环境变量覆盖)
|
||||
timeLocation *time.Location
|
||||
|
||||
// 平台静默观察
|
||||
platformChannels []PlatformChannel
|
||||
platformThinkInterval time.Duration
|
||||
}
|
||||
|
||||
// AutonomousToolPolicy 自主思考工具调用安全策略
|
||||
@@ -213,6 +250,10 @@ type ThinkerConfig struct {
|
||||
PostChatDelay time.Duration // 对话后多久触发思考
|
||||
MinThinkGap time.Duration // 两次思考最小间隔 (在线)
|
||||
OfflineThinkGap time.Duration // 两次思考最小间隔 (离线,默认 10 分钟)
|
||||
|
||||
// 平台静默观察
|
||||
PlatformSilentThinkInterval time.Duration // 平台记忆观察间隔 (默认 600s,0 = 禁用)
|
||||
PlatformChannels []PlatformChannel
|
||||
}
|
||||
|
||||
// DefaultThinkerConfig 默认配置
|
||||
@@ -232,6 +273,8 @@ func DefaultThinkerConfig() ThinkerConfig {
|
||||
PostChatDelay: getEnvDuration("THINK_POST_CHAT_DELAY_SEC", 5),
|
||||
MinThinkGap: getEnvDuration("THINK_MIN_GAP_SEC", 30),
|
||||
OfflineThinkGap: getEnvDuration("THINK_OFFLINE_GAP_SEC", 600),
|
||||
PlatformSilentThinkInterval: getEnvDuration("PLATFORM_THINK_INTERVAL_SEC", 600),
|
||||
PlatformChannels: ParsePlatformChannels(os.Getenv("PLATFORM_CHANNELS")),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,6 +291,7 @@ func NewThinker(
|
||||
convStore *ctxbuild.ConversationStore,
|
||||
adminUserID string,
|
||||
adminSessionID string,
|
||||
adminNickname string,
|
||||
memClient *memory.Client,
|
||||
) *Thinker {
|
||||
// 加载时区配置
|
||||
@@ -281,13 +325,16 @@ func NewThinker(
|
||||
convStore: convStore,
|
||||
adminUserID: adminUserID,
|
||||
adminSessionID: adminSessionID,
|
||||
adminNickname: adminNickname,
|
||||
memClient: memClient,
|
||||
pendingThoughts: make([]*PendingThought, 0),
|
||||
lastUserMessage: time.Now(),
|
||||
stopCh: make(chan struct{}),
|
||||
chain: NewThinkChain(10),
|
||||
autoToolPolicy: DefaultAutonomousToolPolicy(),
|
||||
proactiveGuard: DefaultProactiveGuard(),
|
||||
pendingThoughts: make([]*PendingThought, 0),
|
||||
lastUserMessage: time.Now(),
|
||||
stopCh: make(chan struct{}),
|
||||
chain: NewThinkChain(10),
|
||||
autoToolPolicy: DefaultAutonomousToolPolicy(),
|
||||
proactiveGuard: DefaultProactiveGuard(),
|
||||
platformChannels: cfg.PlatformChannels,
|
||||
platformThinkInterval: cfg.PlatformSilentThinkInterval,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -314,8 +361,14 @@ func (t *Thinker) Start() {
|
||||
go t.periodicThinkLoop()
|
||||
}
|
||||
|
||||
log.Printf("[后台思考] 已就绪 — 周期=%v + 事件驱动模式 (静默超时=%v, 对话后延迟=%v, 在线最小间隔=%v, 离线最小间隔=%v, 管理员=%s)",
|
||||
t.thinkInterval, t.silenceTimeout, t.postChatDelay, t.minThinkGap, t.offlineThinkGap, t.adminUserID)
|
||||
// 启动平台静默观察循环
|
||||
if len(t.platformChannels) > 0 && t.platformThinkInterval > 0 {
|
||||
t.wg.Add(1)
|
||||
go t.platformObservationLoop()
|
||||
}
|
||||
|
||||
log.Printf("[后台思考] 已就绪 — 周期=%v + 事件驱动模式 (静默超时=%v, 对话后延迟=%v, 在线最小间隔=%v, 离线最小间隔=%v, 管理员=%s, 平台频道=%d)",
|
||||
t.thinkInterval, t.silenceTimeout, t.postChatDelay, t.minThinkGap, t.offlineThinkGap, t.adminUserID, len(t.platformChannels))
|
||||
|
||||
// 启动后首次思考:延迟 5s,让服务完全初始化后再触发
|
||||
go func() {
|
||||
@@ -460,6 +513,116 @@ func (t *Thinker) resetSilenceTimer() {
|
||||
}()
|
||||
}
|
||||
|
||||
// platformObservationLoop periodically queries platform channel memories and generates observations.
|
||||
func (t *Thinker) platformObservationLoop() {
|
||||
defer t.wg.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("[后台思考] 平台观察循环 panic 恢复: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
interval := t.platformThinkInterval
|
||||
log.Printf("[后台思考] 平台观察循环已启动 (间隔=%v, 频道数=%d)", interval, len(t.platformChannels))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.stopCh:
|
||||
log.Println("[后台思考] 平台观察循环已停止")
|
||||
return
|
||||
case <-time.After(interval):
|
||||
t.performPlatformObservation()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performPlatformObservation queries memories from all platform channels,
|
||||
// runs an intermediate LLM session to summarize, and stores the result as a pending thought.
|
||||
func (t *Thinker) performPlatformObservation() {
|
||||
if t.memClient == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var channelSummaries []string
|
||||
for _, ch := range t.platformChannels {
|
||||
namespace := fmt.Sprintf("platform_%s_%s_%s", ch.Platform, ch.ChannelType, ch.ChannelID)
|
||||
memories, err := t.memClient.Query(ctx, model.MemoryQuery{
|
||||
UserID: namespace,
|
||||
Limit: 20,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("[后台思考] 查询平台频道 %s 记忆失败: %v", namespace, err)
|
||||
continue
|
||||
}
|
||||
if len(memories) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("【%s %s %s】\n", ch.Platform, ch.ChannelType, ch.ChannelID))
|
||||
for i, m := range memories {
|
||||
if i >= 10 {
|
||||
sb.WriteString(fmt.Sprintf("... 还有 %d 条记忆\n", len(memories)-10))
|
||||
break
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("- %s\n", m.Content))
|
||||
}
|
||||
channelSummaries = append(channelSummaries, sb.String())
|
||||
}
|
||||
|
||||
if len(channelSummaries) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[后台思考] 平台观察:%d 个频道有记忆数据,调用中间会话生成摘要...", len(channelSummaries))
|
||||
|
||||
systemPrompt := "你是昔涟的后台观察助手。以下是各聊天平台频道最近的观察摘要。\n请生成简洁报告:\n1. 各频道近期讨论主题(每频道1-2句)\n2. 是否有需要关注的重要/紧急事项\n3. 整体氛围评估\n注意:这些记忆可能来自不同的群聊成员(不只是开拓者),请以实际发言者为主语描述。不要直接对开拓者说话,这是给昔涟参考的幕后报告。\n输出为JSON格式:{\"summary\": \"报告内容\", \"needs_attention\": true/false}"
|
||||
|
||||
userPrompt := strings.Join(channelSummaries, "\n\n")
|
||||
|
||||
messages := []model.LLMMessage{
|
||||
{Role: model.RoleSystem, Content: systemPrompt},
|
||||
{Role: model.RoleUser, Content: userPrompt},
|
||||
}
|
||||
|
||||
resp, err := t.toolAdapter.Chat(ctx, messages)
|
||||
if err != nil {
|
||||
log.Printf("[后台思考] 中间会话 LLM 调用失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Summary string `json:"summary"`
|
||||
NeedsAttention bool `json:"needs_attention"`
|
||||
}
|
||||
content := strings.TrimSpace(resp.Content)
|
||||
if idx := strings.Index(content, "{"); idx >= 0 {
|
||||
if end := strings.LastIndex(content, "}"); end > idx {
|
||||
content = content[idx : end+1]
|
||||
}
|
||||
}
|
||||
if err := json.Unmarshal([]byte(content), &result); err != nil {
|
||||
result.Summary = resp.Content
|
||||
}
|
||||
|
||||
observationContent := fmt.Sprintf("[平台观察 %s]\n%s", time.Now().In(t.timeLocation).Format("15:04"), result.Summary)
|
||||
t.mu.Lock()
|
||||
t.pendingThoughts = append(t.pendingThoughts, &PendingThought{
|
||||
Content: observationContent,
|
||||
CreatedAt: time.Now(),
|
||||
Consumed: false,
|
||||
})
|
||||
if len(t.pendingThoughts) > 10 {
|
||||
t.pendingThoughts = t.pendingThoughts[len(t.pendingThoughts)-10:]
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
log.Printf("[后台思考] 平台观察摘要已生成 (长度=%d, 需要关注=%v)", len(result.Summary), result.NeedsAttention)
|
||||
}
|
||||
|
||||
// periodicThinkLoop 周期性自主思考循环
|
||||
//
|
||||
// 使用动态间隔:若配置了 ScheduleLoader,每次循环根据当前时段计算间隔;
|
||||
@@ -628,6 +791,66 @@ func (t *Thinker) performThink(triggerReason string) {
|
||||
log.Printf("[后台思考] 模糊搜索补充 %d 条记忆", len(fuzzyResults))
|
||||
}
|
||||
}
|
||||
|
||||
// Also pull recent channel/group memories so the model is aware of group activity.
|
||||
if t.memClient != nil && len(t.platformChannels) > 0 {
|
||||
chanSeen := make(map[string]bool)
|
||||
for _, m := range memories {
|
||||
chanSeen[m.ID] = true
|
||||
}
|
||||
oldCount := len(memories)
|
||||
for _, ch := range t.platformChannels {
|
||||
namespace := fmt.Sprintf("platform_%s_%s_%s", ch.Platform, ch.ChannelType, ch.ChannelID)
|
||||
chanMems, cErr := t.memClient.Query(ctx, model.MemoryQuery{
|
||||
UserID: namespace,
|
||||
Limit: 5,
|
||||
})
|
||||
if cErr != nil {
|
||||
continue
|
||||
}
|
||||
for _, m := range chanMems {
|
||||
if !chanSeen[m.ID] {
|
||||
chanSeen[m.ID] = true
|
||||
labeled := m
|
||||
labeled.Content = fmt.Sprintf("[群聊%s] %s", ch.ChannelID, m.Content)
|
||||
memories = append(memories, labeled)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(memories) > oldCount {
|
||||
log.Printf("[后台思考] 频道记忆补充 %d 条", len(memories)-oldCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Also pull recent channel/group memories so the model is aware of group activity.
|
||||
if t.memClient != nil && len(t.platformChannels) > 0 {
|
||||
chanSeen := make(map[string]bool)
|
||||
for _, m := range memories {
|
||||
chanSeen[m.ID] = true
|
||||
}
|
||||
oldCount := len(memories)
|
||||
for _, ch := range t.platformChannels {
|
||||
namespace := fmt.Sprintf("platform_%s_%s_%s", ch.Platform, ch.ChannelType, ch.ChannelID)
|
||||
chanMems, cErr := t.memClient.Query(ctx, model.MemoryQuery{
|
||||
UserID: namespace,
|
||||
Limit: 5,
|
||||
})
|
||||
if cErr != nil {
|
||||
continue
|
||||
}
|
||||
for _, m := range chanMems {
|
||||
if !chanSeen[m.ID] {
|
||||
chanSeen[m.ID] = true
|
||||
labeled := m
|
||||
labeled.Content = fmt.Sprintf("[群聊%s] %s", ch.ChannelID, m.Content)
|
||||
memories = append(memories, labeled)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(memories) > oldCount {
|
||||
log.Printf("[后台思考] 频道记忆补充 %d 条", len(memories)-oldCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 查询 IoT 设备状态(每次都查询,无间隔限制)
|
||||
@@ -639,9 +862,22 @@ func (t *Thinker) performThink(triggerReason string) {
|
||||
}
|
||||
}
|
||||
|
||||
// 4.5 获取最近平台观察(定期触发和对话后触发时注入)
|
||||
var platformObservation string
|
||||
if triggerReason == "periodic" || triggerReason == "post_chat" {
|
||||
t.mu.Lock()
|
||||
for i := len(t.pendingThoughts) - 1; i >= 0; i-- {
|
||||
if strings.HasPrefix(t.pendingThoughts[i].Content, "[平台观察") {
|
||||
platformObservation = t.pendingThoughts[i].Content
|
||||
break
|
||||
}
|
||||
}
|
||||
t.mu.Unlock()
|
||||
}
|
||||
|
||||
// 5. 构建思考提示词(根据触发原因调整)
|
||||
systemPrompt := t.buildThinkingSystemPrompt(personaConfig, triggerReason)
|
||||
userPrompt := t.buildThinkingUserPrompt(memories, convHistory, deviceSummary, triggerReason)
|
||||
userPrompt := t.buildThinkingUserPrompt(memories, convHistory, deviceSummary, triggerReason, platformObservation)
|
||||
|
||||
messages := []model.LLMMessage{
|
||||
{Role: model.RoleSystem, Content: systemPrompt},
|
||||
@@ -886,7 +1122,9 @@ func (t *Thinker) buildThinkingSystemPrompt(personaConfig *persona.PersonaConfig
|
||||
2. 3-4句话即可。`
|
||||
}
|
||||
|
||||
return basePrompt + thinkingInstructions
|
||||
// Security: only admin can authorize sensitive operations.
|
||||
securityRule := fmt.Sprintf("\n\n## 安全规则\n- 涉及敏感操作(调整IoT设备、执行主机操作等)的请求,只有%s(管理员)下达的指令才能执行。其他陌生人让你做的敏感操作不要执行。\n", t.adminNickname)
|
||||
return basePrompt + thinkingInstructions + securityRule
|
||||
}
|
||||
|
||||
// buildThinkingUserPrompt 构建思考用的用户提示词
|
||||
@@ -895,6 +1133,7 @@ func (t *Thinker) buildThinkingUserPrompt(
|
||||
convHistory []model.LLMMessage,
|
||||
deviceSummary string,
|
||||
triggerReason string,
|
||||
platformObservation string,
|
||||
) string {
|
||||
var sb strings.Builder
|
||||
|
||||
@@ -925,12 +1164,12 @@ func (t *Thinker) buildThinkingUserPrompt(
|
||||
|
||||
switch triggerReason {
|
||||
case "post_chat":
|
||||
sb.WriteString("开拓者刚和你聊完天。你想自然地在心里回味一下刚才的对话……\n")
|
||||
sb.WriteString("刚有人和你聊完天。你想自然地在心里回味一下刚才的对话……\n")
|
||||
case "silence":
|
||||
t.mu.Lock()
|
||||
silenceDuration := time.Since(t.lastUserMessage)
|
||||
t.mu.Unlock()
|
||||
sb.WriteString(fmt.Sprintf("开拓者已经大约 %s 没有说话了。你有点想知道他在做什么……\n",
|
||||
sb.WriteString(fmt.Sprintf("已经大约 %s 没有说话了。你有点想知道大家在做什么……\n",
|
||||
formatDurationHuman(silenceDuration)))
|
||||
default:
|
||||
sb.WriteString("现在是你的自由思考时间。\n")
|
||||
@@ -938,14 +1177,21 @@ func (t *Thinker) buildThinkingUserPrompt(
|
||||
|
||||
// 对话历史
|
||||
var lastUserMsg string
|
||||
lastUserIsAdmin := false
|
||||
if len(convHistory) > 0 {
|
||||
sb.WriteString("\n【最近的对话】\n")
|
||||
sb.WriteString(fmt.Sprintf("(标签说明:每条消息前的 [名字] 标识了说话者。只有 [%s] 才是%s。其他名字是群聊中的其他成员,不是%s。请严格根据标签区分不同的人,不要张冠李戴。)\n",
|
||||
t.adminNickname, t.adminNickname, t.adminNickname))
|
||||
msgCount := 0
|
||||
for _, msg := range convHistory {
|
||||
if msg.Role == model.RoleUser || msg.Role == model.RoleAssistant {
|
||||
roleLabel := "开拓者"
|
||||
roleLabel := "用户"
|
||||
if msg.Role == model.RoleAssistant {
|
||||
roleLabel = "昔涟"
|
||||
} else if strings.Contains(msg.Content, t.adminNickname+"/") {
|
||||
roleLabel = t.adminNickname
|
||||
} else if name := extractGroupSender(msg.Content); name != "" {
|
||||
roleLabel = name
|
||||
}
|
||||
content := msg.Content
|
||||
runes := []rune(content)
|
||||
@@ -956,24 +1202,26 @@ func (t *Thinker) buildThinkingUserPrompt(
|
||||
msgCount++
|
||||
if msg.Role == model.RoleUser {
|
||||
lastUserMsg = msg.Content
|
||||
lastUserIsAdmin = roleLabel == t.adminNickname
|
||||
}
|
||||
}
|
||||
}
|
||||
if msgCount == 0 {
|
||||
sb.WriteString("(暂无对话历史)\n")
|
||||
sb.WriteString("(暂无对话历史)\n")
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("\n【最近的对话】\n(暂无对话历史)\n")
|
||||
}
|
||||
|
||||
// 关键:强调根据对话历史判断用户当前状态
|
||||
if lastUserMsg != "" {
|
||||
// 关键:强调根据对话历史判断当前状态
|
||||
if lastUserMsg != "" && lastUserIsAdmin {
|
||||
sb.WriteString(fmt.Sprintf("\n🔍 **重要**:开拓者最后说的是「%s」。请认真判断:他现在是不是在休息/睡觉/忙?如果是,不要输出【主动消息】指令行。\n", lastUserMsg))
|
||||
}
|
||||
|
||||
// 现有记忆
|
||||
// 现有记忆(可能来自管理员对话、群聊观察等多个来源)
|
||||
if len(memories) > 0 {
|
||||
sb.WriteString("\n【你记得的关于开拓者的事】\n")
|
||||
sb.WriteString("\n【你近期收集到的信息】\n")
|
||||
sb.WriteString("(这些记忆来自不同的对话和群聊,不一定都和开拓者有关。请根据记忆内容中标注的来源判断是谁的经历。)\n")
|
||||
for i, m := range memories {
|
||||
if i >= 15 {
|
||||
sb.WriteString(fmt.Sprintf("... 还有 %d 条记忆未列出\n", len(memories)-15))
|
||||
@@ -983,7 +1231,7 @@ func (t *Thinker) buildThinkingUserPrompt(
|
||||
m.Category.DisplayName(), m.Importance, m.Content))
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("\n【你记得的关于开拓者的事】\n(暂无相关记忆)\n")
|
||||
sb.WriteString("\n【你近期收集到的信息】\n(暂无相关记忆)\n")
|
||||
}
|
||||
|
||||
// 思考链:注入上一轮的结论和待续问题
|
||||
@@ -1009,12 +1257,19 @@ func (t *Thinker) buildThinkingUserPrompt(
|
||||
sb.WriteString("\n" + deviceSummary)
|
||||
}
|
||||
|
||||
// 平台观察摘要 (中间会话产生的报告)
|
||||
if platformObservation != "" {
|
||||
sb.WriteString("\n\n【平台频道观察报告(中间会话生成,可能包含多位群聊成员的信息)】\n")
|
||||
sb.WriteString(platformObservation)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// 结尾引导
|
||||
sb.WriteString("\n---\n现在请写下你的私人反思。")
|
||||
sb.WriteString("\n记住:这是日记,用第三人称或自言自语的方式。")
|
||||
sb.WriteString("\n⚠️ 如果开拓者正在休息/睡觉/忙碌——不要输出【主动消息】指令行。你可以在心里想他,但不要去打扰。")
|
||||
sb.WriteString("\n只有在你确认他现在是醒着、有空、且真的需要关心时,才输出一行【主动消息】+ 你要发给他的话。")
|
||||
sb.WriteString("\n❗【主动消息】标记必须独占一行开头,后面紧跟你要对开拓者说的话(用\"你\"称呼),语气自然像主动找他聊天。不要在反思正文中提及\"主动消息\"这个词——如果需要表达这个意思但又不打算发消息,用别的词代替。")
|
||||
sb.WriteString("\n⚠️ 如果有人正在休息/睡觉/忙碌——不要输出【主动消息】指令行。你可以在心里想,但不要去打扰。")
|
||||
sb.WriteString("\n只有在你确认对方现在是醒着、有空、且真的需要关心时,才输出一行【主动消息】+ 你要发给他的话。")
|
||||
sb.WriteString("\n❗【主动消息】标记必须独占一行开头,后面紧跟你要说的话(用\"你\"称呼),语气自然像主动找对方聊天。不要在反思正文中提及\"主动消息\"这个词——如果需要表达这个意思但又不打算发消息,用别的词代替。")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -1572,6 +1827,27 @@ func (t *Thinker) expandMemoryKeywords(ctx context.Context, message string) []st
|
||||
return keywords
|
||||
}
|
||||
|
||||
// extractGroupSender extracts the sender name from a group message prefix.
|
||||
// Group messages have the format: [群聊 GROUPID] SENDERNAME (UID):\ncontent
|
||||
// Returns empty string if the message doesn't match the group format.
|
||||
func extractGroupSender(content string) string {
|
||||
if !strings.HasPrefix(content, "[群聊 ") {
|
||||
return ""
|
||||
}
|
||||
// Find "] " which ends the group label
|
||||
bracketEnd := strings.Index(content, "] ")
|
||||
if bracketEnd < 0 {
|
||||
return ""
|
||||
}
|
||||
rest := content[bracketEnd+2:]
|
||||
// Find " (" which precedes the UID
|
||||
parenIdx := strings.Index(rest, " (")
|
||||
if parenIdx < 0 {
|
||||
return ""
|
||||
}
|
||||
return rest[:parenIdx]
|
||||
}
|
||||
|
||||
// lastUserMessage extracts the last user message from conversation history.
|
||||
func lastUserMessage(history []model.LLMMessage) string {
|
||||
for i := len(history) - 1; i >= 0; i-- {
|
||||
|
||||
@@ -163,8 +163,9 @@ type BuildParams struct {
|
||||
Memories []memory.MemoryEntry
|
||||
HistoryLimit int
|
||||
DeviceContext string // 注入的设备状态文本
|
||||
PendingThoughts []string // 待注入的后台思考
|
||||
Nickname string // 用户昵称 (昔涟对用户的称呼)
|
||||
PendingThoughts []string // 待注入的后台思考
|
||||
PlatformObservationSummary string // 平台观察摘要(中间会话生成)
|
||||
Nickname string // 用户昵称 (昔涟对用户的称呼)
|
||||
}
|
||||
|
||||
// Build 构建发送给LLM的完整消息列表
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/audio"
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/dashscope"
|
||||
)
|
||||
|
||||
// ASRProvider handles speech-to-text transcription.
|
||||
type ASRProvider interface {
|
||||
Transcribe(ctx context.Context, audioURL, language string) (string, error)
|
||||
IsAvailable() bool
|
||||
ModelName() string
|
||||
}
|
||||
|
||||
// DashScopeASRProvider uses DashScope Paraformer API for offline speech recognition.
|
||||
type DashScopeASRProvider struct {
|
||||
model string
|
||||
client *dashscope.RESTClient
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
// NewDashScopeASRProvider creates a DashScope ASR provider.
|
||||
func NewDashScopeASRProvider(baseURL, apiKey, model string) *DashScopeASRProvider {
|
||||
if model == "" {
|
||||
model = "qwen3-asr-flash-2026-02-10"
|
||||
}
|
||||
return &DashScopeASRProvider{
|
||||
model: model,
|
||||
client: dashscope.NewRESTClient(apiKey),
|
||||
http: &http.Client{Timeout: 60 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// IsAvailable returns true if the API key is configured.
|
||||
func (p *DashScopeASRProvider) IsAvailable() bool {
|
||||
return p.client.IsAvailable()
|
||||
}
|
||||
|
||||
// ModelName returns the ASR model name.
|
||||
func (p *DashScopeASRProvider) ModelName() string {
|
||||
return p.model
|
||||
}
|
||||
|
||||
// downloadAudio fetches audio data from a URL and returns the bytes with inferred format.
|
||||
func (p *DashScopeASRProvider) downloadAudio(ctx context.Context, audioURL string) ([]byte, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", audioURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("create download request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := p.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("download failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, 10<<20)) // 10 MB limit
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("read audio data: %w", err)
|
||||
}
|
||||
|
||||
format := inferAudioFormat(audioURL, resp.Header.Get("Content-Type"))
|
||||
return data, format, nil
|
||||
}
|
||||
|
||||
// inferAudioFormat determines the audio format from URL extension or Content-Type header.
|
||||
func inferAudioFormat(urlStr, contentType string) string {
|
||||
u, err := url.Parse(urlStr)
|
||||
if err == nil {
|
||||
path := u.Path
|
||||
if idx := strings.LastIndex(path, "."); idx >= 0 {
|
||||
ext := strings.ToLower(path[idx+1:])
|
||||
switch ext {
|
||||
case "amr", "wav", "mp3", "ogg", "flac", "m4a", "aac", "opus", "webm", "pcm":
|
||||
return ext
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.Contains(contentType, "audio/amr") || strings.Contains(contentType, "amr") {
|
||||
return "amr"
|
||||
}
|
||||
if strings.Contains(contentType, "audio/wav") || strings.Contains(contentType, "wav") {
|
||||
return "wav"
|
||||
}
|
||||
if strings.Contains(contentType, "audio/mpeg") || strings.Contains(contentType, "mp3") {
|
||||
return "mp3"
|
||||
}
|
||||
if strings.Contains(contentType, "audio/ogg") || strings.Contains(contentType, "opus") {
|
||||
return "ogg"
|
||||
}
|
||||
return "amr" // default for QQ voice messages
|
||||
}
|
||||
|
||||
func (p *DashScopeASRProvider) Transcribe(ctx context.Context, audioURL, language string) (string, error) {
|
||||
if !p.IsAvailable() {
|
||||
return "", fmt.Errorf("DashScope ASR API key not configured")
|
||||
}
|
||||
|
||||
audioData, format, err := p.downloadAudio(ctx, audioURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("download audio: %w", err)
|
||||
}
|
||||
|
||||
// 转码为 16kHz mono PCM,提升识别兼容性
|
||||
pcmData, err := audio.ConvertToPCM16(audioData, format)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("audio transcode: %w", err)
|
||||
}
|
||||
|
||||
if language == "" || language == "auto" {
|
||||
language = "zh"
|
||||
}
|
||||
|
||||
return p.client.Transcribe(ctx, p.model, pcmData, "pcm", 16000, language)
|
||||
}
|
||||
@@ -52,6 +52,8 @@ func (cl *CallLogger) log(r CallRecord) {
|
||||
if cl.size < cl.capacity {
|
||||
cl.size++
|
||||
}
|
||||
|
||||
broadcastCall(r)
|
||||
}
|
||||
|
||||
func (cl *CallLogger) get(limit int) []CallRecord {
|
||||
@@ -72,3 +74,49 @@ func (cl *CallLogger) get(limit int) []CallRecord {
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// --- SSE subscriber system ---
|
||||
|
||||
type callSubscriber struct {
|
||||
ch chan CallRecord
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
var (
|
||||
callSubscribers []*callSubscriber
|
||||
callSubscribersMu sync.RWMutex
|
||||
)
|
||||
|
||||
// SubscribeCalls returns a channel that receives new CallRecords and a done channel.
|
||||
func SubscribeCalls() (<-chan CallRecord, <-chan struct{}) {
|
||||
ch := make(chan CallRecord, 20)
|
||||
done := make(chan struct{})
|
||||
callSubscribersMu.Lock()
|
||||
callSubscribers = append(callSubscribers, &callSubscriber{ch: ch, done: done})
|
||||
callSubscribersMu.Unlock()
|
||||
return ch, done
|
||||
}
|
||||
|
||||
// UnsubscribeCalls removes a subscriber. Safe to call multiple times.
|
||||
func UnsubscribeCalls(ch <-chan CallRecord) {
|
||||
callSubscribersMu.Lock()
|
||||
defer callSubscribersMu.Unlock()
|
||||
for i, s := range callSubscribers {
|
||||
if s.ch == ch {
|
||||
close(s.done)
|
||||
callSubscribers = append(callSubscribers[:i], callSubscribers[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func broadcastCall(r CallRecord) {
|
||||
callSubscribersMu.RLock()
|
||||
defer callSubscribersMu.RUnlock()
|
||||
for _, s := range callSubscribers {
|
||||
select {
|
||||
case s.ch <- r:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -274,7 +274,7 @@ func (p *OpenAIProvider) doChat(ctx context.Context, messages []model.LLMMessage
|
||||
resolvedImages := p.resolveImages(msg.Images)
|
||||
oaiMsg := openAIMessage{
|
||||
Role: string(msg.Role),
|
||||
Content: buildContent(msg.Content, resolvedImages),
|
||||
Content: buildContent(msg.Content, resolvedImages, msg.VideoURLs),
|
||||
Name: msg.Name,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
@@ -382,7 +382,7 @@ func (p *OpenAIProvider) doChatStream(ctx context.Context, messages []model.LLMM
|
||||
resolvedImages := p.resolveImages(msg.Images)
|
||||
oaiMsg := openAIMessage{
|
||||
Role: string(msg.Role),
|
||||
Content: buildContent(msg.Content, resolvedImages),
|
||||
Content: buildContent(msg.Content, resolvedImages, msg.VideoURLs),
|
||||
Name: msg.Name,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
@@ -521,23 +521,27 @@ func (p *OpenAIProvider) downloadAsDataURL(url string) (string, error) {
|
||||
|
||||
// buildContent converts text + optional images to API content format.
|
||||
// Returns a plain string if no images, or a multimodal array otherwise.
|
||||
func buildContent(text string, images []string) interface{} {
|
||||
if len(images) == 0 {
|
||||
func buildContent(text string, images []string, videoURLs []string) interface{} {
|
||||
if len(images) == 0 && len(videoURLs) == 0 {
|
||||
return text
|
||||
}
|
||||
parts := make([]model.ImageContent, 0, len(images)+1)
|
||||
parts := make([]interface{}, 0, len(images)+len(videoURLs)+1)
|
||||
if text != "" {
|
||||
parts = append(parts, model.ImageContent{
|
||||
Type: "text",
|
||||
Text: text,
|
||||
parts = append(parts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": text,
|
||||
})
|
||||
}
|
||||
for _, img := range images {
|
||||
parts = append(parts, model.ImageContent{
|
||||
Type: "image_url",
|
||||
ImageURL: &model.ImageURL{
|
||||
URL: img,
|
||||
},
|
||||
parts = append(parts, map[string]interface{}{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]string{"url": img},
|
||||
})
|
||||
}
|
||||
for _, video := range videoURLs {
|
||||
parts = append(parts, map[string]interface{}{
|
||||
"type": "video_url",
|
||||
"video_url": map[string]string{"url": video},
|
||||
})
|
||||
}
|
||||
return parts
|
||||
|
||||
@@ -18,8 +18,10 @@ const (
|
||||
PurposeIntentAnalysis ModelPurpose = "intent_analysis"
|
||||
PurposeToolCalling ModelPurpose = "tool_calling"
|
||||
PurposeMemoryExtraction ModelPurpose = "memory_extraction"
|
||||
PurposeVision ModelPurpose = "vision"
|
||||
PurposeOCR ModelPurpose = "ocr"
|
||||
PurposeVision ModelPurpose = "vision"
|
||||
PurposeVideo ModelPurpose = "video"
|
||||
PurposeOCR ModelPurpose = "ocr"
|
||||
PurposeSpeechRecognition ModelPurpose = "speech_recognition"
|
||||
)
|
||||
|
||||
// ErrModelNotRequired is returned when an optional model is unavailable.
|
||||
|
||||
@@ -34,16 +34,28 @@ func (e *Extractor) ExtractAndStore(ctx context.Context, userID, sessionID, user
|
||||
logger.Printf("[memory] 记忆提取失败: %v", err)
|
||||
return
|
||||
}
|
||||
e.storeMemories(ctx, userID, sessionID, memories)
|
||||
}
|
||||
|
||||
// ExtractObservations 从观察到的单条消息中提取记忆(无语境回复)。
|
||||
// 用于 platform_silent 模式:昔涟被动观察群聊,提取值得记住的信息。
|
||||
func (e *Extractor) ExtractObservations(ctx context.Context, userID, sessionID, message string) {
|
||||
memories, err := e.extractObservations(ctx, message)
|
||||
if err != nil {
|
||||
logger.Printf("[memory] 观察记忆提取失败: %v", err)
|
||||
return
|
||||
}
|
||||
e.storeMemories(ctx, userID, sessionID, memories)
|
||||
}
|
||||
|
||||
func (e *Extractor) storeMemories(ctx context.Context, userID, sessionID string, memories []model.MemoryEntry) {
|
||||
for _, mem := range memories {
|
||||
mem.UserID = userID
|
||||
mem.SessionID = sessionID
|
||||
mem.Source = "conversation"
|
||||
|
||||
// 去重检查:查询用户已有的相关记忆
|
||||
existing, err := e.findSimilar(ctx, userID, &mem)
|
||||
if err == nil && existing != nil {
|
||||
// 相似度 > 80%,更新现有记忆
|
||||
e.mergeMemory(ctx, existing, &mem)
|
||||
continue
|
||||
}
|
||||
@@ -56,6 +68,60 @@ func (e *Extractor) ExtractAndStore(ctx context.Context, userID, sessionID, user
|
||||
}
|
||||
}
|
||||
|
||||
// extractObservations 从观察到的消息中提取记忆(无助手回复)
|
||||
func (e *Extractor) extractObservations(ctx context.Context, message string) ([]model.MemoryEntry, error) {
|
||||
if e.llmChat != nil {
|
||||
return e.extractObservationsWithLLM(ctx, message)
|
||||
}
|
||||
return e.extractWithRules(message, ""), nil
|
||||
}
|
||||
|
||||
// extractObservationsWithLLM 使用LLM从观察到的消息中提取值得记住的信息
|
||||
func (e *Extractor) extractObservationsWithLLM(ctx context.Context, message string) ([]model.MemoryEntry, error) {
|
||||
prompt := fmt.Sprintf(`分析以下在聊天平台观察到的消息,提取值得记住的信息作为记忆。
|
||||
|
||||
观察到的消息: %s
|
||||
|
||||
请以JSON格式返回提取的记忆。这条消息来自群聊/频道,昔涟只是旁观者。
|
||||
消息格式为:[群聊 群号] 发送者昵称 (QQ号):消息内容
|
||||
提取角度:这条消息中包含了什么关于消息发送者、讨论主题、事件或氛围的信息?
|
||||
重要:请以实际发送者的名字为主语(如"某某说..."),不要统一用"开拓者"称呼所有发言者。
|
||||
|
||||
每条记忆需要包含以下字段:
|
||||
- content: 完整的记忆内容(一句话描述,客观准确)
|
||||
- summary: 简短摘要(10字以内)
|
||||
- category: 记忆分类,必须是以下之一:
|
||||
* conversation: 对话主题/讨论摘要
|
||||
* event: 事件记录(发生了什么)
|
||||
* personal_info: 参与者的个人信息
|
||||
* knowledge: 知识性信息
|
||||
* user_preference: 某人的偏好
|
||||
* task: 提及的计划/任务
|
||||
- priority: 优先级 (0=临时, 1=普通, 2=重要, 3=核心)
|
||||
- importance: 重要程度 1-10
|
||||
* 1-3: 日常闲聊,不太重要
|
||||
* 4-6: 一般有用的信息
|
||||
* 7-8: 重要信息,值得长期记住
|
||||
* 9-10: 核心信息
|
||||
- keywords: 关键词标签数组(3-5个词)
|
||||
|
||||
只提取有意义的信息。如果消息只是日常寒暄或无实质内容,返回空数组。
|
||||
|
||||
输出格式:
|
||||
{"memories": [{"content": "...", "summary": "...", "category": "...", "priority": 1, "importance": 6, "keywords": ["词1", "词2"]}]}
|
||||
`, message)
|
||||
|
||||
resp, err := e.llmChat(ctx, []model.LLMMessage{
|
||||
{Role: "system", Content: "你是一个聊天观察记录助手。你只输出JSON格式的结果。你的任务是从观察到的聊天消息中提取值得记住的信息。"},
|
||||
{Role: "user", Content: prompt},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("LLM提取观察记忆失败: %w", err)
|
||||
}
|
||||
|
||||
return e.parseExtractionResult(resp.Content)
|
||||
}
|
||||
|
||||
// extract 从对话中提取记忆
|
||||
func (e *Extractor) extract(ctx context.Context, userMessage, assistantResponse string) ([]model.MemoryEntry, error) {
|
||||
// 如果有LLM,使用LLM提取
|
||||
@@ -128,11 +194,18 @@ func (e *Extractor) extractWithLLM(ctx context.Context, userMessage, assistantRe
|
||||
return nil, fmt.Errorf("LLM提取记忆失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析JSON
|
||||
entries, err := e.parseExtractionResult(resp.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// parseExtractionResult 解析LLM返回的记忆提取JSON结果
|
||||
func (e *Extractor) parseExtractionResult(text string) ([]model.MemoryEntry, error) {
|
||||
result := MemoryExtractionResult{}
|
||||
content := extractJSON(resp.Content)
|
||||
content := extractJSON(text)
|
||||
if err := json.Unmarshal([]byte(content), &result); err != nil {
|
||||
// 尝试作为数组解析(兼容旧格式)
|
||||
var arrResult []ExtractedMemory
|
||||
if err2 := json.Unmarshal([]byte(content), &arrResult); err2 != nil {
|
||||
return nil, fmt.Errorf("解析记忆JSON失败: %w (原始: %s)", err, content[:minint(len(content), 100)])
|
||||
|
||||
@@ -17,6 +17,7 @@ type LLMMessage struct {
|
||||
Role Role `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Images []string `json:"images,omitempty"` // 图片 base64 data URL 列表 (多模态)
|
||||
VideoURLs []string `json:"video_urls,omitempty"` // 视频 URL 列表 (多模态)
|
||||
Name string `json:"name,omitempty"` // 可选发送者名称
|
||||
ToolCallID string `json:"tool_call_id,omitempty"` // 工具调用关联ID (tool role 消息关联调用)
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // 助手消息中的工具调用列表
|
||||
@@ -36,6 +37,16 @@ type ImageURL struct {
|
||||
Detail string `json:"detail,omitempty"` // low, high, auto
|
||||
}
|
||||
|
||||
// VideoURLContent holds a video URL for multimodal video understanding.
|
||||
type VideoURLContent struct {
|
||||
VideoURL *VideoURL `json:"video_url,omitempty"`
|
||||
}
|
||||
|
||||
// VideoURL holds a video URL.
|
||||
type VideoURL struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// ChatMessage 数据库存储的对话消息
|
||||
type ChatMessage struct {
|
||||
ID string `json:"id" db:"id"`
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/bus"
|
||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/scheduler"
|
||||
|
||||
plgManager "git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
||||
plgManager "git.yeij.top/AskaEth/Cyrene-Plugins/manager"
|
||||
)
|
||||
|
||||
// Orchestrator 对话编排器 v2.0
|
||||
@@ -41,6 +41,8 @@ type Orchestrator struct {
|
||||
toolRegistry *plgManager.ToolRegistry
|
||||
visionProvider llm.LLMProvider // 视觉模型 (图片预处理)
|
||||
ocrProvider llm.LLMProvider // OCR 模型 (文字提取,与视觉模型并行调用)
|
||||
videoProvider llm.LLMProvider // 视频模型 (短视频理解)
|
||||
asrProvider llm.ASRProvider // ASR 语音识别 (语音消息转录)
|
||||
}
|
||||
|
||||
// SetResponseCache sets the response cache (optional, for Phase 0.2).
|
||||
@@ -84,6 +86,16 @@ func (o *Orchestrator) SetOCRProvider(op llm.LLMProvider) {
|
||||
o.ocrProvider = op
|
||||
}
|
||||
|
||||
// SetVideoProvider sets the video model provider for short video understanding.
|
||||
func (o *Orchestrator) SetVideoProvider(vp llm.LLMProvider) {
|
||||
o.videoProvider = vp
|
||||
}
|
||||
|
||||
// SetASRProvider sets the ASR provider for voice message transcription.
|
||||
func (o *Orchestrator) SetASRProvider(ap llm.ASRProvider) {
|
||||
o.asrProvider = ap
|
||||
}
|
||||
|
||||
// getBus returns the bus or a nop fallback.
|
||||
func (o *Orchestrator) getBus() bus.Bus {
|
||||
if o.eventBus == nil {
|
||||
@@ -117,12 +129,15 @@ func NewOrchestrator(
|
||||
|
||||
// ProcessParams 处理参数
|
||||
type ProcessParams struct {
|
||||
UserID string
|
||||
SessionID string
|
||||
Message string
|
||||
Images []string // 图片 base64 data URL (多模态)
|
||||
Mode string // text / voice_msg / voice_assistant
|
||||
Nickname string
|
||||
UserID string
|
||||
SessionID string
|
||||
Message string
|
||||
Images []string // 图片 base64 data URL (多模态)
|
||||
VideoURLs []string // 视频 URL (多模态), ≤20s short videos
|
||||
VoiceURLs []string // 语音 URL (ASR 转录)
|
||||
Mode string // text / voice_msg / voice_assistant
|
||||
Nickname string
|
||||
ChannelType string // direct / group
|
||||
}
|
||||
|
||||
// ProcessResult 处理结果
|
||||
@@ -165,7 +180,7 @@ func (o *Orchestrator) ProcessInput(
|
||||
// 0.5 图片预处理: 使用视觉模型分析图片,将描述注入消息
|
||||
if len(params.Images) > 0 && o.visionProvider != nil {
|
||||
startTime := time.Now()
|
||||
augmented := o.preprocessImages(ctx, params.Message, params.Images)
|
||||
augmented := o.PreprocessImages(ctx, params.Message, params.Images)
|
||||
if augmented != params.Message {
|
||||
params.Message = augmented
|
||||
logger.Printf("[orchestrator] 图片预处理耗时: %v, 原消息=%d字, 增强后=%d字",
|
||||
@@ -173,6 +188,34 @@ func (o *Orchestrator) ProcessInput(
|
||||
}
|
||||
// 预处理后清空原始图片,避免后续传给不支持多模态的 Chat 模型
|
||||
params.Images = nil
|
||||
|
||||
// 0.6 视频预处理: 使用视频模型分析短视频 (≤20s),将描述注入消息
|
||||
if len(params.VideoURLs) > 0 && o.videoProvider != nil {
|
||||
startTime := time.Now()
|
||||
augmented := o.preprocessVideos(ctx, params.Message, params.VideoURLs)
|
||||
if augmented != params.Message {
|
||||
params.Message = augmented
|
||||
logger.Printf("[orchestrator] 视频预处理耗时: %v", time.Since(startTime))
|
||||
}
|
||||
params.VideoURLs = nil
|
||||
} else if len(params.VideoURLs) > 0 {
|
||||
logger.Printf("[orchestrator] 视频模型未配置,丢弃 %d 个视频", len(params.VideoURLs))
|
||||
params.VideoURLs = nil
|
||||
}
|
||||
|
||||
// 0.7 语音预处理: 使用 ASR 模型转录语音消息,将文本注入消息
|
||||
if len(params.VoiceURLs) > 0 && o.asrProvider != nil && o.asrProvider.IsAvailable() {
|
||||
startTime := time.Now()
|
||||
augmented := o.preprocessVoice(ctx, params.Message, params.VoiceURLs)
|
||||
if augmented != params.Message {
|
||||
params.Message = augmented
|
||||
logger.Printf("[orchestrator] 语音预处理耗时: %v", time.Since(startTime))
|
||||
}
|
||||
params.VoiceURLs = nil
|
||||
} else if len(params.VoiceURLs) > 0 {
|
||||
logger.Printf("[orchestrator] ASR模型未配置,丢弃 %d 个语音", len(params.VoiceURLs))
|
||||
params.VoiceURLs = nil
|
||||
}
|
||||
} else if len(params.Images) > 0 {
|
||||
// 未配置 Vision 模型时,告知用户该模型不支持图片,并清空图片避免报错
|
||||
if params.Message == "" {
|
||||
@@ -233,7 +276,7 @@ func (o *Orchestrator) ProcessInput(
|
||||
eventCh <- model.StreamEvent{Type: model.StreamSegments, Segments: segments}
|
||||
}
|
||||
eventCh <- model.StreamEvent{Type: model.StreamDone}
|
||||
o.contextBuilder.CacheMessage(params.SessionID, model.RoleAssistant, fullContent)
|
||||
o.cacheAssistantMessage(params, fullContent)
|
||||
logger.Printf("[orchestrator] 缓存响应完成: len=%d", len([]rune(fullContent)))
|
||||
return
|
||||
}
|
||||
@@ -334,6 +377,7 @@ func (o *Orchestrator) ProcessInput(
|
||||
PersonaPrompt: systemPrompt,
|
||||
DialogHistory: history,
|
||||
Mode: params.Mode,
|
||||
ChannelType: params.ChannelType,
|
||||
}
|
||||
if prevEnrichment != nil {
|
||||
synthParams.MemorySummary = prevEnrichment.MemorySummary
|
||||
@@ -476,7 +520,7 @@ func (o *Orchestrator) ProcessInput(
|
||||
|
||||
// 10. 后处理:缓存回复
|
||||
if fullContent != "" {
|
||||
o.contextBuilder.CacheMessage(params.SessionID, model.RoleAssistant, fullContent)
|
||||
o.cacheAssistantMessage(params, fullContent)
|
||||
if o.responseCache != nil {
|
||||
o.responseCache.Set(params.Message, fullContent)
|
||||
}
|
||||
@@ -500,6 +544,15 @@ func (o *Orchestrator) ProcessInput(
|
||||
return eventCh, nil
|
||||
}
|
||||
|
||||
// ExtractMemoriesOnly 仅提取记忆,不生成回复。
|
||||
// 用于 platform_silent 模式:观察群聊消息并提取值得记住的信息到对应命名空间。
|
||||
func (o *Orchestrator) ExtractMemoriesOnly(ctx context.Context, userID, sessionID, message string) {
|
||||
if o.memoryExtractor == nil {
|
||||
return
|
||||
}
|
||||
o.memoryExtractor.ExtractObservations(ctx, userID, sessionID, message)
|
||||
}
|
||||
|
||||
// scheduleWithDelays 通过 MessageScheduler 为审查消息分配发送延迟
|
||||
func (o *Orchestrator) scheduleWithDelays(messages []model.ReviewMessage) []model.ReviewMessage {
|
||||
if o.msgScheduler == nil || len(messages) <= 1 {
|
||||
@@ -683,12 +736,20 @@ func (o *Orchestrator) CacheMessage(sessionID string, role model.Role, content s
|
||||
}
|
||||
}
|
||||
|
||||
// preprocessImages uses vision and OCR models to analyze images and augments the user message.
|
||||
// cacheAssistantMessage caches the assistant response.
|
||||
func (o *Orchestrator) cacheAssistantMessage(params ProcessParams, fullContent string) {
|
||||
if o.contextBuilder == nil {
|
||||
return
|
||||
}
|
||||
o.contextBuilder.CacheMessage(params.SessionID, model.RoleAssistant, fullContent)
|
||||
}
|
||||
|
||||
// PreprocessImages uses vision and OCR models to analyze images and augments the user message.
|
||||
// When both vision and OCR providers are available (and are different models), they are called
|
||||
// in parallel and both results are passed to the chat model for autonomous judgment.
|
||||
// For standalone images (no text): generates a comprehensive description as the message.
|
||||
// For text+images: appends image descriptions as contextual annotations.
|
||||
func (o *Orchestrator) preprocessImages(ctx context.Context, message string, images []string) string {
|
||||
func (o *Orchestrator) PreprocessImages(ctx context.Context, message string, images []string) string {
|
||||
visionPromptBase := "请详细描述这张图片的内容,包括场景、物体、人物、文字(如有)、颜色、氛围等所有视觉信息。"
|
||||
ocrPromptBase := `请逐字逐句完整提取图片中的所有文字内容,保持原有格式和排版。如果图片中没有文字,请回复"无文字"。`
|
||||
|
||||
@@ -743,7 +804,7 @@ func (o *Orchestrator) preprocessImages(ctx context.Context, message string, ima
|
||||
var combined string
|
||||
switch {
|
||||
case visionDesc != "" && ocrDesc != "":
|
||||
combined = fmt.Sprintf("[视觉分析]: %s\n[文字提取(OCR)]: %s", visionDesc, ocrDesc)
|
||||
combined = fmt.Sprintf("这张图片的内容:%s(图中包含的文字:%s)", visionDesc, ocrDesc)
|
||||
case visionDesc != "":
|
||||
combined = visionDesc
|
||||
case ocrDesc != "":
|
||||
@@ -765,7 +826,79 @@ func (o *Orchestrator) preprocessImages(ctx context.Context, message string, ima
|
||||
|
||||
augmented := message
|
||||
for i, desc := range descriptions {
|
||||
augmented += fmt.Sprintf("\n\n[图片%d的视觉分析]: %s", i+1, desc)
|
||||
label := "图片分析结果"
|
||||
if len(descriptions) > 1 {
|
||||
label = fmt.Sprintf("图片%d分析结果", i+1)
|
||||
}
|
||||
augmented += fmt.Sprintf("\n\n[%s]: %s", label, desc)
|
||||
}
|
||||
return augmented
|
||||
}
|
||||
|
||||
// preprocessVideos uses the video model to analyze short videos and augments the message.
|
||||
func (o *Orchestrator) preprocessVideos(ctx context.Context, message string, videoURLs []string) string {
|
||||
if o.videoProvider == nil {
|
||||
return message
|
||||
}
|
||||
|
||||
var descriptions []string
|
||||
for i, url := range videoURLs {
|
||||
resp, err := o.videoProvider.Chat(ctx, []model.LLMMessage{
|
||||
{Role: model.RoleUser, Content: "请用简短的中文描述这个视频的内容,包括场景、人物、动作等。控制在100字以内。", VideoURLs: []string{url}},
|
||||
})
|
||||
if err != nil {
|
||||
logger.Printf("[orchestrator] 视频 %d 分析失败: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if resp.Content != "" {
|
||||
descriptions = append(descriptions, resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(descriptions) == 0 {
|
||||
return message
|
||||
}
|
||||
|
||||
if message == "" {
|
||||
return strings.Join(descriptions, "\n\n")
|
||||
}
|
||||
|
||||
augmented := message
|
||||
for i, desc := range descriptions {
|
||||
augmented += fmt.Sprintf("\n\n[视频%d的分析]: %s", i+1, desc)
|
||||
}
|
||||
return augmented
|
||||
}
|
||||
|
||||
// preprocessVoice transcribes voice messages using the ASR provider and augments the message.
|
||||
func (o *Orchestrator) preprocessVoice(ctx context.Context, message string, voiceURLs []string) string {
|
||||
if o.asrProvider == nil || !o.asrProvider.IsAvailable() {
|
||||
return message
|
||||
}
|
||||
|
||||
var transcriptions []string
|
||||
for i, url := range voiceURLs {
|
||||
text, err := o.asrProvider.Transcribe(ctx, url, "zh")
|
||||
if err != nil {
|
||||
logger.Printf("[orchestrator] 语音 %d 转录失败: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if text != "" {
|
||||
transcriptions = append(transcriptions, text)
|
||||
}
|
||||
}
|
||||
|
||||
if len(transcriptions) == 0 {
|
||||
return message
|
||||
}
|
||||
|
||||
if message == "" {
|
||||
return strings.Join(transcriptions, "\n\n")
|
||||
}
|
||||
|
||||
augmented := message
|
||||
for i, t := range transcriptions {
|
||||
augmented += fmt.Sprintf("\n\n[语音%d的转写]: %s", i+1, t)
|
||||
}
|
||||
return augmented
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/llm"
|
||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/model"
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/logger"
|
||||
plgManager "git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
||||
plgSDK "git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
plgManager "git.yeij.top/AskaEth/Cyrene-Plugins/manager"
|
||||
plgSDK "git.yeij.top/AskaEth/Cyrene-Plugins/sdk"
|
||||
)
|
||||
|
||||
// Synthesizer 主会话综合器
|
||||
@@ -35,6 +35,7 @@ type SynthesizeParams struct {
|
||||
SessionID string
|
||||
UserMessage string
|
||||
Images []string // 图片 base64 data URL (多模态)
|
||||
VideoURLs []string // 视频 URL (多模态)
|
||||
Nickname string
|
||||
PersonaPrompt string // 完整人格提示词
|
||||
DialogHistory []model.LLMMessage // 对话历史
|
||||
@@ -45,6 +46,7 @@ type SynthesizeParams struct {
|
||||
KnowledgeInfo string // 知识库检索摘要
|
||||
PendingToolResults []PendingToolResult // 上一轮异步完成的工具结果
|
||||
Mode string // text / voice_assistant
|
||||
ChannelType string // direct / group
|
||||
}
|
||||
|
||||
// Synthesize 综合所有子会话结果,流式生成最终回复。
|
||||
@@ -210,7 +212,15 @@ func (s *Synthesizer) buildSynthesizeMessages(params SynthesizeParams) []model.L
|
||||
Content: systemPrompt,
|
||||
})
|
||||
|
||||
// 注入记忆摘要
|
||||
// 群聊上下文:当消息来自群聊时,告知模型这是一条群聊消息而非一对一私聊。
|
||||
if params.ChannelType == "group" {
|
||||
messages = append(messages, model.LLMMessage{
|
||||
Role: model.RoleSystem,
|
||||
Content: "【群聊上下文】这条消息来自QQ群聊。消息前缀 [群聊 群号] 昵称 (QQ号) 标注了真实发送者。你不是在和开拓者一对一私聊,而是在群聊中和不同成员交流。请根据当前这条消息前缀中的发送者名字来称呼对方——即使你之前在历史对话中称呼过别人,也不要把之前用的称呼套在当前发送者身上。不同的人有不同的名字。只在对你说话或延续已有对话时才回复。",
|
||||
})
|
||||
}
|
||||
|
||||
// 注入记忆摘要// 注入记忆摘要
|
||||
if params.MemorySummary != "" && !strings.Contains(params.MemorySummary, "没有找到") {
|
||||
messages = append(messages, model.LLMMessage{
|
||||
Role: model.RoleSystem,
|
||||
@@ -271,11 +281,12 @@ func (s *Synthesizer) buildSynthesizeMessages(params SynthesizeParams) []model.L
|
||||
messages = append(messages, history...)
|
||||
}
|
||||
|
||||
// 当前用户消息 (支持多模态图片)
|
||||
// 当前用户消息 (支持多模态图片和视频)
|
||||
messages = append(messages, model.LLMMessage{
|
||||
Role: model.RoleUser,
|
||||
Content: params.UserMessage,
|
||||
Images: params.Images,
|
||||
Role: model.RoleUser,
|
||||
Content: params.UserMessage,
|
||||
Images: params.Images,
|
||||
VideoURLs: params.VideoURLs,
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PluginManagerClient calls the plugin-manager service.
|
||||
type PluginManagerClient struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// PMToolDefinition matches the plugin-manager tool definition format.
|
||||
type PMToolDefinition struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Description string `json:"description"`
|
||||
Category string `json:"category"`
|
||||
Complexity string `json:"complexity"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
DangerLevel string `json:"danger_level,omitempty"`
|
||||
}
|
||||
|
||||
// PMToolResult matches the plugin-manager execution result.
|
||||
type PMToolResult struct {
|
||||
ToolName string `json:"tool_name"`
|
||||
Success bool `json:"success"`
|
||||
Output string `json:"output,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// PMPluginInfo matches plugin-manager plugin info.
|
||||
type PMPluginInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
Status string `json:"status"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Tools []string `json:"tools"`
|
||||
}
|
||||
|
||||
func NewPluginManagerClient(baseURL string) *PluginManagerClient {
|
||||
return &PluginManagerClient{
|
||||
baseURL: baseURL,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// GetToolDefinitions fetches all tool definitions from plugin-manager.
|
||||
func (c *PluginManagerClient) GetToolDefinitions(ctx context.Context) ([]PMToolDefinition, error) {
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", c.baseURL+"/api/v1/tools", nil)
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("plugin-manager GetToolDefinitions: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var body struct {
|
||||
Tools []PMToolDefinition `json:"tools"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
return nil, fmt.Errorf("plugin-manager decode tools: %w", err)
|
||||
}
|
||||
return body.Tools, nil
|
||||
}
|
||||
|
||||
// ExecuteTool calls a tool on plugin-manager by ID.
|
||||
func (c *PluginManagerClient) ExecuteTool(ctx context.Context, toolID string, args map[string]interface{}) (*PMToolResult, error) {
|
||||
body, _ := json.Marshal(map[string]interface{}{"arguments": args})
|
||||
url := fmt.Sprintf("%s/api/v1/tools/%s/execute", c.baseURL, toolID)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("plugin-manager ExecuteTool: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result PMToolResult
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("plugin-manager decode result: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// ListPlugins fetches all installed plugins from plugin-manager.
|
||||
func (c *PluginManagerClient) ListPlugins(ctx context.Context) ([]PMPluginInfo, error) {
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", c.baseURL+"/api/v1/plugins", nil)
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var body struct {
|
||||
Plugins []PMPluginInfo `json:"plugins"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body.Plugins, nil
|
||||
}
|
||||
|
||||
// AdaptDefinitions converts PM tool definitions to ai-core ToolDefinition format.
|
||||
func (c *PluginManagerClient) AdaptDefinitions(ctx context.Context) ([]ToolDefinition, error) {
|
||||
pmDefs, err := c.GetToolDefinitions(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defs := make([]ToolDefinition, 0, len(pmDefs))
|
||||
for _, d := range pmDefs {
|
||||
defs = append(defs, ToolDefinition{
|
||||
Name: d.Name,
|
||||
Description: d.Description,
|
||||
Parameters: d.Parameters,
|
||||
})
|
||||
}
|
||||
return defs, nil
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/llm"
|
||||
"git.yeij.top/AskaEth/Cyrene/ai-core/internal/model"
|
||||
)
|
||||
|
||||
// VideoTool enables video understanding via multimodal LLM.
|
||||
type VideoTool struct {
|
||||
videoProvider llm.LLMProvider
|
||||
}
|
||||
|
||||
// NewVideoTool creates a video tool. videoProvider is optional (nil = no-op mode).
|
||||
func NewVideoTool(videoProvider llm.LLMProvider) *VideoTool {
|
||||
return &VideoTool{videoProvider: videoProvider}
|
||||
}
|
||||
|
||||
func (t *VideoTool) Definition() ToolDefinition {
|
||||
return ToolDefinition{
|
||||
Name: "video_analyze",
|
||||
Description: "分析视频内容。传入视频文件路径或URL,返回视频内容的文字描述和分析结果。支持场景理解、动作识别、文字提取等。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"video_path": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "视频文件路径或URL",
|
||||
},
|
||||
"task": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "分析任务: describe(内容描述), summarize(摘要), analyze(综合分析)",
|
||||
"enum": []string{"describe", "summarize", "analyze"},
|
||||
},
|
||||
},
|
||||
"required": []string{"video_path", "task"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var videoTaskPrompts = map[string]string{
|
||||
"describe": "请详细描述这个视频的内容,包括场景、人物、动作、对话要点等。",
|
||||
"summarize": "请用简洁的语言总结这个视频的主要内容。",
|
||||
"analyze": "请综合分析这个视频,包括内容描述、关键片段、文字信息(如有)、以及你的理解。",
|
||||
}
|
||||
|
||||
func (t *VideoTool) Execute(ctx context.Context, args map[string]interface{}) (*ToolResult, error) {
|
||||
videoPath, _ := args["video_path"].(string)
|
||||
if videoPath == "" {
|
||||
return &ToolResult{ToolName: "video_analyze", Success: false, Error: "video_path 参数不能为空"}, nil
|
||||
}
|
||||
|
||||
task, _ := args["task"].(string)
|
||||
if task == "" {
|
||||
task = "analyze"
|
||||
}
|
||||
|
||||
prompt := videoTaskPrompts[task]
|
||||
if prompt == "" {
|
||||
prompt = videoTaskPrompts["analyze"]
|
||||
}
|
||||
|
||||
if t.videoProvider == nil {
|
||||
return &ToolResult{ToolName: "video_analyze", Success: false, Error: "视频理解模型未配置"}, nil
|
||||
}
|
||||
|
||||
messages := []model.LLMMessage{
|
||||
{Role: model.RoleUser, Content: prompt, VideoURLs: []string{videoPath}},
|
||||
}
|
||||
resp, err := t.videoProvider.Chat(ctx, messages)
|
||||
if err != nil {
|
||||
return &ToolResult{ToolName: "video_analyze", Success: false, Error: fmt.Sprintf("视频模型调用失败: %v", err)}, nil
|
||||
}
|
||||
|
||||
output, _ := json.Marshal(map[string]interface{}{
|
||||
"video_path": videoPath,
|
||||
"task": task,
|
||||
"model": t.videoProvider.ModelName(),
|
||||
"text": resp.Content,
|
||||
"prompt_tokens": resp.Usage.PromptTokens,
|
||||
"completion_tokens": resp.Usage.CompletionTokens,
|
||||
"total_tokens": resp.Usage.TotalTokens,
|
||||
})
|
||||
return &ToolResult{ToolName: "video_analyze", Success: true, Data: string(output)}, nil
|
||||
}
|
||||
@@ -38,6 +38,7 @@ type ChatHandler struct {
|
||||
hub *ws.Hub
|
||||
sessionStore *store.SessionStore
|
||||
fileStore *store.FileStore
|
||||
voiceStream *VoiceStreamManager
|
||||
upgrader websocket.Upgrader
|
||||
pending map[string][]queuedMsg // per-session message queue
|
||||
pendingMu sync.Mutex
|
||||
@@ -50,6 +51,7 @@ func NewChatHandler(cfg *config.Config, hub *ws.Hub, sessionStore *store.Session
|
||||
hub: hub,
|
||||
sessionStore: sessionStore,
|
||||
fileStore: fileStore,
|
||||
voiceStream: NewVoiceStreamManager(cfg.VoiceServiceURL),
|
||||
pending: make(map[string][]queuedMsg),
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
@@ -131,6 +133,12 @@ func (h *ChatHandler) handleMessage(client *ws.Client, msg ws.ClientMessage) {
|
||||
h.handleChatMessage(client, msg)
|
||||
case "voice_input":
|
||||
h.handleVoiceInput(client, msg)
|
||||
case "voice_stream_start":
|
||||
h.handleVoiceStreamStart(client, msg)
|
||||
case "voice_stream_chunk":
|
||||
h.handleVoiceStreamChunk(client, msg)
|
||||
case "voice_stream_end":
|
||||
h.handleVoiceStreamEnd(client, msg)
|
||||
case "history":
|
||||
h.handleHistoryRequest(client, msg)
|
||||
default:
|
||||
@@ -436,11 +444,13 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
// 处理审查后的结构化消息 (review)
|
||||
if len(chunk.ReviewMessages) > 0 {
|
||||
for i, rm := range chunk.ReviewMessages {
|
||||
msgType := rm.Type
|
||||
if msgType == "" {
|
||||
msgType = "chat"
|
||||
}
|
||||
role := "assistant"
|
||||
msgType := "chat"
|
||||
if rm.Type == "action" {
|
||||
if msgType == "action" {
|
||||
role = "action"
|
||||
msgType = "action"
|
||||
}
|
||||
reviewMsgID := fmt.Sprintf("%s_r%d", msgID, i)
|
||||
// 持久化每条审查消息 (action 角色映射为 assistant,LLM 模型不支持自定义角色)
|
||||
@@ -473,6 +483,7 @@ func (h *ChatHandler) streamResponse(client *ws.Client, mode string, reqBody []b
|
||||
SessionID: client.SessionID,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
ClientInfo: clientInfo,
|
||||
Metadata: rm.Metadata,
|
||||
})
|
||||
// 使用 MessageScheduler 计算的 per-message 延迟
|
||||
if rm.DelayMs > 0 {
|
||||
@@ -650,6 +661,96 @@ func (h *ChatHandler) handleVoiceInput(client *ws.Client, msg ws.ClientMessage)
|
||||
}()
|
||||
}
|
||||
|
||||
// handleVoiceStreamStart begins a streaming voice session via voice-service.
|
||||
func (h *ChatHandler) handleVoiceStreamStart(client *ws.Client, msg ws.ClientMessage) {
|
||||
format := msg.Format
|
||||
if format == "" {
|
||||
format = "webm"
|
||||
}
|
||||
language := msg.Language
|
||||
if language == "" {
|
||||
language = "zh"
|
||||
}
|
||||
|
||||
if err := h.voiceStream.StartStream(client, format, language); err != nil {
|
||||
logger.Printf("[voice-stream] 启动流式 STT 失败: %v", err)
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "error",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Error: "启动语音流失败: " + err.Error(),
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "voice_interim",
|
||||
MessageID: "voice_" + generateID(),
|
||||
Text: "",
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
}
|
||||
|
||||
// handleVoiceStreamChunk forwards an audio chunk to the active voice stream.
|
||||
func (h *ChatHandler) handleVoiceStreamChunk(client *ws.Client, msg ws.ClientMessage) {
|
||||
if msg.AudioData == "" {
|
||||
return
|
||||
}
|
||||
|
||||
audioData, err := decodeBase64(msg.AudioData)
|
||||
if err != nil {
|
||||
logger.Printf("[voice-stream] 解码音频块失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.voiceStream.SendChunk(client.ClientID, client.SessionID, audioData, msg.Sequence); err != nil {
|
||||
logger.Printf("[voice-stream] 发送音频块失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleVoiceStreamEnd stops the voice stream and processes the final transcription.
|
||||
func (h *ChatHandler) handleVoiceStreamEnd(client *ws.Client, msg ws.ClientMessage) {
|
||||
go func() {
|
||||
text, err := h.voiceStream.EndStream(client.ClientID, client.SessionID)
|
||||
if err != nil {
|
||||
logger.Printf("[voice-stream] 结束流式 STT 失败: %v", err)
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "error",
|
||||
MessageID: "msg_" + generateID(),
|
||||
Error: "语音流处理失败: " + err.Error(),
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if text == "" {
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "voice_final",
|
||||
MessageID: "voice_" + generateID(),
|
||||
Text: "",
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Send final transcription to frontend
|
||||
client.SendMessage(ws.ServerMessage{
|
||||
Type: "voice_final",
|
||||
MessageID: "voice_" + generateID(),
|
||||
Text: text,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
|
||||
// Route the transcribed text as a regular chat message to ai-core
|
||||
chatMsg := ws.ClientMessage{
|
||||
Type: "message",
|
||||
Content: text,
|
||||
Mode: msg.Mode,
|
||||
}
|
||||
h.handleChatMessage(client, chatMsg)
|
||||
}()
|
||||
}
|
||||
|
||||
// transcribeAudio 将 base64 编码的音频发送到 voice-service 进行转录。
|
||||
func (h *ChatHandler) transcribeAudio(audioB64 string, format string) (string, error) {
|
||||
audioData, err := decodeBase64(audioB64)
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/gateway/internal/ws"
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/logger"
|
||||
)
|
||||
|
||||
// voiceStreamSession manages a proxied WebSocket connection to voice-service
|
||||
// for real-time streaming speech-to-text during a single voice input.
|
||||
type voiceStreamSession struct {
|
||||
client *ws.Client
|
||||
voiceConn *websocket.Conn
|
||||
language string
|
||||
format string
|
||||
mu sync.Mutex
|
||||
done chan struct{}
|
||||
interimBuf strings.Builder
|
||||
finalText string
|
||||
}
|
||||
|
||||
// VoiceStreamManager creates and tracks streaming STT sessions.
|
||||
type VoiceStreamManager struct {
|
||||
voiceServiceURL string
|
||||
sessions map[string]*voiceStreamSession // key: clientID+sessionID
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewVoiceStreamManager creates a voice stream manager.
|
||||
func NewVoiceStreamManager(voiceServiceURL string) *VoiceStreamManager {
|
||||
return &VoiceStreamManager{
|
||||
voiceServiceURL: voiceServiceURL,
|
||||
sessions: make(map[string]*voiceStreamSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *VoiceStreamManager) sessionKey(clientID, sessionID string) string {
|
||||
return clientID + ":" + sessionID
|
||||
}
|
||||
|
||||
// StartStream begins a streaming STT session by connecting to voice-service.
|
||||
func (m *VoiceStreamManager) StartStream(client *ws.Client, format, language string) error {
|
||||
m.mu.Lock()
|
||||
key := m.sessionKey(client.ClientID, client.SessionID)
|
||||
if _, exists := m.sessions[key]; exists {
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("voice stream already active for this session")
|
||||
}
|
||||
|
||||
if format == "" {
|
||||
format = "webm"
|
||||
}
|
||||
if language == "" {
|
||||
language = "zh"
|
||||
}
|
||||
|
||||
voiceURL := strings.TrimRight(m.voiceServiceURL, "/")
|
||||
wsURL := "ws" + strings.TrimPrefix(voiceURL, "http") + "/api/v1/stt/stream"
|
||||
wsURL += "?language=" + language + "&format=" + format
|
||||
|
||||
dialer := websocket.Dialer{HandshakeTimeout: 10 * time.Second}
|
||||
voiceConn, _, err := dialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("connect to voice-service stream: %w", err)
|
||||
}
|
||||
|
||||
session := &voiceStreamSession{
|
||||
client: client,
|
||||
voiceConn: voiceConn,
|
||||
language: language,
|
||||
format: format,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
m.sessions[key] = session
|
||||
m.mu.Unlock()
|
||||
|
||||
// Read results from voice-service in background
|
||||
go session.readResults(m, key)
|
||||
|
||||
logger.Printf("[voice-stream] 流式 STT 会话已建立: client=%s, lang=%s, fmt=%s", client.ClientID, language, format)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendChunk forwards an audio chunk (already decoded bytes) to voice-service.
|
||||
func (m *VoiceStreamManager) SendChunk(clientID, sessionID string, audioData []byte, seq int) error {
|
||||
m.mu.Lock()
|
||||
key := m.sessionKey(clientID, sessionID)
|
||||
session, exists := m.sessions[key]
|
||||
m.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("no active voice stream for this session")
|
||||
}
|
||||
|
||||
session.mu.Lock()
|
||||
defer session.mu.Unlock()
|
||||
|
||||
if err := session.voiceConn.WriteMessage(websocket.BinaryMessage, audioData); err != nil {
|
||||
return fmt.Errorf("send audio chunk: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EndStream signals voice-service that the audio stream is complete,
|
||||
// waits for final result, then cleans up.
|
||||
func (m *VoiceStreamManager) EndStream(clientID, sessionID string) (string, error) {
|
||||
m.mu.Lock()
|
||||
key := m.sessionKey(clientID, sessionID)
|
||||
session, exists := m.sessions[key]
|
||||
m.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
return "", fmt.Errorf("no active voice stream for this session")
|
||||
}
|
||||
|
||||
// Send stop action to voice-service
|
||||
session.mu.Lock()
|
||||
stopMsg, _ := json.Marshal(map[string]interface{}{"action": "stop"})
|
||||
session.voiceConn.WriteMessage(websocket.TextMessage, stopMsg)
|
||||
session.mu.Unlock()
|
||||
|
||||
// Wait for result processing to finish
|
||||
select {
|
||||
case <-session.done:
|
||||
case <-time.After(15 * time.Second):
|
||||
logger.Printf("[voice-stream] 等待最终结果超时: client=%s", clientID)
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
session.close()
|
||||
m.mu.Lock()
|
||||
delete(m.sessions, key)
|
||||
m.mu.Unlock()
|
||||
|
||||
text := session.finalText
|
||||
if text == "" {
|
||||
text = session.interimBuf.String()
|
||||
}
|
||||
logger.Printf("[voice-stream] 流式 STT 结束: client=%s, text=%q", clientID, text)
|
||||
return text, nil
|
||||
}
|
||||
|
||||
// CancelStream forcibly terminates a voice stream.
|
||||
func (m *VoiceStreamManager) CancelStream(clientID, sessionID string) {
|
||||
m.mu.Lock()
|
||||
key := m.sessionKey(clientID, sessionID)
|
||||
session, exists := m.sessions[key]
|
||||
if exists {
|
||||
delete(m.sessions, key)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
if exists {
|
||||
session.close()
|
||||
logger.Printf("[voice-stream] 流式 STT 已取消: client=%s", clientID)
|
||||
}
|
||||
}
|
||||
|
||||
// readResults reads STT results from voice-service and forwards them to the client.
|
||||
func (s *voiceStreamSession) readResults(mgr *VoiceStreamManager, key string) {
|
||||
defer close(s.done)
|
||||
|
||||
voiceConn := s.voiceConn
|
||||
for {
|
||||
msgType, data, err := voiceConn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
|
||||
logger.Printf("[voice-stream] voice-service 读取错误: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if msgType != websocket.TextMessage {
|
||||
continue
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
IsFinal bool `json:"isFinal"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
logger.Printf("[voice-stream] 解析结果失败: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
logger.Printf("[voice-stream] voice-service 错误: %s", result.Error)
|
||||
s.client.SendMessage(ws.ServerMessage{
|
||||
Type: "voice_interim",
|
||||
MessageID: "voice_" + generateID(),
|
||||
Error: result.Error,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if result.Text != "" {
|
||||
if result.IsFinal {
|
||||
s.finalText = result.Text
|
||||
s.client.SendMessage(ws.ServerMessage{
|
||||
Type: "voice_final",
|
||||
MessageID: "voice_" + generateID(),
|
||||
Text: result.Text,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Interim result — accumulate and forward
|
||||
s.interimBuf.Reset()
|
||||
s.interimBuf.WriteString(result.Text)
|
||||
s.client.SendMessage(ws.ServerMessage{
|
||||
Type: "voice_interim",
|
||||
MessageID: "voice_" + generateID(),
|
||||
Text: result.Text,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
}
|
||||
|
||||
// "done" type from voice-service signals end of results
|
||||
if result.Type == "done" {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *voiceStreamSession) close() {
|
||||
if s.voiceConn != nil {
|
||||
s.voiceConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// HasActiveStream checks if a client already has an active voice stream.
|
||||
func (m *VoiceStreamManager) HasActiveStream(clientID, sessionID string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
_, exists := m.sessions[m.sessionKey(clientID, sessionID)]
|
||||
return exists
|
||||
}
|
||||
|
||||
// CleanupClient removes all streams for a client.
|
||||
func (m *VoiceStreamManager) CleanupClient(clientID string) {
|
||||
m.mu.Lock()
|
||||
var toRemove []string
|
||||
for key, session := range m.sessions {
|
||||
if session.client.ClientID == clientID {
|
||||
toRemove = append(toRemove, key)
|
||||
}
|
||||
}
|
||||
for _, key := range toRemove {
|
||||
delete(m.sessions, key)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, key := range toRemove {
|
||||
// Close connection if session exists (we already deleted from map)
|
||||
logger.Printf("[voice-stream] 清理客户端流: key=%s", key)
|
||||
}
|
||||
}
|
||||
@@ -15,11 +15,14 @@ type MessageAttachment struct {
|
||||
|
||||
// 客户端 → 服务端消息
|
||||
type ClientMessage struct {
|
||||
Type string `json:"type"` // message | voice_input | ping | history
|
||||
Type string `json:"type"` // message | voice_input | voice_stream_start | voice_stream_chunk | voice_stream_end | ping | history
|
||||
SessionID string `json:"session_id"`
|
||||
Mode string `json:"mode"` // text | voice_msg | voice_assistant
|
||||
Content string `json:"content"`
|
||||
AudioData string `json:"audio_data,omitempty"` // base64
|
||||
Format string `json:"format,omitempty"` // 音频格式: webm, wav, pcm, opus
|
||||
Language string `json:"language,omitempty"` // 识别语言: zh, en, ja, ko, auto
|
||||
Sequence int `json:"sequence,omitempty"` // 音频块序列号 (voice_stream_chunk)
|
||||
Attachments []MessageAttachment `json:"attachments,omitempty"` // 图片等附件
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
ClientID string `json:"client_id,omitempty"` // 客户端唯一标识 (多端区分)
|
||||
@@ -28,11 +31,12 @@ type ClientMessage struct {
|
||||
ClientMsgID string `json:"client_msg_id,omitempty"` // 客户端消息ID (跨端去重)
|
||||
}
|
||||
|
||||
// ReviewMessage 审查后的结构化消息(动作/聊天分离)
|
||||
// ReviewMessage 审查后的结构化消息(动作/聊天/Markdown/代码块/搜索结果)
|
||||
type ReviewMessage struct {
|
||||
Type string `json:"type"` // "action" | "chat"
|
||||
Content string `json:"content"`
|
||||
DelayMs int `json:"delay_ms,omitempty"` // ms to wait before sending (0 = immediate)
|
||||
Type string `json:"type"` // action | chat | markdown | code | search_result
|
||||
Content string `json:"content"`
|
||||
DelayMs int `json:"delay_ms,omitempty"` // ms to wait before sending (0 = immediate)
|
||||
Metadata map[string]any `json:"metadata,omitempty"` // 类型特定元数据 (code 语言、搜索结果 URL 等)
|
||||
}
|
||||
|
||||
// ClientInfo carries the originating client's device metadata.
|
||||
@@ -44,7 +48,7 @@ type ClientInfo struct {
|
||||
|
||||
// 服务端 → 客户端消息
|
||||
type ServerMessage struct {
|
||||
Type string `json:"type"` // response | segment | audio | error | device_update | pong | history_response | stream_chunk | stream_end | background_thinking | notification | multi_message | stream_segments | review | thinking | tool_progress | system_info
|
||||
Type string `json:"type"` // response | segment | audio | error | device_update | pong | history_response | stream_chunk | stream_end | background_thinking | notification | multi_message | stream_segments | review | thinking | tool_progress | system_info | voice_interim | voice_final
|
||||
MessageID string `json:"message_id"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Content string `json:"content,omitempty"` // stream_chunk 的增量文本
|
||||
@@ -63,7 +67,8 @@ type ServerMessage struct {
|
||||
Notification *NotificationInfo `json:"notification,omitempty"` // 通知推送
|
||||
MultiMessage *MultiMessagePayload `json:"multi_message,omitempty"` // 多条消息批量发送
|
||||
ReviewMessages []ReviewMessage `json:"review_messages,omitempty"` // 审查后的结构化消息列表
|
||||
MsgType string `json:"msg_type,omitempty"` // 消息展示类型: action | chat | thinking | tool_progress | system_info
|
||||
MsgType string `json:"msg_type,omitempty"` // 消息展示类型: action | chat | thinking | tool_progress | system_info | markdown | code | search_result
|
||||
Metadata map[string]any `json:"metadata,omitempty"` // 消息元数据 (code 语言等)
|
||||
ToolProgress *ToolProgressInfo `json:"tool_progress,omitempty"` // 工具执行进度
|
||||
SystemInfo *SystemInfoPayload `json:"system_info,omitempty"` // 系统通知信息
|
||||
ProtocolVersion int `json:"protocol_version,omitempty"` // 协议版本
|
||||
|
||||
+2
-2
@@ -5,9 +5,9 @@ use (
|
||||
./gateway
|
||||
./iot-debug-service
|
||||
./memory-service
|
||||
./pkg/audio
|
||||
./pkg/dashscope
|
||||
./pkg/logger
|
||||
./pkg/plugins
|
||||
./platform-bridge
|
||||
./plugin-manager
|
||||
./voice-service
|
||||
)
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
package audio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NormalizeFormat 规范化音频格式字符串。
|
||||
func NormalizeFormat(format string) string {
|
||||
switch strings.ToLower(format) {
|
||||
case "pcm", "wav", "mp3", "mpeg", "ogg", "opus", "flac", "m4a", "mp4", "aac", "webm", "amr":
|
||||
return strings.ToLower(format)
|
||||
default:
|
||||
return format
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertToPCM16 将音频数据转换为 16-bit PCM 16000Hz mono。
|
||||
// 对于已经是 PCM 的数据直接返回;对于 WAV 跳过 44 字节头部;
|
||||
// 其他格式使用 ffmpeg 转码。
|
||||
func ConvertToPCM16(data []byte, format string) ([]byte, error) {
|
||||
normFormat := NormalizeFormat(format)
|
||||
switch normFormat {
|
||||
case "pcm":
|
||||
return data, nil
|
||||
case "wav":
|
||||
if len(data) > 44 {
|
||||
return data[44:], nil
|
||||
}
|
||||
return data, nil
|
||||
default:
|
||||
return transcodeToPCM(data, normFormat)
|
||||
}
|
||||
}
|
||||
|
||||
// transcodeToPCM 使用 ffmpeg 将音频数据转码为 PCM 16-bit 16000Hz mono。
|
||||
func transcodeToPCM(data []byte, format string) ([]byte, error) {
|
||||
inFile, err := os.CreateTemp(os.TempDir(), "cyrene-asr-in-*."+format)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建输入临时文件失败: %w", err)
|
||||
}
|
||||
inPath := inFile.Name()
|
||||
defer os.Remove(inPath)
|
||||
if _, err := inFile.Write(data); err != nil {
|
||||
inFile.Close()
|
||||
return nil, fmt.Errorf("写入输入临时文件失败: %w", err)
|
||||
}
|
||||
inFile.Close()
|
||||
|
||||
outFile, err := os.CreateTemp(os.TempDir(), "cyrene-asr-out-*.pcm")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建输出临时文件失败: %w", err)
|
||||
}
|
||||
outPath := outFile.Name()
|
||||
outFile.Close()
|
||||
defer os.Remove(outPath)
|
||||
|
||||
cmd := exec.Command("ffmpeg",
|
||||
"-i", inPath,
|
||||
"-ar", "16000",
|
||||
"-ac", "1",
|
||||
"-c:a", "pcm_s16le",
|
||||
"-f", "s16le",
|
||||
outPath,
|
||||
"-y",
|
||||
)
|
||||
cmd.Stderr = nil
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("音频转码失败 (ffmpeg): %w", err)
|
||||
}
|
||||
|
||||
outData, err := os.ReadFile(outPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取转码结果失败: %w", err)
|
||||
}
|
||||
|
||||
return outData, nil
|
||||
}
|
||||
|
||||
// IsFFmpegAvailable 检查 ffmpeg 是否可执行。
|
||||
func IsFFmpegAvailable() bool {
|
||||
_, err := exec.LookPath("ffmpeg")
|
||||
return err == nil
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
module git.yeij.top/AskaEth/Cyrene/pkg/audio
|
||||
|
||||
go 1.21
|
||||
@@ -0,0 +1,127 @@
|
||||
package dashscope
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---- 共享类型 ----
|
||||
|
||||
// ASRRequest DashScope ASR REST API 请求体。
|
||||
type ASRRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input ASRInput `json:"input"`
|
||||
Parameters ASRParams `json:"parameters"`
|
||||
}
|
||||
|
||||
// ASRInput 音频输入。
|
||||
type ASRInput struct {
|
||||
Audio string `json:"audio"`
|
||||
}
|
||||
|
||||
// ASRParams 识别参数。
|
||||
type ASRParams struct {
|
||||
Format string `json:"format,omitempty"`
|
||||
SampleRate int `json:"sample_rate,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
}
|
||||
|
||||
// ASRResponse DashScope ASR REST API 响应体。
|
||||
type ASRResponse struct {
|
||||
Output struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"output"`
|
||||
Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
RequestID string `json:"request_id"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ---- 共享客户端 ----
|
||||
|
||||
// RESTClient 封装 DashScope REST API 的 HTTP 通信。
|
||||
type RESTClient struct {
|
||||
apiKey string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewRESTClient 创建 REST 客户端。
|
||||
func NewRESTClient(apiKey string) *RESTClient {
|
||||
return &RESTClient{
|
||||
apiKey: apiKey,
|
||||
client: &http.Client{Timeout: 60 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// IsAvailable 检查 API Key 是否已配置。
|
||||
func (c *RESTClient) IsAvailable() bool {
|
||||
return c.apiKey != ""
|
||||
}
|
||||
|
||||
// Transcribe 调用 DashScope ASR REST API 进行语音识别。
|
||||
// audioData 应为 PCM 16kHz mono 格式。
|
||||
func (c *RESTClient) Transcribe(ctx context.Context, model string, audioData []byte, format string, sampleRate int, language string) (string, error) {
|
||||
if !c.IsAvailable() {
|
||||
return "", fmt.Errorf("DashScope ASR API key not configured")
|
||||
}
|
||||
if language == "" || language == "auto" {
|
||||
language = "zh"
|
||||
}
|
||||
|
||||
audioB64 := base64.StdEncoding.EncodeToString(audioData)
|
||||
|
||||
reqBody := ASRRequest{
|
||||
Model: model,
|
||||
Input: ASRInput{
|
||||
Audio: fmt.Sprintf("data:audio/%s;base64,%s", format, audioB64),
|
||||
},
|
||||
Parameters: ASRParams{
|
||||
Format: format,
|
||||
SampleRate: sampleRate,
|
||||
Language: language,
|
||||
},
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal ASR request: %w", err)
|
||||
}
|
||||
|
||||
url := "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/asr"
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create ASR request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("ASR request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read ASR response: %w", err)
|
||||
}
|
||||
|
||||
var asrResp ASRResponse
|
||||
if err := json.Unmarshal(respBytes, &asrResp); err != nil {
|
||||
return "", fmt.Errorf("parse ASR response: %w", err)
|
||||
}
|
||||
|
||||
if asrResp.Code != "" && asrResp.Code != "0" {
|
||||
return "", fmt.Errorf("ASR error: %s (code=%s)", asrResp.Message, asrResp.Code)
|
||||
}
|
||||
|
||||
return asrResp.Output.Text, nil
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package dashscope
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRESTClient_Transcribe_Success(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer test-key" {
|
||||
t.Errorf("unexpected auth header: %s", r.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
var req ASRRequest
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
if req.Model != "test-model" {
|
||||
t.Errorf("unexpected model: %s", req.Model)
|
||||
}
|
||||
if req.Parameters.Language != "zh" {
|
||||
t.Errorf("unexpected language: %s", req.Parameters.Language)
|
||||
}
|
||||
|
||||
resp := ASRResponse{}
|
||||
resp.Output.Text = "你好世界"
|
||||
resp.RequestID = "req-1"
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
client := &RESTClient{apiKey: "test-key", client: ts.Client()}
|
||||
|
||||
// We can't override the hardcoded URL — this test validates the client
|
||||
// infrastructure. For full integration, test against real or mocked URL.
|
||||
_ = client
|
||||
|
||||
if !client.IsAvailable() {
|
||||
t.Error("client should be available with apiKey")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClient_NotAvailable(t *testing.T) {
|
||||
client := NewRESTClient("")
|
||||
if client.IsAvailable() {
|
||||
t.Error("client without apiKey should not be available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClient_Transcribe_NoAPIKey(t *testing.T) {
|
||||
client := NewRESTClient("")
|
||||
_, err := client.Transcribe(context.Background(), "model", []byte{}, "pcm", 16000, "zh")
|
||||
if err == nil {
|
||||
t.Error("expected error without API key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRESTClient_Transcribe_AutoLanguage(t *testing.T) {
|
||||
c := NewRESTClient("")
|
||||
_ = c
|
||||
// Verify the language fallback logic via type inspection
|
||||
pcmData := make([]byte, 16000)
|
||||
b64 := base64.StdEncoding.EncodeToString(pcmData)
|
||||
_ = b64
|
||||
}
|
||||
|
||||
func TestASRRequest_Serialization(t *testing.T) {
|
||||
req := ASRRequest{
|
||||
Model: "test-model",
|
||||
Input: ASRInput{
|
||||
Audio: "data:audio/pcm;base64,dGVzdA==",
|
||||
},
|
||||
Parameters: ASRParams{
|
||||
Format: "pcm",
|
||||
SampleRate: 16000,
|
||||
Language: "zh",
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
}
|
||||
|
||||
var decoded ASRRequest
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if decoded.Model != "test-model" {
|
||||
t.Errorf("model mismatch: %s", decoded.Model)
|
||||
}
|
||||
if decoded.Parameters.Language != "zh" {
|
||||
t.Errorf("language mismatch: %s", decoded.Parameters.Language)
|
||||
}
|
||||
}
|
||||
|
||||
func TestASRResponse_Deserialization(t *testing.T) {
|
||||
jsonStr := `{"output":{"text":"你好"},"usage":{"total_tokens":0},"request_id":"r1","code":""}`
|
||||
var resp ASRResponse
|
||||
if err := json.Unmarshal([]byte(jsonStr), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if resp.Output.Text != "你好" {
|
||||
t.Errorf("text mismatch: %s", resp.Output.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestASRResponse_Error(t *testing.T) {
|
||||
jsonStr := `{"output":{"text":""},"code":"InvalidParameter","message":"bad request"}`
|
||||
var resp ASRResponse
|
||||
if err := json.Unmarshal([]byte(jsonStr), &resp); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if resp.Code != "InvalidParameter" {
|
||||
t.Errorf("code mismatch: %s", resp.Code)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
module git.yeij.top/AskaEth/Cyrene/pkg/dashscope
|
||||
|
||||
go 1.21
|
||||
@@ -1,279 +0,0 @@
|
||||
package calculator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
type CalculatorPlugin struct {
|
||||
sdk.BasePlugin
|
||||
}
|
||||
|
||||
func (p *CalculatorPlugin) Metadata() sdk.PluginMetadata {
|
||||
return sdk.PluginMetadata{
|
||||
Name: "calculator", DisplayName: "Calculator", Version: "1.0.0",
|
||||
Description: "Safe mathematical expression evaluation with custom parser",
|
||||
Category: "utility", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *CalculatorPlugin) Tools() []sdk.Tool {
|
||||
return []sdk.Tool{&CalculatorTool{}}
|
||||
}
|
||||
|
||||
type CalculatorTool struct {
|
||||
sdk.BaseTool
|
||||
}
|
||||
|
||||
func (t *CalculatorTool) Definition() sdk.ToolDefinition {
|
||||
return sdk.ToolDefinition{
|
||||
ID: "calculator", Name: "calculator", DisplayName: "Calculator",
|
||||
Description: "Execute mathematical calculations. Supports arithmetic, trig, logs, powers.",
|
||||
Category: "utility", Complexity: sdk.ComplexitySimple,
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object", "properties": map[string]interface{}{"expression": map[string]interface{}{"type": "string"}},
|
||||
"required": []string{"expression"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *CalculatorTool) Validate(args map[string]interface{}) error {
|
||||
if _, ok := args["expression"]; !ok {
|
||||
return fmt.Errorf("missing required parameter: expression")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *CalculatorTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
expr, _ := args["expression"].(string)
|
||||
result, err := evalExpression(expr)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "calculator", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "calculator", Success: true, Output: fmt.Sprintf("%v", result)}, nil
|
||||
}
|
||||
|
||||
// Expression parser supporting +, -, *, /, %, ^, functions, constants.
|
||||
type exprParser struct {
|
||||
s string
|
||||
pos int
|
||||
}
|
||||
|
||||
func evalExpression(s string) (float64, error) {
|
||||
p := &exprParser{s: strings.TrimSpace(s)}
|
||||
result, err := p.parseAddSub()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if p.pos < len(p.s) {
|
||||
return 0, fmt.Errorf("unexpected character at position %d: %c", p.pos, p.s[p.pos])
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (p *exprParser) peek() byte {
|
||||
if p.pos < len(p.s) {
|
||||
return p.s[p.pos]
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (p *exprParser) skipSpaces() {
|
||||
for p.pos < len(p.s) && p.s[p.pos] == ' ' {
|
||||
p.pos++
|
||||
}
|
||||
}
|
||||
|
||||
func (p *exprParser) parseAddSub() (float64, error) {
|
||||
left, err := p.parseMulDiv()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for {
|
||||
p.skipSpaces()
|
||||
op := p.peek()
|
||||
if op != '+' && op != '-' {
|
||||
break
|
||||
}
|
||||
p.pos++
|
||||
right, err := p.parseMulDiv()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if op == '+' {
|
||||
left += right
|
||||
} else {
|
||||
left -= right
|
||||
}
|
||||
}
|
||||
return left, nil
|
||||
}
|
||||
|
||||
func (p *exprParser) parseMulDiv() (float64, error) {
|
||||
left, err := p.parsePower()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for {
|
||||
p.skipSpaces()
|
||||
op := p.peek()
|
||||
if op != '*' && op != '/' && op != '%' {
|
||||
break
|
||||
}
|
||||
p.pos++
|
||||
right, err := p.parsePower()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch op {
|
||||
case '*':
|
||||
left *= right
|
||||
case '/':
|
||||
if right == 0 {
|
||||
return 0, fmt.Errorf("division by zero")
|
||||
}
|
||||
left /= right
|
||||
case '%':
|
||||
left = math.Mod(left, right)
|
||||
}
|
||||
}
|
||||
return left, nil
|
||||
}
|
||||
|
||||
func (p *exprParser) parsePower() (float64, error) {
|
||||
left, err := p.parseUnary()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
p.skipSpaces()
|
||||
if p.peek() == '^' {
|
||||
p.pos++
|
||||
right, err := p.parseUnary()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return math.Pow(left, right), nil
|
||||
}
|
||||
return left, nil
|
||||
}
|
||||
|
||||
func (p *exprParser) parseUnary() (float64, error) {
|
||||
p.skipSpaces()
|
||||
if p.peek() == '-' {
|
||||
p.pos++
|
||||
val, err := p.parseAtom()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return -val, nil
|
||||
}
|
||||
if p.peek() == '+' {
|
||||
p.pos++
|
||||
}
|
||||
return p.parseAtom()
|
||||
}
|
||||
|
||||
func (p *exprParser) parseAtom() (float64, error) {
|
||||
p.skipSpaces()
|
||||
if p.peek() == '(' {
|
||||
p.pos++
|
||||
result, err := p.parseAddSub()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
p.skipSpaces()
|
||||
if p.peek() != ')' {
|
||||
return 0, fmt.Errorf("missing closing parenthesis")
|
||||
}
|
||||
p.pos++
|
||||
return result, nil
|
||||
}
|
||||
if p.peek() == 0 {
|
||||
return 0, fmt.Errorf("unexpected end of expression")
|
||||
}
|
||||
if unicode.IsDigit(rune(p.peek())) || p.peek() == '.' {
|
||||
return p.parseNumber()
|
||||
}
|
||||
return p.parseFuncOrConst()
|
||||
}
|
||||
|
||||
func (p *exprParser) parseNumber() (float64, error) {
|
||||
start := p.pos
|
||||
for p.pos < len(p.s) && (unicode.IsDigit(rune(p.s[p.pos])) || p.s[p.pos] == '.') {
|
||||
p.pos++
|
||||
}
|
||||
return strconv.ParseFloat(p.s[start:p.pos], 64)
|
||||
}
|
||||
|
||||
func (p *exprParser) parseFuncOrConst() (float64, error) {
|
||||
start := p.pos
|
||||
for p.pos < len(p.s) && (unicode.IsLetter(rune(p.s[p.pos])) || p.s[p.pos] == '_') {
|
||||
p.pos++
|
||||
}
|
||||
name := p.s[start:p.pos]
|
||||
p.skipSpaces()
|
||||
|
||||
switch name {
|
||||
case "pi":
|
||||
return math.Pi, nil
|
||||
case "e":
|
||||
return math.E, nil
|
||||
case "sqrt", "sin", "cos", "tan", "abs", "floor", "ceil", "round", "log", "ln":
|
||||
if p.peek() != '(' {
|
||||
return 0, fmt.Errorf("expected '(' after function %s", name)
|
||||
}
|
||||
p.pos++
|
||||
arg, err := p.parseAddSub()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if p.peek() != ')' {
|
||||
return 0, fmt.Errorf("missing ')' after function argument")
|
||||
}
|
||||
p.pos++
|
||||
return applyFunc(name, arg)
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown function or constant: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
func applyFunc(name string, x float64) (float64, error) {
|
||||
switch name {
|
||||
case "sqrt":
|
||||
if x < 0 {
|
||||
return 0, fmt.Errorf("square root of negative number")
|
||||
}
|
||||
return math.Sqrt(x), nil
|
||||
case "sin":
|
||||
return math.Sin(x), nil
|
||||
case "cos":
|
||||
return math.Cos(x), nil
|
||||
case "tan":
|
||||
return math.Tan(x), nil
|
||||
case "abs":
|
||||
return math.Abs(x), nil
|
||||
case "floor":
|
||||
return math.Floor(x), nil
|
||||
case "ceil":
|
||||
return math.Ceil(x), nil
|
||||
case "round":
|
||||
return math.Round(x), nil
|
||||
case "log":
|
||||
if x <= 0 {
|
||||
return 0, fmt.Errorf("log of non-positive number")
|
||||
}
|
||||
return math.Log10(x), nil
|
||||
case "ln":
|
||||
if x <= 0 {
|
||||
return 0, fmt.Errorf("ln of non-positive number")
|
||||
}
|
||||
return math.Log(x), nil
|
||||
}
|
||||
return 0, fmt.Errorf("unknown function: %s", name)
|
||||
}
|
||||
@@ -1,116 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"hash"
|
||||
"net/url"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
type CryptoPlugin struct{ sdk.BasePlugin }
|
||||
|
||||
func (p *CryptoPlugin) Metadata() sdk.PluginMetadata {
|
||||
return sdk.PluginMetadata{
|
||||
Name: "crypto", DisplayName: "Crypto & Encoding", Version: "1.0.0",
|
||||
Description: "Hashing (MD5/SHA) and encoding (Base64, URL) utilities",
|
||||
Category: "utility", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *CryptoPlugin) Tools() []sdk.Tool { return []sdk.Tool{&CryptoTool{}} }
|
||||
|
||||
type CryptoTool struct{ sdk.BaseTool }
|
||||
|
||||
func (t *CryptoTool) Definition() sdk.ToolDefinition {
|
||||
return sdk.ToolDefinition{
|
||||
ID: "crypto", Name: "crypto", DisplayName: "Crypto & Encoding",
|
||||
Description: "Crypto hash and encoding utilities. MD5/SHA hashing, Base64 encode/decode, URL encode/decode.",
|
||||
Category: "utility", Complexity: sdk.ComplexitySimple,
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{"type": "string", "enum": []string{"hash", "base64_encode", "base64_decode", "url_encode", "url_decode"}},
|
||||
"input": map[string]interface{}{"type": "string"},
|
||||
"algorithm": map[string]interface{}{"type": "string", "enum": []string{"md5", "sha1", "sha256", "sha512"}},
|
||||
},
|
||||
"required": []string{"action", "input"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *CryptoTool) Validate(args map[string]interface{}) error {
|
||||
for _, k := range []string{"action", "input"} {
|
||||
if _, ok := args[k]; !ok {
|
||||
return fmt.Errorf("missing required parameter: %s", k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *CryptoTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
action, _ := args["action"].(string)
|
||||
input, _ := args["input"].(string)
|
||||
|
||||
switch action {
|
||||
case "hash":
|
||||
alg, _ := args["algorithm"].(string)
|
||||
if alg == "" {
|
||||
alg = "sha256"
|
||||
}
|
||||
var h hash.Hash
|
||||
switch alg {
|
||||
case "md5":
|
||||
h = md5.New()
|
||||
case "sha1":
|
||||
h = sha1.New()
|
||||
case "sha256":
|
||||
h = sha256.New()
|
||||
case "sha512":
|
||||
h = sha512.New()
|
||||
default:
|
||||
return &sdk.ToolResult{ToolName: "crypto", Success: false, Error: "unsupported algorithm: " + alg}, nil
|
||||
}
|
||||
h.Write([]byte(input))
|
||||
return &sdk.ToolResult{ToolName: "crypto", Success: true,
|
||||
Output: fmt.Sprintf("%s: %x", alg, h.Sum(nil))}, nil
|
||||
|
||||
case "base64_encode":
|
||||
return &sdk.ToolResult{ToolName: "crypto", Success: true,
|
||||
Output: base64.StdEncoding.EncodeToString([]byte(input))}, nil
|
||||
|
||||
case "base64_decode":
|
||||
for _, enc := range []*base64.Encoding{base64.StdEncoding, base64.RawStdEncoding, base64.URLEncoding, base64.RawURLEncoding} {
|
||||
if decoded, err := enc.DecodeString(input); err == nil {
|
||||
return &sdk.ToolResult{ToolName: "crypto", Success: true, Output: truncate(string(decoded), 200)}, nil
|
||||
}
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "crypto", Success: false, Error: "failed to decode base64"}, nil
|
||||
|
||||
case "url_encode":
|
||||
return &sdk.ToolResult{ToolName: "crypto", Success: true,
|
||||
Output: url.QueryEscape(input)}, nil
|
||||
|
||||
case "url_decode":
|
||||
decoded, err := url.QueryUnescape(input)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "crypto", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "crypto", Success: true, Output: decoded}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "crypto", Success: false, Error: "unknown action: " + action}, nil
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) > n {
|
||||
return string(runes[:n]) + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -1,170 +0,0 @@
|
||||
package datetime
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
type DatetimePlugin struct{ sdk.BasePlugin }
|
||||
|
||||
func (p *DatetimePlugin) Metadata() sdk.PluginMetadata {
|
||||
return sdk.PluginMetadata{
|
||||
Name: "datetime", DisplayName: "Date & Time", Version: "1.0.0",
|
||||
Description: "Date/time utilities: now, format, arithmetic, diff, timezone list",
|
||||
Category: "utility", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DatetimePlugin) Tools() []sdk.Tool { return []sdk.Tool{&DatetimeTool{}} }
|
||||
|
||||
type DatetimeTool struct{ sdk.BaseTool }
|
||||
|
||||
func (t *DatetimeTool) Definition() sdk.ToolDefinition {
|
||||
return sdk.ToolDefinition{
|
||||
ID: "datetime", Name: "datetime", DisplayName: "Date & Time",
|
||||
Description: "Date/time utility. Get current time, format dates, date arithmetic, date diff, list timezones.",
|
||||
Category: "utility", Complexity: sdk.ComplexitySimple,
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{"type": "string", "enum": []string{"now", "format", "add", "diff", "timezone_list"}},
|
||||
"format": map[string]interface{}{"type": "string"},
|
||||
"timezone": map[string]interface{}{"type": "string"},
|
||||
"date": map[string]interface{}{"type": "string"},
|
||||
"duration": map[string]interface{}{"type": "string"},
|
||||
"date2": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DatetimeTool) Validate(args map[string]interface{}) error {
|
||||
if _, ok := args["action"]; !ok {
|
||||
return fmt.Errorf("missing required parameter: action")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *DatetimeTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
action, _ := args["action"].(string)
|
||||
tzStr, _ := args["timezone"].(string)
|
||||
loc, _ := parseLocation(tzStr)
|
||||
now := time.Now().In(loc)
|
||||
|
||||
switch action {
|
||||
case "now":
|
||||
return &sdk.ToolResult{ToolName: "datetime", Success: true,
|
||||
Output: fmt.Sprintf("Current time: %s (unix: %d, zone: %s)", now.Format(time.RFC3339), now.Unix(), loc.String())}, nil
|
||||
|
||||
case "format":
|
||||
dateStr, _ := args["date"].(string)
|
||||
format, _ := args["format"].(string)
|
||||
if format == "" {
|
||||
format = time.RFC3339
|
||||
}
|
||||
parsed, err := parseDate(dateStr, loc)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "datetime", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "datetime", Success: true,
|
||||
Output: fmt.Sprintf("Formatted: %s", parsed.Format(format))}, nil
|
||||
|
||||
case "add":
|
||||
dateStr, _ := args["date"].(string)
|
||||
durStr, _ := args["duration"].(string)
|
||||
base := now
|
||||
if dateStr != "" {
|
||||
var err error
|
||||
base, err = parseDate(dateStr, loc)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "datetime", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
}
|
||||
result, err := addDuration(base, durStr)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "datetime", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "datetime", Success: true,
|
||||
Output: fmt.Sprintf("%s + %s = %s", base.Format(time.RFC3339), durStr, result.Format(time.RFC3339))}, nil
|
||||
|
||||
case "diff":
|
||||
d1, _ := args["date"].(string)
|
||||
d2, _ := args["date2"].(string)
|
||||
t1, err := parseDate(d1, loc)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "datetime", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
t2, err := parseDate(d2, loc)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "datetime", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
diff := t2.Sub(t1)
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
days := int(diff.Hours()) / 24
|
||||
hours := int(diff.Hours()) % 24
|
||||
minutes := int(diff.Minutes()) % 60
|
||||
seconds := int(diff.Seconds()) % 60
|
||||
return &sdk.ToolResult{ToolName: "datetime", Success: true,
|
||||
Output: fmt.Sprintf("Difference: %d days, %d hours, %d minutes, %d seconds", days, hours, minutes, seconds)}, nil
|
||||
|
||||
case "timezone_list":
|
||||
return &sdk.ToolResult{ToolName: "datetime", Success: true,
|
||||
Output: "Common timezones: UTC, Asia/Shanghai, Asia/Tokyo, Asia/Seoul, Asia/Singapore, Asia/Kolkata, Asia/Dubai, Europe/London, Europe/Paris, Europe/Moscow, America/New_York, America/Chicago, America/Los_Angeles, America/Sao_Paulo, Australia/Sydney, Pacific/Auckland, Africa/Cairo, Africa/Lagos"}, nil
|
||||
|
||||
default:
|
||||
return &sdk.ToolResult{ToolName: "datetime", Success: false,
|
||||
Error: fmt.Sprintf("unknown action: %s", action)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseLocation(tz string) (*time.Location, error) {
|
||||
if tz == "" {
|
||||
loc, err := time.LoadLocation("Asia/Shanghai")
|
||||
if err != nil {
|
||||
return time.UTC, nil
|
||||
}
|
||||
return loc, nil
|
||||
}
|
||||
return time.LoadLocation(tz)
|
||||
}
|
||||
|
||||
func parseDate(s string, loc *time.Location) (time.Time, error) {
|
||||
formats := []string{time.RFC3339, "2006-01-02T15:04:05", "2006-01-02 15:04:05", "2006-01-02", "2006/01/02"}
|
||||
for _, f := range formats {
|
||||
if t, err := time.ParseInLocation(f, s, loc); err == nil {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return time.Time{}, fmt.Errorf("cannot parse date: %s", s)
|
||||
}
|
||||
|
||||
func addDuration(t time.Time, durStr string) (time.Time, error) {
|
||||
durStr = strings.TrimSpace(durStr)
|
||||
if durStr == "" {
|
||||
return t, nil
|
||||
}
|
||||
// Handle months and years
|
||||
if strings.Contains(durStr, "M") || strings.Contains(durStr, "y") {
|
||||
months := 0
|
||||
years := 0
|
||||
if strings.Contains(durStr, "y") {
|
||||
fmt.Sscanf(durStr, "%dy", &years)
|
||||
}
|
||||
if strings.Contains(durStr, "M") {
|
||||
fmt.Sscanf(durStr, "%dM", &months)
|
||||
}
|
||||
return t.AddDate(years, months, 0), nil
|
||||
}
|
||||
d, err := time.ParseDuration(durStr)
|
||||
if err != nil {
|
||||
return t, fmt.Errorf("invalid duration: %s", durStr)
|
||||
}
|
||||
return t.Add(d), nil
|
||||
}
|
||||
@@ -1,158 +0,0 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
type FilePlugin struct {
|
||||
sdk.BasePlugin
|
||||
dataDir string
|
||||
}
|
||||
|
||||
func NewFilePlugin(dataDir string) *FilePlugin {
|
||||
if dataDir == "" {
|
||||
dataDir = "/tmp/cyrene_data"
|
||||
}
|
||||
return &FilePlugin{dataDir: dataDir}
|
||||
}
|
||||
|
||||
func (p *FilePlugin) Metadata() sdk.PluginMetadata {
|
||||
return sdk.PluginMetadata{
|
||||
Name: "file", DisplayName: "File Operations", Version: "1.0.0",
|
||||
Description: "Sandboxed file operations: read, write, list, delete within DATA_DIR",
|
||||
Category: "system", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *FilePlugin) Tools() []sdk.Tool { return []sdk.Tool{&FileTool{dataDir: p.dataDir}} }
|
||||
|
||||
type FileTool struct {
|
||||
sdk.BaseTool
|
||||
dataDir string
|
||||
}
|
||||
|
||||
func (t *FileTool) Definition() sdk.ToolDefinition {
|
||||
return sdk.ToolDefinition{
|
||||
ID: "file_ops", Name: "file_ops", DisplayName: "File Operations",
|
||||
Description: "File operations within a sandboxed data directory. Read, write, list, check existence, delete.",
|
||||
Category: "system", Complexity: sdk.ComplexitySimple,
|
||||
DangerLevel: "medium",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{"type": "string", "enum": []string{"read", "write", "list", "exists", "delete"}},
|
||||
"path": map[string]interface{}{"type": "string"},
|
||||
"content": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []string{"action", "path"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *FileTool) Validate(args map[string]interface{}) error {
|
||||
for _, k := range []string{"action", "path"} {
|
||||
if _, ok := args[k]; !ok {
|
||||
return fmt.Errorf("missing required parameter: %s", k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *FileTool) safePath(p string) (string, error) {
|
||||
clean := filepath.Clean(p)
|
||||
abs, err := filepath.Abs(filepath.Join(t.dataDir, clean))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("path resolution failed: %w", err)
|
||||
}
|
||||
if !strings.HasPrefix(abs, filepath.Clean(t.dataDir)+string(os.PathSeparator)) && abs != filepath.Clean(t.dataDir) {
|
||||
return "", fmt.Errorf("path traversal denied: %s", p)
|
||||
}
|
||||
return abs, nil
|
||||
}
|
||||
|
||||
func (t *FileTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
action, _ := args["action"].(string)
|
||||
pathStr, _ := args["path"].(string)
|
||||
|
||||
safePath, err := t.safePath(pathStr)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "read":
|
||||
info, err := os.Stat(safePath)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
if info.IsDir() {
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: "cannot read a directory"}, nil
|
||||
}
|
||||
if info.Size() > 100*1024 {
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: "file too large (>100KB)"}, nil
|
||||
}
|
||||
data, err := os.ReadFile(safePath)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: true, Output: string(data)}, nil
|
||||
|
||||
case "write":
|
||||
content, _ := args["content"].(string)
|
||||
dir := filepath.Dir(safePath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
if err := os.WriteFile(safePath, []byte(content), 0644); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: true, Output: fmt.Sprintf("Written %d bytes to %s", len(content), pathStr)}, nil
|
||||
|
||||
case "list":
|
||||
entries, err := os.ReadDir(safePath)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
var out strings.Builder
|
||||
for _, e := range entries {
|
||||
info, _ := e.Info()
|
||||
if e.IsDir() {
|
||||
out.WriteString(fmt.Sprintf("[DIR] %s/\n", e.Name()))
|
||||
} else {
|
||||
out.WriteString(fmt.Sprintf("[FILE] %s (%d bytes)\n", e.Name(), info.Size()))
|
||||
}
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: true, Output: out.String()}, nil
|
||||
|
||||
case "exists":
|
||||
info, err := os.Stat(safePath)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: true, Output: fmt.Sprintf("Path does not exist: %s", pathStr)}, nil
|
||||
}
|
||||
kind := "file"
|
||||
if info.IsDir() {
|
||||
kind = "directory"
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: true, Output: fmt.Sprintf("Path exists (%s): %s", kind, pathStr)}, nil
|
||||
|
||||
case "delete":
|
||||
info, err := os.Stat(safePath)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
if info.IsDir() {
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: "cannot delete a directory"}, nil
|
||||
}
|
||||
if err := os.Remove(safePath); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: true, Output: fmt.Sprintf("Deleted: %s", pathStr)}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "file_ops", Success: false, Error: "unknown action: " + action}, nil
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
module git.yeij.top/AskaEth/Cyrene/pkg/plugins
|
||||
|
||||
go 1.21
|
||||
@@ -1,122 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
type HTTPPlugin struct {
|
||||
sdk.BasePlugin
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewHTTPPlugin() *HTTPPlugin {
|
||||
return &HTTPPlugin{client: &http.Client{Timeout: 10 * time.Second}}
|
||||
}
|
||||
|
||||
func (p *HTTPPlugin) Metadata() sdk.PluginMetadata {
|
||||
return sdk.PluginMetadata{
|
||||
Name: "http", DisplayName: "HTTP Client", Version: "1.0.0",
|
||||
Description: "Send arbitrary HTTP requests with custom methods, headers, body",
|
||||
Category: "network", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *HTTPPlugin) Tools() []sdk.Tool { return []sdk.Tool{&HTTPTool{client: p.client}} }
|
||||
|
||||
type HTTPTool struct {
|
||||
sdk.BaseTool
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (t *HTTPTool) Definition() sdk.ToolDefinition {
|
||||
return sdk.ToolDefinition{
|
||||
ID: "http_request", Name: "http_request", DisplayName: "HTTP Client",
|
||||
Description: "Send arbitrary HTTP requests. Supports custom methods, headers, and body.",
|
||||
Category: "network", Complexity: sdk.ComplexitySimple,
|
||||
DangerLevel: "low",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"url": map[string]interface{}{"type": "string"},
|
||||
"method": map[string]interface{}{"type": "string", "enum": []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}},
|
||||
"headers": map[string]interface{}{"type": "object"},
|
||||
"body": map[string]interface{}{"type": "string"},
|
||||
"timeout": map[string]interface{}{"type": "number"},
|
||||
},
|
||||
"required": []string{"url"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var allowedMethods = map[string]bool{
|
||||
"GET": true, "POST": true, "PUT": true, "DELETE": true,
|
||||
"PATCH": true, "HEAD": true, "OPTIONS": true,
|
||||
}
|
||||
|
||||
func (t *HTTPTool) Validate(args map[string]interface{}) error {
|
||||
if _, ok := args["url"]; !ok {
|
||||
return fmt.Errorf("missing required parameter: url")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *HTTPTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
urlStr, _ := args["url"].(string)
|
||||
method, _ := args["method"].(string)
|
||||
if method == "" {
|
||||
method = "GET"
|
||||
}
|
||||
if !allowedMethods[method] {
|
||||
return &sdk.ToolResult{ToolName: "http_request", Success: false, Error: "invalid method: " + method}, nil
|
||||
}
|
||||
if !strings.HasPrefix(urlStr, "http://") && !strings.HasPrefix(urlStr, "https://") {
|
||||
return &sdk.ToolResult{ToolName: "http_request", Success: false, Error: "only http/https URLs allowed"}, nil
|
||||
}
|
||||
|
||||
var bodyReader io.Reader
|
||||
if body, _ := args["body"].(string); body != "" {
|
||||
bodyReader = strings.NewReader(body)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, urlStr, bodyReader)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "http_request", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
req.Header.Set("User-Agent", "CyreneBot/1.0")
|
||||
|
||||
if headers, ok := args["headers"].(map[string]interface{}); ok {
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
|
||||
client := t.client
|
||||
if timeout, _ := args["timeout"].(float64); timeout > 0 {
|
||||
client = &http.Client{Timeout: time.Duration(timeout) * time.Second}
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "http_request", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, 50*1024))
|
||||
return &sdk.ToolResult{ToolName: "http_request", Success: resp.StatusCode < 500, Output: fmt.Sprintf(
|
||||
"HTTP %d\n%s\n\n%s", resp.StatusCode, formatHeaders(resp.Header), string(bodyBytes))}, nil
|
||||
}
|
||||
|
||||
func formatHeaders(h http.Header) string {
|
||||
var lines []string
|
||||
for k, v := range h {
|
||||
lines = append(lines, fmt.Sprintf("%s: %s", k, strings.Join(v, ", ")))
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
@@ -1,189 +0,0 @@
|
||||
package iotcontrol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
// IoTController extends IoTClient with control operations.
|
||||
type IoTController interface {
|
||||
GetDevice(ctx context.Context, deviceID string) (*sdk.IoTDeviceState, error)
|
||||
SetDeviceProperty(ctx context.Context, deviceID, property string, value interface{}) error
|
||||
ToggleDevice(ctx context.Context, deviceID string) (*sdk.IoTDeviceState, error)
|
||||
}
|
||||
|
||||
type IoTControlPlugin struct {
|
||||
sdk.BasePlugin
|
||||
iotClient IoTController
|
||||
}
|
||||
|
||||
func NewIoTControlPlugin(client IoTController) *IoTControlPlugin {
|
||||
return &IoTControlPlugin{iotClient: client}
|
||||
}
|
||||
|
||||
func (p *IoTControlPlugin) Metadata() sdk.PluginMetadata {
|
||||
return sdk.PluginMetadata{
|
||||
Name: "iot_control", DisplayName: "IoT Device Control", Version: "1.0.0",
|
||||
Description: "Control smart home devices: toggle, set temperature/brightness/mode/color",
|
||||
Category: "iot", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *IoTControlPlugin) Tools() []sdk.Tool {
|
||||
return []sdk.Tool{&IoTControlTool{iotClient: p.iotClient}}
|
||||
}
|
||||
|
||||
type IoTControlTool struct {
|
||||
sdk.BaseTool
|
||||
iotClient IoTController
|
||||
}
|
||||
|
||||
func (t *IoTControlTool) Definition() sdk.ToolDefinition {
|
||||
return sdk.ToolDefinition{
|
||||
ID: "iot_control", Name: "iot_control", DisplayName: "IoT Device Control",
|
||||
Description: "Control smart home devices. Supports toggle, turn_on, turn_off, set_temperature, set_brightness, set_position, set_mode, set_color.",
|
||||
Category: "iot", Complexity: sdk.ComplexitySimple,
|
||||
DangerLevel: "medium",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"device_id": map[string]interface{}{"type": "string"},
|
||||
"action": map[string]interface{}{"type": "string", "enum": []string{"toggle", "turn_on", "turn_off", "set_temperature", "set_brightness", "set_position", "set_mode", "set_color"}},
|
||||
"value": map[string]interface{}{},
|
||||
},
|
||||
"required": []string{"device_id", "action"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *IoTControlTool) Validate(args map[string]interface{}) error {
|
||||
for _, k := range []string{"device_id", "action"} {
|
||||
if _, ok := args[k]; !ok {
|
||||
return fmt.Errorf("missing required parameter: %s", k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *IoTControlTool) Execute(ctx context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
if t.iotClient == nil {
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: "IoT client not configured"}, nil
|
||||
}
|
||||
|
||||
deviceID, _ := args["device_id"].(string)
|
||||
if deviceID == "" {
|
||||
deviceID, _ = args["entity_id"].(string)
|
||||
}
|
||||
action := normalizeAction(args)
|
||||
|
||||
switch action {
|
||||
case "turn_on", "turn_off":
|
||||
status := "on"
|
||||
if action == "turn_off" {
|
||||
status = "off"
|
||||
}
|
||||
if err := t.iotClient.SetDeviceProperty(ctx, deviceID, "status", status); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
||||
Output: fmt.Sprintf("Device %s turned %s", deviceID, status)}, nil
|
||||
|
||||
case "set_temperature":
|
||||
value := toFloat64(args["value"])
|
||||
old := ""
|
||||
if dev, err := t.iotClient.GetDevice(ctx, deviceID); err == nil {
|
||||
old = fmt.Sprintf(" (was %.1fC)", dev.Temperature)
|
||||
}
|
||||
if err := t.iotClient.SetDeviceProperty(ctx, deviceID, "temperature", value); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
||||
Output: fmt.Sprintf("Temperature set to %.1fC%s", value, old)}, nil
|
||||
|
||||
case "set_brightness":
|
||||
value := toFloat64(args["value"])
|
||||
if err := t.iotClient.SetDeviceProperty(ctx, deviceID, "brightness", value); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
||||
Output: fmt.Sprintf("Brightness set to %.0f%%", value)}, nil
|
||||
|
||||
case "set_position":
|
||||
value := toFloat64(args["value"])
|
||||
if err := t.iotClient.SetDeviceProperty(ctx, deviceID, "position", value); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
||||
Output: fmt.Sprintf("Position set to %.0f%%", value)}, nil
|
||||
|
||||
case "set_mode":
|
||||
value, _ := args["value"].(string)
|
||||
if err := t.iotClient.SetDeviceProperty(ctx, deviceID, "mode", value); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
||||
Output: fmt.Sprintf("Mode set to %s", value)}, nil
|
||||
|
||||
case "set_color":
|
||||
value, _ := args["value"].(string)
|
||||
if err := t.iotClient.SetDeviceProperty(ctx, deviceID, "color", value); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
||||
Output: fmt.Sprintf("Color set to %s", value)}, nil
|
||||
|
||||
case "toggle":
|
||||
dev, err := t.iotClient.ToggleDevice(ctx, deviceID)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: true,
|
||||
Output: fmt.Sprintf("Device %s toggled to %s", deviceID, dev.Status)}, nil
|
||||
|
||||
default:
|
||||
return &sdk.ToolResult{ToolName: "iot_control", Success: false,
|
||||
Error: fmt.Sprintf("unknown action: %s", action)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAction(args map[string]interface{}) string {
|
||||
action, _ := args["action"].(string)
|
||||
// Chinese aliases
|
||||
switch action {
|
||||
case "打开":
|
||||
return "turn_on"
|
||||
case "关闭", "关掉", "关上":
|
||||
return "turn_off"
|
||||
case "设置温度", "调温度":
|
||||
return "set_temperature"
|
||||
case "设置亮度", "调亮度":
|
||||
return "set_brightness"
|
||||
case "设置位置":
|
||||
return "set_position"
|
||||
case "设置模式":
|
||||
return "set_mode"
|
||||
case "设置颜色":
|
||||
return "set_color"
|
||||
case "开关", "切换":
|
||||
return "toggle"
|
||||
}
|
||||
return action
|
||||
}
|
||||
|
||||
func toFloat64(v interface{}) float64 {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return n
|
||||
case int:
|
||||
return float64(n)
|
||||
case int64:
|
||||
return float64(n)
|
||||
case string:
|
||||
var f float64
|
||||
fmt.Sscanf(n, "%f", &f)
|
||||
return f
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
package iotquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
// IoTClient is the interface for IoT device access.
|
||||
type IoTClient interface {
|
||||
GetAllDevices(ctx context.Context) ([]sdk.IoTDeviceState, error)
|
||||
GetDevice(ctx context.Context, deviceID string) (*sdk.IoTDeviceState, error)
|
||||
}
|
||||
|
||||
type IoTQueryPlugin struct {
|
||||
sdk.BasePlugin
|
||||
iotClient IoTClient
|
||||
}
|
||||
|
||||
func NewIoTQueryPlugin(client IoTClient) *IoTQueryPlugin {
|
||||
return &IoTQueryPlugin{iotClient: client}
|
||||
}
|
||||
|
||||
func (p *IoTQueryPlugin) Metadata() sdk.PluginMetadata {
|
||||
return sdk.PluginMetadata{
|
||||
Name: "iot_query", DisplayName: "IoT Device Query", Version: "1.0.0",
|
||||
Description: "Query smart home device status (single device or all devices)",
|
||||
Category: "iot", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *IoTQueryPlugin) Tools() []sdk.Tool { return []sdk.Tool{&IoTQueryTool{iotClient: p.iotClient}} }
|
||||
|
||||
type IoTQueryTool struct {
|
||||
sdk.BaseTool
|
||||
iotClient IoTClient
|
||||
}
|
||||
|
||||
func (t *IoTQueryTool) Definition() sdk.ToolDefinition {
|
||||
return sdk.ToolDefinition{
|
||||
ID: "iot_query", Name: "iot_query", DisplayName: "IoT Device Query",
|
||||
Description: "Query smart home device status. Device status is typically auto-injected; use this only when status is stale.",
|
||||
Category: "iot", Complexity: sdk.ComplexitySimple,
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{"device_id": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *IoTQueryTool) Validate(args map[string]interface{}) error { return nil }
|
||||
|
||||
func (t *IoTQueryTool) Execute(ctx context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
if t.iotClient == nil {
|
||||
return &sdk.ToolResult{ToolName: "iot_query", Success: false, Error: "IoT client not configured"}, nil
|
||||
}
|
||||
|
||||
deviceID, _ := args["device_id"].(string)
|
||||
if deviceID != "" {
|
||||
dev, err := t.iotClient.GetDevice(ctx, deviceID)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "iot_query", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "iot_query", Success: true, Output: formatDevice(dev)}, nil
|
||||
}
|
||||
|
||||
devices, err := t.iotClient.GetAllDevices(ctx)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "iot_query", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
if len(devices) == 0 {
|
||||
return &sdk.ToolResult{ToolName: "iot_query", Success: true, Output: "No devices found"}, nil
|
||||
}
|
||||
var out string
|
||||
for _, d := range devices {
|
||||
out += formatDeviceLine(&d) + "\n"
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "iot_query", Success: true, Output: out}, nil
|
||||
}
|
||||
|
||||
func formatDevice(d *sdk.IoTDeviceState) string {
|
||||
emoji := deviceEmoji(d.Type)
|
||||
return fmt.Sprintf("%s %s (%s)\n Status: %s\n ID: %s", emoji, d.Name, d.Type, d.Status, d.ID)
|
||||
}
|
||||
|
||||
func formatDeviceLine(d *sdk.IoTDeviceState) string {
|
||||
emoji := deviceEmoji(d.Type)
|
||||
switch d.Type {
|
||||
case "light":
|
||||
return fmt.Sprintf("%s %s: %s (brightness: %d, color: %s)", emoji, d.Name, d.Status, d.Brightness, d.Color)
|
||||
case "ac":
|
||||
return fmt.Sprintf("%s %s: %s (mode: %s, temp: %.1fC)", emoji, d.Name, d.Status, d.Mode, d.Temperature)
|
||||
case "curtain":
|
||||
return fmt.Sprintf("%s %s: %s (position: %d%%)", emoji, d.Name, d.Status, d.Position)
|
||||
case "sensor":
|
||||
return fmt.Sprintf("%s %s: %.1f%s", emoji, d.Name, d.Value, d.Unit)
|
||||
case "lock":
|
||||
return fmt.Sprintf("%s %s: %s (battery: %d%%)", emoji, d.Name, d.Status, d.Battery)
|
||||
default:
|
||||
return fmt.Sprintf("%s %s: %s", emoji, d.Name, d.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func deviceEmoji(t string) string {
|
||||
switch t {
|
||||
case "light":
|
||||
return "\U0001F4A1"
|
||||
case "ac":
|
||||
return "❄️"
|
||||
case "curtain":
|
||||
return "\U0001F3E0"
|
||||
case "sensor":
|
||||
return "\U0001F4CA"
|
||||
case "lock":
|
||||
return "\U0001F512"
|
||||
default:
|
||||
return "\U0001F4E6"
|
||||
}
|
||||
}
|
||||
@@ -1,132 +0,0 @@
|
||||
package jsonplugin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
type JSONPlugin struct{ sdk.BasePlugin }
|
||||
|
||||
func (p *JSONPlugin) Metadata() sdk.PluginMetadata {
|
||||
return sdk.PluginMetadata{
|
||||
Name: "json", DisplayName: "JSON Processor", Version: "1.0.0",
|
||||
Description: "JSON parsing, dot-path query, validation, pretty-print",
|
||||
Category: "format", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *JSONPlugin) Tools() []sdk.Tool { return []sdk.Tool{&JSONTool{}} }
|
||||
|
||||
type JSONTool struct{ sdk.BaseTool }
|
||||
|
||||
func (t *JSONTool) Definition() sdk.ToolDefinition {
|
||||
return sdk.ToolDefinition{
|
||||
ID: "json_ops", Name: "json_ops", DisplayName: "JSON Processor",
|
||||
Description: "JSON processing. Parse/pretty-print, query by dot-notation path, validate.",
|
||||
Category: "format", Complexity: sdk.ComplexitySimple,
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{"type": "string", "enum": []string{"parse", "query", "validate"}},
|
||||
"json_string": map[string]interface{}{"type": "string"},
|
||||
"path": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []string{"action", "json_string"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *JSONTool) Validate(args map[string]interface{}) error {
|
||||
for _, k := range []string{"action", "json_string"} {
|
||||
if _, ok := args[k]; !ok {
|
||||
return fmt.Errorf("missing required parameter: %s", k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *JSONTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
action, _ := args["action"].(string)
|
||||
jsonStr, _ := args["json_string"].(string)
|
||||
|
||||
switch action {
|
||||
case "parse":
|
||||
var v interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &v); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "json_ops", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
pretty, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "json_ops", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "json_ops", Success: true, Output: string(pretty)}, nil
|
||||
|
||||
case "query":
|
||||
var v interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &v); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "json_ops", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
path, _ := args["path"].(string)
|
||||
if path == "" {
|
||||
return &sdk.ToolResult{ToolName: "json_ops", Success: false, Error: "path is required for query"}, nil
|
||||
}
|
||||
result, err := jsonPathQuery(v, path)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "json_ops", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
out, _ := json.Marshal(result)
|
||||
return &sdk.ToolResult{ToolName: "json_ops", Success: true, Output: string(out)}, nil
|
||||
|
||||
case "validate":
|
||||
var v interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &v); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "json_ops", Success: true, Output: "Invalid JSON: " + err.Error()}, nil
|
||||
}
|
||||
typeStr := "unknown"
|
||||
switch v.(type) {
|
||||
case map[string]interface{}:
|
||||
typeStr = "object"
|
||||
case []interface{}:
|
||||
typeStr = "array"
|
||||
case string:
|
||||
typeStr = "string"
|
||||
case float64:
|
||||
typeStr = "number"
|
||||
case bool:
|
||||
typeStr = "boolean"
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "json_ops", Success: true,
|
||||
Output: fmt.Sprintf("Valid JSON (type: %s, size: %d bytes)", typeStr, len(jsonStr))}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "json_ops", Success: false, Error: "unknown action: " + action}, nil
|
||||
}
|
||||
|
||||
func jsonPathQuery(root interface{}, path string) (interface{}, error) {
|
||||
path = strings.TrimPrefix(path, "$.")
|
||||
parts := strings.Split(path, ".")
|
||||
current := root
|
||||
for _, part := range parts {
|
||||
switch v := current.(type) {
|
||||
case map[string]interface{}:
|
||||
var ok bool
|
||||
current, ok = v[part]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("key %q not found", part)
|
||||
}
|
||||
case []interface{}:
|
||||
idx, err := strconv.Atoi(part)
|
||||
if err != nil || idx < 0 || idx >= len(v) {
|
||||
return nil, fmt.Errorf("invalid array index: %s", part)
|
||||
}
|
||||
current = v[idx]
|
||||
default:
|
||||
return nil, fmt.Errorf("cannot traverse into %T at path segment %q", current, part)
|
||||
}
|
||||
}
|
||||
return current, nil
|
||||
}
|
||||
@@ -1,226 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
// PluginManager manages the lifecycle of all plugins and their tools.
|
||||
type PluginManager struct {
|
||||
mu sync.RWMutex
|
||||
plugins map[string]*pluginEntry
|
||||
registry *ToolRegistry
|
||||
host sdk.HostAPI
|
||||
}
|
||||
|
||||
type pluginEntry struct {
|
||||
instance sdk.Plugin
|
||||
info sdk.PluginInfo
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewPluginManager(registry *ToolRegistry, host sdk.HostAPI) *PluginManager {
|
||||
return &PluginManager{
|
||||
plugins: make(map[string]*pluginEntry),
|
||||
registry: registry,
|
||||
host: host,
|
||||
}
|
||||
}
|
||||
|
||||
// Install registers a plugin instance.
|
||||
func (m *PluginManager) Install(plugin sdk.Plugin) error {
|
||||
meta := plugin.Metadata()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.plugins[meta.Name]; exists {
|
||||
return fmt.Errorf("plugin %q is already installed", meta.Name)
|
||||
}
|
||||
|
||||
m.plugins[meta.Name] = &pluginEntry{
|
||||
instance: plugin,
|
||||
info: sdk.PluginInfo{
|
||||
Metadata: meta,
|
||||
Status: sdk.StatusInstalled,
|
||||
InstalledAt: time.Now(),
|
||||
Enabled: false,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Enable activates a plugin: Init → register tools → Start.
|
||||
func (m *PluginManager) Enable(ctx context.Context, pluginName string) error {
|
||||
m.mu.Lock()
|
||||
entry, ok := m.plugins[pluginName]
|
||||
m.mu.Unlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("plugin %q not found", pluginName)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
entry.info.Status = sdk.StatusLoaded
|
||||
m.mu.Unlock()
|
||||
|
||||
meta := entry.instance.Metadata()
|
||||
if err := entry.instance.Init(ctx, nil); err != nil {
|
||||
m.mu.Lock()
|
||||
entry.info.Status = sdk.StatusError
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("plugin %q init failed: %w", meta.Name, err)
|
||||
}
|
||||
|
||||
pluginCtx, cancel := context.WithCancel(context.Background())
|
||||
if err := entry.instance.Start(pluginCtx, m.host); err != nil {
|
||||
cancel()
|
||||
m.mu.Lock()
|
||||
entry.info.Status = sdk.StatusError
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("plugin %q start failed: %w", meta.Name, err)
|
||||
}
|
||||
|
||||
tools := entry.instance.Tools()
|
||||
toolIDs := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
if err := m.registry.Register(t); err != nil {
|
||||
m.registry.UnregisterAll(toolIDs)
|
||||
cancel()
|
||||
m.mu.Lock()
|
||||
entry.info.Status = sdk.StatusError
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("plugin %q tool register failed: %w", meta.Name, err)
|
||||
}
|
||||
toolIDs = append(toolIDs, t.Definition().ID)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
entry.cancel = cancel
|
||||
entry.info.Status = sdk.StatusRunning
|
||||
entry.info.Enabled = true
|
||||
entry.info.Tools = toolIDs
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disable stops a plugin and unregisters its tools.
|
||||
func (m *PluginManager) Disable(ctx context.Context, pluginName string) error {
|
||||
m.mu.Lock()
|
||||
entry, ok := m.plugins[pluginName]
|
||||
m.mu.Unlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("plugin %q not found", pluginName)
|
||||
}
|
||||
|
||||
if err := entry.instance.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("plugin %q stop failed: %w", pluginName, err)
|
||||
}
|
||||
if entry.cancel != nil {
|
||||
entry.cancel()
|
||||
}
|
||||
|
||||
m.registry.UnregisterAll(entry.info.Tools)
|
||||
|
||||
m.mu.Lock()
|
||||
entry.info.Status = sdk.StatusDisabled
|
||||
entry.info.Enabled = false
|
||||
entry.info.Tools = nil
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns info for all installed plugins.
|
||||
func (m *PluginManager) List() []sdk.PluginInfo {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]sdk.PluginInfo, 0, len(m.plugins))
|
||||
for _, entry := range m.plugins {
|
||||
result = append(result, entry.info)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Get returns info for a single plugin.
|
||||
func (m *PluginManager) Get(pluginName string) (*sdk.PluginInfo, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
entry, ok := m.plugins[pluginName]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
info := entry.info
|
||||
return &info, true
|
||||
}
|
||||
|
||||
// EnableAll starts all installed plugins.
|
||||
func (m *PluginManager) EnableAll(ctx context.Context) []error {
|
||||
m.mu.RLock()
|
||||
names := make([]string, 0, len(m.plugins))
|
||||
for name := range m.plugins {
|
||||
names = append(names, name)
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
var errs []error
|
||||
for _, name := range names {
|
||||
if err := m.Enable(ctx, name); err != nil {
|
||||
errs = append(errs, fmt.Errorf("%s: %w", name, err))
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// Uninstall removes a plugin completely.
|
||||
func (m *PluginManager) Uninstall(ctx context.Context, pluginName string) error {
|
||||
m.mu.RLock()
|
||||
entry, ok := m.plugins[pluginName]
|
||||
m.mu.RUnlock()
|
||||
if !ok {
|
||||
return fmt.Errorf("plugin %q not found", pluginName)
|
||||
}
|
||||
if entry.info.Status == sdk.StatusRunning {
|
||||
if err := m.Disable(ctx, pluginName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.plugins, pluginName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reload stops and re-starts a plugin.
|
||||
func (m *PluginManager) Reload(ctx context.Context, pluginName string) error {
|
||||
if err := m.Disable(ctx, pluginName); err != nil {
|
||||
return fmt.Errorf("reload disable: %w", err)
|
||||
}
|
||||
return m.Enable(ctx, pluginName)
|
||||
}
|
||||
|
||||
// Shutdown stops all running plugins gracefully.
|
||||
func (m *PluginManager) Shutdown(ctx context.Context) []error {
|
||||
m.mu.RLock()
|
||||
names := make([]string, 0, len(m.plugins))
|
||||
for name, entry := range m.plugins {
|
||||
if entry.info.Status == sdk.StatusRunning {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
var errs []error
|
||||
for _, name := range names {
|
||||
if err := m.Disable(ctx, name); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// Registry returns the aggregated tool registry.
|
||||
func (m *PluginManager) Registry() *ToolRegistry {
|
||||
return m.registry
|
||||
}
|
||||
@@ -1,298 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
// CallLogRecord 工具调用记录
|
||||
type CallLogRecord struct {
|
||||
CallID string `json:"call_id"`
|
||||
ToolName string `json:"tool_name"`
|
||||
Arguments string `json:"arguments"`
|
||||
Output string `json:"output"`
|
||||
Error string `json:"error"`
|
||||
Success bool `json:"success"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// callLogRing 线程安全的环形缓冲区
|
||||
type callLogRing struct {
|
||||
mu sync.Mutex
|
||||
records []CallLogRecord
|
||||
capacity int
|
||||
head int
|
||||
size int
|
||||
}
|
||||
|
||||
func newCallLogRing(capacity int) *callLogRing {
|
||||
return &callLogRing{capacity: capacity, records: make([]CallLogRecord, capacity)}
|
||||
}
|
||||
|
||||
func (r *callLogRing) push(rec CallLogRecord) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
rec.CallID = fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
rec.Timestamp = time.Now().UnixMilli()
|
||||
r.records[r.head] = rec
|
||||
r.head = (r.head + 1) % r.capacity
|
||||
if r.size < r.capacity {
|
||||
r.size++
|
||||
}
|
||||
}
|
||||
|
||||
func (r *callLogRing) getAll() []CallLogRecord {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
result := make([]CallLogRecord, r.size)
|
||||
for i := 0; i < r.size; i++ {
|
||||
idx := (r.head - 1 - i) % r.capacity
|
||||
if idx < 0 {
|
||||
idx += r.capacity
|
||||
}
|
||||
result[i] = r.records[idx]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *callLogRing) statsByTool() map[string]map[string]interface{} {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
byTool := make(map[string]map[string]interface{})
|
||||
for i := 0; i < r.size; i++ {
|
||||
idx := (r.head - 1 - i) % r.capacity
|
||||
if idx < 0 {
|
||||
idx += r.capacity
|
||||
}
|
||||
rec := r.records[idx]
|
||||
if _, ok := byTool[rec.ToolName]; !ok {
|
||||
byTool[rec.ToolName] = map[string]interface{}{
|
||||
"tool_name": rec.ToolName, "count": 0, "success_count": 0,
|
||||
"fail_count": 0, "total_duration_ms": 0,
|
||||
}
|
||||
}
|
||||
s := byTool[rec.ToolName]
|
||||
s["count"] = s["count"].(int) + 1
|
||||
if rec.Success {
|
||||
s["success_count"] = s["success_count"].(int) + 1
|
||||
} else {
|
||||
s["fail_count"] = s["fail_count"].(int) + 1
|
||||
}
|
||||
s["total_duration_ms"] = s["total_duration_ms"].(int) + rec.DurationMs
|
||||
}
|
||||
return byTool
|
||||
}
|
||||
|
||||
// ToolRegistry aggregates tool definitions from all running plugins and dispatches execution.
|
||||
type ToolRegistry struct {
|
||||
mu sync.RWMutex
|
||||
tools map[string]sdk.Tool // tool ID -> Tool
|
||||
callLog *callLogRing
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func NewToolRegistry() *ToolRegistry {
|
||||
return &ToolRegistry{
|
||||
tools: make(map[string]sdk.Tool),
|
||||
callLog: newCallLogRing(500),
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
// IsEnabled returns whether tool execution is enabled.
|
||||
func (r *ToolRegistry) IsEnabled() bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.enabled
|
||||
}
|
||||
|
||||
// SetEnabled enables or disables tool execution.
|
||||
func (r *ToolRegistry) SetEnabled(enabled bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.enabled = enabled
|
||||
}
|
||||
|
||||
// DefinitionNames returns all registered tool names.
|
||||
func (r *ToolRegistry) DefinitionNames() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
names := make([]string, 0, len(r.tools))
|
||||
for id := range r.tools {
|
||||
names = append(names, id)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Register(tool sdk.Tool) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
id := tool.Definition().ID
|
||||
if _, exists := r.tools[id]; exists {
|
||||
return fmt.Errorf("tool %q already registered", id)
|
||||
}
|
||||
r.tools[id] = tool
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Unregister(toolID string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
delete(r.tools, toolID)
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Get(toolID string) (sdk.Tool, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
t, ok := r.tools[toolID]
|
||||
return t, ok
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) List() []sdk.Tool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
result := make([]sdk.Tool, 0, len(r.tools))
|
||||
for _, t := range r.tools {
|
||||
result = append(result, t)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Definitions() []sdk.ToolDefinition {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
defs := make([]sdk.ToolDefinition, 0, len(r.tools))
|
||||
for _, t := range r.tools {
|
||||
defs = append(defs, t.Definition())
|
||||
}
|
||||
return defs
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Execute(ctx context.Context, toolID string, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
r.mu.RLock()
|
||||
tool, ok := r.tools[toolID]
|
||||
r.mu.RUnlock()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
if !ok {
|
||||
r.callLog.push(CallLogRecord{
|
||||
ToolName: toolID, Error: fmt.Sprintf("tool %q not found", toolID),
|
||||
Success: false, DurationMs: int(time.Since(startTime).Milliseconds()),
|
||||
})
|
||||
return nil, fmt.Errorf("tool %q not found", toolID)
|
||||
}
|
||||
|
||||
if err := tool.Validate(args); err != nil {
|
||||
r.callLog.push(CallLogRecord{
|
||||
ToolName: toolID, Error: err.Error(), Success: false,
|
||||
DurationMs: int(time.Since(startTime).Milliseconds()),
|
||||
})
|
||||
return &sdk.ToolResult{Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
result, err := tool.Execute(ctx, args)
|
||||
durationMs := int(time.Since(startTime).Milliseconds())
|
||||
|
||||
if err != nil {
|
||||
r.callLog.push(CallLogRecord{
|
||||
ToolName: toolID, Error: err.Error(), Success: false, DurationMs: durationMs,
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
var argsJSON string
|
||||
if args != nil {
|
||||
if b, _ := json.Marshal(args); b != nil {
|
||||
argsJSON = string(b)
|
||||
}
|
||||
}
|
||||
r.callLog.push(CallLogRecord{
|
||||
ToolName: toolID, Arguments: argsJSON, Output: result.Output,
|
||||
Error: result.Error, Success: result.Success, DurationMs: durationMs,
|
||||
})
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UnregisterAll removes all tools matching given IDs.
|
||||
func (r *ToolRegistry) UnregisterAll(toolIDs []string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for _, id := range toolIDs {
|
||||
delete(r.tools, id)
|
||||
}
|
||||
}
|
||||
|
||||
// GetCallLogs 获取工具调用记录(最新在前,支持按工具名过滤、分页)
|
||||
func (r *ToolRegistry) GetCallLogs(toolName string, limit, offset int) ([]CallLogRecord, int) {
|
||||
all := r.callLog.getAll()
|
||||
|
||||
// 过滤
|
||||
var filtered []CallLogRecord
|
||||
if toolName == "" {
|
||||
filtered = all
|
||||
} else {
|
||||
filtered = make([]CallLogRecord, 0)
|
||||
for _, rec := range all {
|
||||
if rec.ToolName == toolName {
|
||||
filtered = append(filtered, rec)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
total := len(filtered)
|
||||
|
||||
// 分页
|
||||
if offset >= len(filtered) {
|
||||
return []CallLogRecord{}, total
|
||||
}
|
||||
page := filtered[offset:]
|
||||
if limit > 0 && limit < len(page) {
|
||||
page = page[:limit]
|
||||
}
|
||||
|
||||
return page, total
|
||||
}
|
||||
|
||||
// GetCallStats 获取工具调用统计
|
||||
func (r *ToolRegistry) GetCallStats() map[string]interface{} {
|
||||
byTool := r.callLog.statsByTool()
|
||||
totalCalls, successCount, failCount, totalDurationMs := 0, 0, 0, 0
|
||||
toolStats := make([]map[string]interface{}, 0, len(byTool))
|
||||
for _, s := range byTool {
|
||||
count := s["count"].(int)
|
||||
success := s["success_count"].(int)
|
||||
fail := s["fail_count"].(int)
|
||||
totalDur := s["total_duration_ms"].(int)
|
||||
avgDur := 0.0
|
||||
if count > 0 {
|
||||
avgDur = float64(totalDur) / float64(count)
|
||||
}
|
||||
s["avg_duration_ms"] = avgDur
|
||||
delete(s, "total_duration_ms")
|
||||
toolStats = append(toolStats, s)
|
||||
totalCalls += count
|
||||
successCount += success
|
||||
failCount += fail
|
||||
totalDurationMs += totalDur
|
||||
}
|
||||
avgDuration := 0.0
|
||||
if totalCalls > 0 {
|
||||
avgDuration = float64(totalDurationMs) / float64(totalCalls)
|
||||
}
|
||||
successRate := 0.0
|
||||
if totalCalls > 0 {
|
||||
successRate = float64(successCount) / float64(totalCalls) * 100
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"total_calls": totalCalls, "success_count": successCount, "fail_count": failCount,
|
||||
"success_rate": successRate, "avg_duration_ms": avgDuration, "by_tool": toolStats,
|
||||
}
|
||||
}
|
||||
@@ -1,184 +0,0 @@
|
||||
package markdown
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
type MarkdownPlugin struct{ sdk.BasePlugin }
|
||||
|
||||
func (p *MarkdownPlugin) Metadata() sdk.PluginMetadata {
|
||||
return sdk.PluginMetadata{
|
||||
Name: "markdown", DisplayName: "Markdown Processor", Version: "1.0.0",
|
||||
Description: "Markdown processing: to HTML, extract text/links/code, generate TOC",
|
||||
Category: "format", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *MarkdownPlugin) Tools() []sdk.Tool { return []sdk.Tool{&MarkdownTool{}} }
|
||||
|
||||
type MarkdownTool struct{ sdk.BaseTool }
|
||||
|
||||
func (t *MarkdownTool) Definition() sdk.ToolDefinition {
|
||||
return sdk.ToolDefinition{
|
||||
ID: "markdown", Name: "markdown", DisplayName: "Markdown Processor",
|
||||
Description: "Markdown processing. Convert to HTML, extract plain text, extract links/code blocks, generate TOC.",
|
||||
Category: "format", Complexity: sdk.ComplexitySimple,
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{"type": "string", "enum": []string{"to_html", "to_text", "extract_links", "extract_code", "table_of_contents"}},
|
||||
"markdown": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []string{"action", "markdown"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MarkdownTool) Validate(args map[string]interface{}) error {
|
||||
for _, k := range []string{"action", "markdown"} {
|
||||
if _, ok := args[k]; !ok {
|
||||
return fmt.Errorf("missing required parameter: %s", k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *MarkdownTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
action, _ := args["action"].(string)
|
||||
md, _ := args["markdown"].(string)
|
||||
|
||||
switch action {
|
||||
case "to_html":
|
||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: mdToHTML(md)}, nil
|
||||
|
||||
case "to_text":
|
||||
text := md
|
||||
reCode := regexp.MustCompile("(?s)```.*?```")
|
||||
text = reCode.ReplaceAllString(text, "")
|
||||
text = regexp.MustCompile(`\*\*([^*]+)\*\*`).ReplaceAllString(text, "$1")
|
||||
text = regexp.MustCompile(`\*([^*]+)\*`).ReplaceAllString(text, "$1")
|
||||
text = regexp.MustCompile(`~~([^~]+)~~`).ReplaceAllString(text, "$1")
|
||||
text = regexp.MustCompile(`^#{1,6}\s+`).ReplaceAllString(text, "")
|
||||
text = regexp.MustCompile(`^[*-]\s+`).ReplaceAllString(text, "- ")
|
||||
text = regexp.MustCompile(`^>\s+`).ReplaceAllString(text, "")
|
||||
text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n")
|
||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: strings.TrimSpace(text)}, nil
|
||||
|
||||
case "extract_links":
|
||||
re := regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`)
|
||||
matches := re.FindAllStringSubmatch(md, -1)
|
||||
if len(matches) == 0 {
|
||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: "No links found"}, nil
|
||||
}
|
||||
var out strings.Builder
|
||||
for i, m := range matches {
|
||||
out.WriteString(fmt.Sprintf("%d. %s -> %s\n", i+1, m[1], m[2]))
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: out.String()}, nil
|
||||
|
||||
case "extract_code":
|
||||
re := regexp.MustCompile("(?s)```(\\w*)\n?(.*?)```")
|
||||
matches := re.FindAllStringSubmatch(md, -1)
|
||||
if len(matches) == 0 {
|
||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: "No code blocks found"}, nil
|
||||
}
|
||||
var out strings.Builder
|
||||
for i, m := range matches {
|
||||
lang := m[1]
|
||||
if lang == "" {
|
||||
lang = "text"
|
||||
}
|
||||
code := m[2]
|
||||
if len([]rune(code)) > 500 {
|
||||
code = string([]rune(code)[:500]) + "..."
|
||||
}
|
||||
out.WriteString(fmt.Sprintf("--- Block %d (%s) ---\n%s\n\n", i+1, lang, code))
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: out.String()}, nil
|
||||
|
||||
case "table_of_contents":
|
||||
re := regexp.MustCompile(`(?m)^(#{1,6})\s+(.+)$`)
|
||||
matches := re.FindAllStringSubmatch(md, -1)
|
||||
if len(matches) == 0 {
|
||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: "No headings found"}, nil
|
||||
}
|
||||
var out strings.Builder
|
||||
for _, m := range matches {
|
||||
depth := len(m[1])
|
||||
indent := strings.Repeat(" ", depth-1)
|
||||
out.WriteString(fmt.Sprintf("%s- %s\n", indent, m[2]))
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "markdown", Success: true, Output: out.String()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "markdown", Success: false, Error: "unknown action: " + action}, nil
|
||||
}
|
||||
|
||||
func mdToHTML(md string) string {
|
||||
// Save code blocks
|
||||
type placeholder struct {
|
||||
orig string
|
||||
content string
|
||||
language string
|
||||
}
|
||||
blocks := []*placeholder{}
|
||||
reCode := regexp.MustCompile("(?s)```(\\w*)\n?(.*?)```")
|
||||
md = reCode.ReplaceAllStringFunc(md, func(s string) string {
|
||||
m := reCode.FindStringSubmatch(s)
|
||||
b := &placeholder{orig: fmt.Sprintf("\x00CODE%d\x00", len(blocks)), language: m[1], content: escapeHTML(m[2])}
|
||||
blocks = append(blocks, b)
|
||||
return b.orig
|
||||
})
|
||||
|
||||
// Inline elements
|
||||
md = regexp.MustCompile("`([^`]+)`").ReplaceAllString(md, "<code>$1</code>")
|
||||
md = regexp.MustCompile(`!\[([^\]]*)\]\(([^)]+)\)`).ReplaceAllString(md, `<img src="$2" alt="$1">`)
|
||||
md = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`).ReplaceAllString(md, `<a href="$2">$1</a>`)
|
||||
md = regexp.MustCompile(`\*\*([^*]+)\*\*`).ReplaceAllString(md, `<strong>$1</strong>`)
|
||||
md = regexp.MustCompile(`\*([^*]+)\*`).ReplaceAllString(md, `<em>$1</em>`)
|
||||
md = regexp.MustCompile(`~~([^~]+)~~`).ReplaceAllString(md, `<del>$1</del>`)
|
||||
md = regexp.MustCompile(`(?m)^#{6}\s+(.+)$`).ReplaceAllString(md, `<h6>$1</h6>`)
|
||||
md = regexp.MustCompile(`(?m)^#{5}\s+(.+)$`).ReplaceAllString(md, `<h5>$1</h5>`)
|
||||
md = regexp.MustCompile(`(?m)^#{4}\s+(.+)$`).ReplaceAllString(md, `<h4>$1</h4>`)
|
||||
md = regexp.MustCompile(`(?m)^#{3}\s+(.+)$`).ReplaceAllString(md, `<h3>$1</h3>`)
|
||||
md = regexp.MustCompile(`(?m)^#{2}\s+(.+)$`).ReplaceAllString(md, `<h2>$1</h2>`)
|
||||
md = regexp.MustCompile(`(?m)^#{1}\s+(.+)$`).ReplaceAllString(md, `<h1>$1</h1>`)
|
||||
md = regexp.MustCompile(`(?m)^---\s*$`).ReplaceAllString(md, `<hr>`)
|
||||
md = regexp.MustCompile(`(?m)^>\s+(.+)$`).ReplaceAllString(md, `<blockquote>$1</blockquote>`)
|
||||
|
||||
// Restore code blocks
|
||||
for _, b := range blocks {
|
||||
langAttr := ""
|
||||
if b.language != "" {
|
||||
langAttr = " class=\"language-" + b.language + "\""
|
||||
}
|
||||
md = strings.Replace(md, b.orig, "<pre><code"+langAttr+">"+b.content+"</code></pre>", 1)
|
||||
}
|
||||
|
||||
// Paragraphs
|
||||
lines := strings.Split(md, "\n")
|
||||
var out strings.Builder
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "<") {
|
||||
out.WriteString(trimmed + "\n")
|
||||
} else {
|
||||
out.WriteString("<p>" + trimmed + "</p>\n")
|
||||
}
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func escapeHTML(s string) string {
|
||||
s = strings.ReplaceAll(s, "&", "&")
|
||||
s = strings.ReplaceAll(s, "<", "<")
|
||||
s = strings.ReplaceAll(s, ">", ">")
|
||||
return s
|
||||
}
|
||||
@@ -1,175 +0,0 @@
|
||||
package random
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
mathrand "math/rand"
|
||||
"strings"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
type RandomPlugin struct{ sdk.BasePlugin }
|
||||
|
||||
func (p *RandomPlugin) Metadata() sdk.PluginMetadata {
|
||||
return sdk.PluginMetadata{
|
||||
Name: "random", DisplayName: "Random Generator", Version: "1.0.0",
|
||||
Description: "Random generation: numbers, UUIDs, secure passwords, pick/shuffle",
|
||||
Category: "utility", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RandomPlugin) Tools() []sdk.Tool { return []sdk.Tool{&RandomTool{}} }
|
||||
|
||||
type RandomTool struct{ sdk.BaseTool }
|
||||
|
||||
func (t *RandomTool) Definition() sdk.ToolDefinition {
|
||||
return sdk.ToolDefinition{
|
||||
ID: "random", Name: "random", DisplayName: "Random Generator",
|
||||
Description: "Random generation. Random numbers, UUID v4, secure passwords, pick from list, shuffle list.",
|
||||
Category: "utility", Complexity: sdk.ComplexitySimple,
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{"type": "string", "enum": []string{"number", "uuid", "password", "pick", "shuffle"}},
|
||||
"min": map[string]interface{}{"type": "number"},
|
||||
"max": map[string]interface{}{"type": "number"},
|
||||
"length": map[string]interface{}{"type": "number"},
|
||||
"items": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
|
||||
"count": map[string]interface{}{"type": "number"},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *RandomTool) Validate(args map[string]interface{}) error {
|
||||
if _, ok := args["action"]; !ok {
|
||||
return fmt.Errorf("missing required parameter: action")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *RandomTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
action, _ := args["action"].(string)
|
||||
|
||||
switch action {
|
||||
case "number":
|
||||
min := getIntArg(args, "min", 0)
|
||||
max := getIntArg(args, "max", 100)
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(int64(max-min+1)))
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "random", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "random", Success: true,
|
||||
Output: fmt.Sprintf("%d", int(n.Int64())+min)}, nil
|
||||
|
||||
case "uuid":
|
||||
uuid := make([]byte, 16)
|
||||
rand.Read(uuid)
|
||||
uuid[6] = (uuid[6] & 0x0f) | 0x40
|
||||
uuid[8] = (uuid[8] & 0x3f) | 0x80
|
||||
return &sdk.ToolResult{ToolName: "random", Success: true,
|
||||
Output: fmt.Sprintf("%x-%x-%x-%x-%x", uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:])}, nil
|
||||
|
||||
case "password":
|
||||
length := getIntArg(args, "length", 16)
|
||||
if length < 4 {
|
||||
length = 4
|
||||
}
|
||||
if length > 128 {
|
||||
length = 128
|
||||
}
|
||||
upper := "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
lower := "abcdefghijklmnopqrstuvwxyz"
|
||||
digits := "0123456789"
|
||||
symbols := "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
all := upper + lower + digits + symbols
|
||||
bytes := make([]byte, length)
|
||||
for i := range bytes {
|
||||
idx, _ := rand.Int(rand.Reader, big.NewInt(int64(len(all))))
|
||||
bytes[i] = all[idx.Int64()]
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "random", Success: true, Output: string(bytes)}, nil
|
||||
|
||||
case "pick":
|
||||
items := getStringSliceArg(args, "items")
|
||||
if len(items) == 0 {
|
||||
return &sdk.ToolResult{ToolName: "random", Success: false, Error: "items list is empty"}, nil
|
||||
}
|
||||
count := getIntArg(args, "count", 1)
|
||||
if count > len(items) {
|
||||
count = len(items)
|
||||
}
|
||||
indices := shuffledIndices(len(items))
|
||||
picked := make([]string, count)
|
||||
for i := 0; i < count; i++ {
|
||||
picked[i] = items[indices[i]]
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "random", Success: true, Output: strings.Join(picked, ", ")}, nil
|
||||
|
||||
case "shuffle":
|
||||
items := getStringSliceArg(args, "items")
|
||||
indices := shuffledIndices(len(items))
|
||||
shuffled := make([]string, len(items))
|
||||
for i, idx := range indices {
|
||||
shuffled[i] = items[idx]
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "random", Success: true, Output: strings.Join(shuffled, ", ")}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "random", Success: false, Error: "unknown action: " + action}, nil
|
||||
}
|
||||
|
||||
func getIntArg(args map[string]interface{}, key string, defaultVal int) int {
|
||||
v, ok := args[key]
|
||||
if !ok {
|
||||
return defaultVal
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n)
|
||||
case int:
|
||||
return n
|
||||
case int64:
|
||||
return int(n)
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
func getStringSliceArg(args map[string]interface{}, key string) []string {
|
||||
v, ok := args[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch s := v.(type) {
|
||||
case []string:
|
||||
return s
|
||||
case []interface{}:
|
||||
result := make([]string, len(s))
|
||||
for i, item := range s {
|
||||
result[i] = fmt.Sprint(item)
|
||||
}
|
||||
return result
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func shuffledIndices(n int) []int {
|
||||
indices := make([]int, n)
|
||||
for i := range indices {
|
||||
indices[i] = i
|
||||
}
|
||||
for i := n - 1; i > 0; i-- {
|
||||
jBig, err := rand.Int(rand.Reader, big.NewInt(int64(i+1)))
|
||||
if err != nil {
|
||||
j := mathrand.Intn(i + 1)
|
||||
indices[i], indices[j] = indices[j], indices[i]
|
||||
continue
|
||||
}
|
||||
j := int(jBig.Int64())
|
||||
indices[i], indices[j] = indices[j], indices[i]
|
||||
}
|
||||
return indices
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// BasePlugin provides default implementations for optional Plugin methods.
|
||||
type BasePlugin struct{}
|
||||
|
||||
func (BasePlugin) Init(_ context.Context, _ PluginConfig) error { return nil }
|
||||
|
||||
func (BasePlugin) Start(_ context.Context, _ HostAPI) error { return nil }
|
||||
|
||||
func (BasePlugin) Stop(_ context.Context) error { return nil }
|
||||
|
||||
func (BasePlugin) Health(_ context.Context) error { return nil }
|
||||
|
||||
// BaseTool provides a Validate default that checks required parameters.
|
||||
type BaseTool struct {
|
||||
Def ToolDefinition
|
||||
Required []string
|
||||
}
|
||||
|
||||
func (b BaseTool) Definition() ToolDefinition { return b.Def }
|
||||
|
||||
func (b BaseTool) Complexity() ToolComplexity { return ComplexitySimple }
|
||||
|
||||
func (b BaseTool) Validate(args map[string]interface{}) error {
|
||||
for _, key := range b.Required {
|
||||
if _, ok := args[key]; !ok {
|
||||
return fmt.Errorf("missing required parameter: %s", key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b BaseTool) Execute(_ context.Context, _ map[string]interface{}) (*ToolResult, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package sdk
|
||||
|
||||
// PluginPermissions defines what a plugin is allowed to do.
|
||||
type PluginPermissions struct {
|
||||
NetworkAllowed bool `json:"networkAllowed"`
|
||||
AllowedHosts []string `json:"allowedHosts,omitempty"`
|
||||
IoTRead bool `json:"iotRead"`
|
||||
IoTWrite bool `json:"iotWrite"`
|
||||
MemoryRead bool `json:"memoryRead"`
|
||||
MemoryWrite bool `json:"memoryWrite"`
|
||||
FileRead bool `json:"fileRead"`
|
||||
FileWrite bool `json:"fileWrite"`
|
||||
AllowedPaths []string `json:"allowedPaths,omitempty"`
|
||||
ExecAllowed bool `json:"execAllowed"`
|
||||
MaxCPUPercent float64 `json:"maxCPUPercent"`
|
||||
MaxMemoryMB int `json:"maxMemoryMB"`
|
||||
}
|
||||
|
||||
// DefaultPermissions returns a safe default permission set.
|
||||
func DefaultPermissions() PluginPermissions {
|
||||
return PluginPermissions{
|
||||
NetworkAllowed: false,
|
||||
AllowedHosts: []string{},
|
||||
IoTRead: false,
|
||||
IoTWrite: false,
|
||||
MemoryRead: false,
|
||||
MemoryWrite: false,
|
||||
FileRead: false,
|
||||
FileWrite: false,
|
||||
AllowedPaths: []string{},
|
||||
ExecAllowed: false,
|
||||
MaxCPUPercent: 10.0,
|
||||
MaxMemoryMB: 128,
|
||||
}
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Plugin is the main interface every plugin must implement.
|
||||
type Plugin interface {
|
||||
Metadata() PluginMetadata
|
||||
Init(ctx context.Context, config PluginConfig) error
|
||||
Start(ctx context.Context, host HostAPI) error
|
||||
Stop(ctx context.Context) error
|
||||
Health(ctx context.Context) error
|
||||
Tools() []Tool
|
||||
}
|
||||
|
||||
// Tool is the interface every tool must implement.
|
||||
type Tool interface {
|
||||
Definition() ToolDefinition
|
||||
Execute(ctx context.Context, args map[string]interface{}) (*ToolResult, error)
|
||||
Validate(args map[string]interface{}) error
|
||||
Complexity() ToolComplexity
|
||||
}
|
||||
|
||||
// ComplexTool extends Tool for async multi-round execution.
|
||||
type ComplexTool interface {
|
||||
Tool
|
||||
ExecuteAsync(ctx context.Context, args map[string]interface{}) (<-chan ToolProgress, error)
|
||||
Cancel(ctx context.Context, executionID string) error
|
||||
}
|
||||
|
||||
// HostAPI gives plugins access to Cyrene core capabilities.
|
||||
type HostAPI interface {
|
||||
CallLLM(ctx context.Context, messages []LLMMessage) (*LLMResponse, error)
|
||||
SearchMemory(ctx context.Context, userID, query string, limit int) ([]MemoryEntry, error)
|
||||
StoreMemory(ctx context.Context, entry MemoryEntry) error
|
||||
Logger() Logger
|
||||
GetConfig(key string) (string, error)
|
||||
SetConfig(key, value string) error
|
||||
PublishEvent(ctx context.Context, event map[string]interface{}) error
|
||||
HTTPClient() *http.Client
|
||||
}
|
||||
|
||||
// Logger is a minimal logging interface for plugins.
|
||||
type Logger interface {
|
||||
Printf(format string, args ...interface{})
|
||||
Println(args ...interface{})
|
||||
}
|
||||
@@ -1,134 +0,0 @@
|
||||
package sdk
|
||||
|
||||
import "time"
|
||||
|
||||
// ToolComplexity grades tools into simple (single-call, <2s) and complex (multi-round, async).
|
||||
type ToolComplexity string
|
||||
|
||||
const (
|
||||
ComplexitySimple ToolComplexity = "simple"
|
||||
ComplexityComplex ToolComplexity = "complex"
|
||||
)
|
||||
|
||||
// PluginMetadata describes a plugin's identity and requirements.
|
||||
type PluginMetadata struct {
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Version string `json:"version"`
|
||||
MinCyreneVersion string `json:"minCyreneVersion"`
|
||||
Author PluginAuthor `json:"author"`
|
||||
Description string `json:"description"`
|
||||
License string `json:"license"`
|
||||
Keywords []string `json:"keywords,omitempty"`
|
||||
Category string `json:"category"`
|
||||
Dependencies map[string]string `json:"dependencies,omitempty"` // plugin name -> version range
|
||||
Homepage string `json:"homepage,omitempty"`
|
||||
Repository string `json:"repository,omitempty"`
|
||||
}
|
||||
|
||||
type PluginAuthor struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
||||
// PluginConfig holds runtime configuration for a plugin.
|
||||
type PluginConfig map[string]interface{}
|
||||
|
||||
// ToolDefinition describes a tool's interface for LLM function calling.
|
||||
type ToolDefinition struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Description string `json:"description"`
|
||||
Category string `json:"category"`
|
||||
Complexity ToolComplexity `json:"complexity"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
Returns map[string]interface{} `json:"returns,omitempty"`
|
||||
TimeoutMs int `json:"timeout_ms,omitempty"`
|
||||
MaxRetries int `json:"max_retries,omitempty"`
|
||||
DangerLevel string `json:"danger_level,omitempty"` // low / medium / high
|
||||
}
|
||||
|
||||
// ToolResult is the standard tool execution result.
|
||||
type ToolResult struct {
|
||||
ToolName string `json:"tool_name"`
|
||||
Success bool `json:"success"`
|
||||
Output string `json:"output,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
DurationMs int64 `json:"duration_ms,omitempty"`
|
||||
}
|
||||
|
||||
// ToolProgress reports execution progress for complex (async) tools.
|
||||
type ToolProgress struct {
|
||||
ExecutionID string `json:"execution_id"`
|
||||
Status string `json:"status"` // started / running / completed / failed / cancelled
|
||||
Progress float64 `json:"progress"` // 0.0 - 1.0
|
||||
Message string `json:"message,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Result *ToolResult `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
// PluginStatus represents the current lifecycle state of a plugin.
|
||||
type PluginStatus string
|
||||
|
||||
const (
|
||||
StatusInstalled PluginStatus = "installed"
|
||||
StatusLoaded PluginStatus = "loaded"
|
||||
StatusRunning PluginStatus = "running"
|
||||
StatusPaused PluginStatus = "paused"
|
||||
StatusError PluginStatus = "error"
|
||||
StatusDisabled PluginStatus = "disabled"
|
||||
)
|
||||
|
||||
// PluginInfo is the runtime view of an installed plugin.
|
||||
type PluginInfo struct {
|
||||
Metadata PluginMetadata `json:"metadata"`
|
||||
Status PluginStatus `json:"status"`
|
||||
Tools []string `json:"tools"` // tool IDs provided by this plugin
|
||||
InstalledAt time.Time `json:"installed_at"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// LLMMessage is a message in an LLM conversation.
|
||||
type LLMMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// LLMResponse is the result of an LLM call.
|
||||
type LLMResponse struct {
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCall represents a tool call requested by the LLM.
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
// IoTDeviceState is the shared device state across IoT plugins.
|
||||
type IoTDeviceState struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Brightness int `json:"brightness,omitempty"`
|
||||
Color string `json:"color,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Position int `json:"position,omitempty"`
|
||||
Value float64 `json:"value,omitempty"`
|
||||
Unit string `json:"unit,omitempty"`
|
||||
Battery int `json:"battery,omitempty"`
|
||||
}
|
||||
|
||||
// MemoryEntry is a memory record.
|
||||
type MemoryEntry struct {
|
||||
UserID string `json:"user_id"`
|
||||
Content string `json:"content"`
|
||||
Type string `json:"type"`
|
||||
Meta map[string]interface{} `json:"meta,omitempty"`
|
||||
}
|
||||
@@ -1,177 +0,0 @@
|
||||
package text
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
type TextPlugin struct{ sdk.BasePlugin }
|
||||
|
||||
func (p *TextPlugin) Metadata() sdk.PluginMetadata {
|
||||
return sdk.PluginMetadata{
|
||||
Name: "text", DisplayName: "Text Processing", Version: "1.0.0",
|
||||
Description: "Text processing: count stats, summarize, regex extract",
|
||||
Category: "utility", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TextPlugin) Tools() []sdk.Tool { return []sdk.Tool{&TextTool{}} }
|
||||
|
||||
type TextTool struct{ sdk.BaseTool }
|
||||
|
||||
func (t *TextTool) Definition() sdk.ToolDefinition {
|
||||
return sdk.ToolDefinition{
|
||||
ID: "text", Name: "text", DisplayName: "Text Processing",
|
||||
Description: "Text processing. Count stats, summarize, translate, regex extract.",
|
||||
Category: "utility", Complexity: sdk.ComplexitySimple,
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{"type": "string", "enum": []string{"count", "summarize", "translate", "extract"}},
|
||||
"text": map[string]interface{}{"type": "string"},
|
||||
"target_lang": map[string]interface{}{"type": "string", "enum": []string{"en", "zh", "ja", "ko", "fr", "de"}},
|
||||
"pattern": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"required": []string{"action", "text"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TextTool) Validate(args map[string]interface{}) error {
|
||||
for _, k := range []string{"action", "text"} {
|
||||
if _, ok := args[k]; !ok {
|
||||
return fmt.Errorf("missing required parameter: %s", k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TextTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
action, _ := args["action"].(string)
|
||||
txt, _ := args["text"].(string)
|
||||
|
||||
switch action {
|
||||
case "count":
|
||||
charsNoSpace := 0
|
||||
chineseChars := 0
|
||||
for _, r := range txt {
|
||||
if !unicode.IsSpace(r) {
|
||||
charsNoSpace++
|
||||
}
|
||||
if unicode.Is(unicode.Han, r) {
|
||||
chineseChars++
|
||||
}
|
||||
}
|
||||
words := strings.Fields(txt)
|
||||
lines := strings.Split(txt, "\n")
|
||||
paragraphs := regexp.MustCompile(`\n\s*\n`).Split(txt, -1)
|
||||
return &sdk.ToolResult{ToolName: "text", Success: true, Output: fmt.Sprintf(
|
||||
"Characters: %d (no spaces: %d, Chinese: %d)\nBytes: %d\nWords: %d\nLines: %d\nParagraphs: %d",
|
||||
len([]rune(txt)), charsNoSpace, chineseChars, len(txt), len(words), len(lines), len(paragraphs))}, nil
|
||||
|
||||
case "summarize":
|
||||
paragraphs := regexp.MustCompile(`\n\s*\n`).Split(txt, -1)
|
||||
firstPara := ""
|
||||
if len(paragraphs) > 0 {
|
||||
runes := []rune(paragraphs[0])
|
||||
if len(runes) > 300 {
|
||||
runes = runes[:300]
|
||||
}
|
||||
firstPara = string(runes)
|
||||
}
|
||||
sentences := regexp.MustCompile(`[。!?.!?]+`).Split(txt, -1)
|
||||
keywords := []string{"重要", "关键", "因此", "总结", "important", "key", "conclusion", "therefore"}
|
||||
type scored struct {
|
||||
text string
|
||||
score int
|
||||
}
|
||||
var scoredSents []scored
|
||||
for _, s := range sentences {
|
||||
s = strings.TrimSpace(s)
|
||||
if len([]rune(s)) < 10 {
|
||||
continue
|
||||
}
|
||||
score := len([]rune(s))
|
||||
for _, kw := range keywords {
|
||||
if strings.Contains(strings.ToLower(s), strings.ToLower(kw)) {
|
||||
score += 20
|
||||
}
|
||||
}
|
||||
scoredSents = append(scoredSents, scored{s, score})
|
||||
}
|
||||
var out strings.Builder
|
||||
out.WriteString(fmt.Sprintf("First paragraph: %s\n\nKey sentences:\n", firstPara))
|
||||
count := 0
|
||||
for i := 0; i < len(scoredSents) && count < 5; i++ {
|
||||
out.WriteString(fmt.Sprintf("- %s\n", scoredSents[i].text))
|
||||
count++
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "text", Success: true, Output: out.String()}, nil
|
||||
|
||||
case "translate":
|
||||
targetLang, _ := args["target_lang"].(string)
|
||||
if targetLang == "" {
|
||||
targetLang = "en"
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "text", Success: true, Output: fmt.Sprintf(
|
||||
"[Translation request] Please translate the following text to %s.\n\nOriginal text:\n%s", targetLang, txt)}, nil
|
||||
|
||||
case "extract":
|
||||
pattern, _ := args["pattern"].(string)
|
||||
var out strings.Builder
|
||||
extracted := false
|
||||
if pattern == "" || pattern == "email" {
|
||||
re := regexp.MustCompile(`[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}`)
|
||||
if matches := re.FindAllString(txt, -1); len(matches) > 0 {
|
||||
out.WriteString("Emails:\n")
|
||||
for _, m := range matches {
|
||||
out.WriteString(fmt.Sprintf("- %s\n", m))
|
||||
}
|
||||
extracted = true
|
||||
}
|
||||
}
|
||||
if pattern == "" || pattern == "phone" {
|
||||
re := regexp.MustCompile(`1[3-9]\d{9}`)
|
||||
if matches := re.FindAllString(txt, -1); len(matches) > 0 {
|
||||
out.WriteString("Phone numbers:\n")
|
||||
for _, m := range matches {
|
||||
out.WriteString(fmt.Sprintf("- %s\n", m))
|
||||
}
|
||||
extracted = true
|
||||
}
|
||||
}
|
||||
if pattern == "" || pattern == "url" {
|
||||
re := regexp.MustCompile(`https?://[^\s<>"{}|\\^` + "`" + `\[\]]+`)
|
||||
if matches := re.FindAllString(txt, -1); len(matches) > 0 {
|
||||
out.WriteString("URLs:\n")
|
||||
for _, m := range matches {
|
||||
out.WriteString(fmt.Sprintf("- %s\n", m))
|
||||
}
|
||||
extracted = true
|
||||
}
|
||||
}
|
||||
if !extracted && pattern != "" && pattern != "email" && pattern != "phone" && pattern != "url" {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "text", Success: false, Error: "Invalid regex: " + err.Error()}, nil
|
||||
}
|
||||
if matches := re.FindAllString(txt, -1); len(matches) > 0 {
|
||||
out.WriteString(fmt.Sprintf("Pattern matches (%s):\n", pattern))
|
||||
for _, m := range matches {
|
||||
out.WriteString(fmt.Sprintf("- %s\n", m))
|
||||
}
|
||||
extracted = true
|
||||
}
|
||||
}
|
||||
if !extracted {
|
||||
return &sdk.ToolResult{ToolName: "text", Success: true, Output: "No matches found"}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "text", Success: true, Output: out.String()}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "text", Success: false, Error: "unknown action: " + action}, nil
|
||||
}
|
||||
@@ -1,113 +0,0 @@
|
||||
package webfetch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
type WebFetchPlugin struct {
|
||||
sdk.BasePlugin
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewWebFetchPlugin() *WebFetchPlugin {
|
||||
return &WebFetchPlugin{client: &http.Client{Timeout: 15 * time.Second}}
|
||||
}
|
||||
|
||||
func (p *WebFetchPlugin) Metadata() sdk.PluginMetadata {
|
||||
return sdk.PluginMetadata{
|
||||
Name: "web_fetch", DisplayName: "Web Fetch", Version: "1.0.0",
|
||||
Description: "Fetch and extract text content from URLs",
|
||||
Category: "network", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WebFetchPlugin) Tools() []sdk.Tool { return []sdk.Tool{&WebFetchTool{client: p.client}} }
|
||||
|
||||
type WebFetchTool struct {
|
||||
sdk.BaseTool
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) Definition() sdk.ToolDefinition {
|
||||
return sdk.ToolDefinition{
|
||||
ID: "web_fetch", Name: "web_fetch", DisplayName: "Web Fetch",
|
||||
Description: "Fetch content of a specified URL. Returns plain text summary (first 2000 characters). HTTP/HTTPS only.",
|
||||
Category: "network", Complexity: sdk.ComplexitySimple,
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{"url": map[string]interface{}{"type": "string"}},
|
||||
"required": []string{"url"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) Validate(args map[string]interface{}) error {
|
||||
if _, ok := args["url"]; !ok {
|
||||
return fmt.Errorf("missing required parameter: url")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
urlStr, _ := args["url"].(string)
|
||||
if !strings.HasPrefix(urlStr, "http://") && !strings.HasPrefix(urlStr, "https://") {
|
||||
return &sdk.ToolResult{ToolName: "web_fetch", Success: false, Error: "only http/https URLs allowed"}, nil
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest("GET", urlStr, nil)
|
||||
req.Header.Set("User-Agent", "CyreneBot/1.0")
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "web_fetch", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, 100*1024))
|
||||
text := stripHTMLFull(string(bodyBytes))
|
||||
text = removeBlankLines(text)
|
||||
runes := []rune(text)
|
||||
if len(runes) > 2000 {
|
||||
text = string(runes[:2000]) + "..."
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "web_fetch", Success: true, Output: fmt.Sprintf(
|
||||
"URL: %s\nStatus: %d\nContent-Type: %s\n\n%s",
|
||||
urlStr, resp.StatusCode, resp.Header.Get("Content-Type"), text)}, nil
|
||||
}
|
||||
|
||||
func stripHTMLFull(s string) string {
|
||||
result := make([]rune, 0, len([]rune(s)))
|
||||
inTag := false
|
||||
for _, r := range s {
|
||||
if r == '<' {
|
||||
inTag = true
|
||||
continue
|
||||
}
|
||||
if r == '>' {
|
||||
inTag = false
|
||||
continue
|
||||
}
|
||||
if !inTag {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func removeBlankLines(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
var result []string
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return strings.Join(result, "\n")
|
||||
}
|
||||
@@ -1,239 +0,0 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
type WebSearchPlugin struct {
|
||||
sdk.BasePlugin
|
||||
client *http.Client
|
||||
searxngURL string
|
||||
}
|
||||
|
||||
func NewWebSearchPlugin() *WebSearchPlugin {
|
||||
return &WebSearchPlugin{client: &http.Client{Timeout: 10 * time.Second}}
|
||||
}
|
||||
|
||||
func NewWebSearchPluginWithURL(searxngURL string) *WebSearchPlugin {
|
||||
return &WebSearchPlugin{
|
||||
client: &http.Client{Timeout: 10 * time.Second},
|
||||
searxngURL: strings.TrimRight(searxngURL, "/"),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WebSearchPlugin) Metadata() sdk.PluginMetadata {
|
||||
return sdk.PluginMetadata{
|
||||
Name: "web_search", DisplayName: "Web Search", Version: "1.1.0",
|
||||
Description: "Search the internet via SearXNG (or DuckDuckGo fallback)",
|
||||
Category: "network", Author: sdk.PluginAuthor{Name: "Cyrene Team"},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WebSearchPlugin) Tools() []sdk.Tool {
|
||||
return []sdk.Tool{&WebSearchTool{client: p.client, searxngURL: p.searxngURL}}
|
||||
}
|
||||
|
||||
type WebSearchTool struct {
|
||||
sdk.BaseTool
|
||||
client *http.Client
|
||||
searxngURL string
|
||||
}
|
||||
|
||||
// ---- SearXNG response types ----
|
||||
type searxngResponse struct {
|
||||
Query string `json:"query"`
|
||||
NumberOrResults int `json:"number_of_results"`
|
||||
Results []searxngResult `json:"results"`
|
||||
Answers []string `json:"answers"`
|
||||
Suggestions []string `json:"suggestions"`
|
||||
}
|
||||
|
||||
type searxngResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
Engine string `json:"engine"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
|
||||
// ---- DuckDuckGo response types (fallback) ----
|
||||
type ddgResponse struct {
|
||||
Abstract string `json:"Abstract"`
|
||||
AbstractText string `json:"AbstractText"`
|
||||
Answer string `json:"Answer"`
|
||||
Heading string `json:"Heading"`
|
||||
Results []ddgTopic `json:"Results"`
|
||||
RelatedTopics []ddgTopic `json:"RelatedTopics"`
|
||||
}
|
||||
|
||||
type ddgTopic struct {
|
||||
FirstURL string `json:"FirstURL"`
|
||||
Text string `json:"Text"`
|
||||
}
|
||||
|
||||
func (t *WebSearchTool) Definition() sdk.ToolDefinition {
|
||||
return sdk.ToolDefinition{
|
||||
ID: "web_search", Name: "web_search", DisplayName: "Web Search",
|
||||
Description: "Search the internet. SearXNG backend with DuckDuckGo fallback. Returns up to 5 results.",
|
||||
Category: "network", Complexity: sdk.ComplexitySimple,
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{"query": map[string]interface{}{"type": "string"}},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebSearchTool) Validate(args map[string]interface{}) error {
|
||||
if _, ok := args["query"]; !ok {
|
||||
return fmt.Errorf("missing required parameter: query")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *WebSearchTool) Execute(_ context.Context, args map[string]interface{}) (*sdk.ToolResult, error) {
|
||||
query, _ := args["query"].(string)
|
||||
if query == "" {
|
||||
return &sdk.ToolResult{ToolName: "web_search", Success: false, Error: "empty query"}, nil
|
||||
}
|
||||
|
||||
if t.searxngURL != "" {
|
||||
return t.searchViaSearXNG(query)
|
||||
}
|
||||
return t.searchViaDuckDuckGo(query)
|
||||
}
|
||||
|
||||
// China-accessible SearXNG engines (baidu, sogou, 360search, bing all work from China)
|
||||
const searxngEngines = "bing,sogou,360search,baidu"
|
||||
|
||||
func (t *WebSearchTool) searchViaSearXNG(query string) (*sdk.ToolResult, error) {
|
||||
apiURL := fmt.Sprintf("%s/search?format=json&engines=%s&q=%s",
|
||||
t.searxngURL, searxngEngines, url.QueryEscape(query))
|
||||
|
||||
resp, err := t.client.Get(apiURL)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "web_search", Success: false,
|
||||
Error: fmt.Sprintf("SearXNG request failed: %v", err)}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return &sdk.ToolResult{ToolName: "web_search", Success: false,
|
||||
Error: fmt.Sprintf("SearXNG returned HTTP %d", resp.StatusCode)}, nil
|
||||
}
|
||||
|
||||
var result searxngResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "web_search", Success: false,
|
||||
Error: fmt.Sprintf("SearXNG parse error: %v", err)}, nil
|
||||
}
|
||||
|
||||
var out strings.Builder
|
||||
out.WriteString(fmt.Sprintf("搜索: %s (共%d条结果)\n\n", query, result.NumberOrResults))
|
||||
|
||||
// 优先显示答案(如 Wikipedia infobox)
|
||||
for _, answer := range result.Answers {
|
||||
out.WriteString(fmt.Sprintf("📌 %s\n\n", answer))
|
||||
}
|
||||
|
||||
// 搜索结果(最多5条,按score排序)
|
||||
count := 0
|
||||
for _, r := range result.Results {
|
||||
if count >= 5 {
|
||||
break
|
||||
}
|
||||
if r.Title == "" || r.URL == "" {
|
||||
continue
|
||||
}
|
||||
content := cleanSnippet(r.Content)
|
||||
out.WriteString(fmt.Sprintf("%d. **%s**\n %s\n %s\n\n", count+1, r.Title, r.URL, content))
|
||||
count++
|
||||
}
|
||||
|
||||
if out.Len() == 0 {
|
||||
return &sdk.ToolResult{ToolName: "web_search", Success: true,
|
||||
Output: fmt.Sprintf("未找到与「%s」相关的结果。", query)}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "web_search", Success: true, Output: out.String()}, nil
|
||||
}
|
||||
|
||||
func (t *WebSearchTool) searchViaDuckDuckGo(query string) (*sdk.ToolResult, error) {
|
||||
apiURL := fmt.Sprintf("https://api.duckduckgo.com/?q=%s&format=json&no_html=1", url.QueryEscape(query))
|
||||
resp, err := t.client.Get(apiURL)
|
||||
if err != nil {
|
||||
return &sdk.ToolResult{ToolName: "web_search", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result ddgResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return &sdk.ToolResult{ToolName: "web_search", Success: false, Error: err.Error()}, nil
|
||||
}
|
||||
|
||||
var out strings.Builder
|
||||
if result.Answer != "" {
|
||||
out.WriteString(fmt.Sprintf("Answer: %s\n\n", result.Answer))
|
||||
}
|
||||
if result.AbstractText != "" {
|
||||
text := result.AbstractText
|
||||
if len([]rune(text)) > 500 {
|
||||
text = string([]rune(text)[:500]) + "..."
|
||||
}
|
||||
out.WriteString(fmt.Sprintf("Abstract: %s\n\n", stripHTML(text)))
|
||||
}
|
||||
topics := result.Results
|
||||
if len(topics) == 0 {
|
||||
topics = result.RelatedTopics
|
||||
}
|
||||
count := 0
|
||||
for _, topic := range topics {
|
||||
if count >= 5 {
|
||||
break
|
||||
}
|
||||
if topic.Text == "" {
|
||||
continue
|
||||
}
|
||||
out.WriteString(fmt.Sprintf("%d. %s (%s)\n", count+1, stripHTML(topic.Text), topic.FirstURL))
|
||||
count++
|
||||
}
|
||||
if out.Len() == 0 {
|
||||
return &sdk.ToolResult{ToolName: "web_search", Success: true,
|
||||
Output: "No results found for: " + query}, nil
|
||||
}
|
||||
return &sdk.ToolResult{ToolName: "web_search", Success: true, Output: out.String()}, nil
|
||||
}
|
||||
|
||||
func cleanSnippet(s string) string {
|
||||
runes := []rune(strings.TrimSpace(s))
|
||||
if len(runes) > 200 {
|
||||
return string(runes[:200]) + "..."
|
||||
}
|
||||
return string(runes)
|
||||
}
|
||||
|
||||
func stripHTML(s string) string {
|
||||
result := make([]rune, 0, len([]rune(s)))
|
||||
inTag := false
|
||||
for _, r := range s {
|
||||
if r == '<' {
|
||||
inTag = true
|
||||
continue
|
||||
}
|
||||
if r == '>' {
|
||||
inTag = false
|
||||
continue
|
||||
}
|
||||
if !inTag {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(string(result))
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
# ========== 构建阶段 ==========
|
||||
FROM golang:1.26-alpine AS builder
|
||||
|
||||
RUN apk add --no-cache git ca-certificates
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY backend/platform-bridge/ ./backend/platform-bridge/
|
||||
|
||||
WORKDIR /app/backend/platform-bridge
|
||||
ENV GOPROXY=https://goproxy.cn,direct
|
||||
RUN go mod download
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /platform-bridge ./cmd/main.go
|
||||
|
||||
# ========== 运行阶段 ==========
|
||||
FROM alpine:3.20
|
||||
|
||||
RUN apk add --no-cache ca-certificates tzdata && \
|
||||
cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \
|
||||
echo "Asia/Shanghai" > /etc/timezone
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /platform-bridge .
|
||||
|
||||
RUN mkdir -p logs && adduser -D -H cyrene && chown -R cyrene:cyrene /app
|
||||
USER cyrene
|
||||
|
||||
EXPOSE 8095
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD wget --no-verbose --tries=1 --spider http://localhost:8095/health || exit 1
|
||||
|
||||
ENTRYPOINT ["./platform-bridge"]
|
||||
+866
-128
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,11 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -17,31 +21,126 @@ var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// Adapter implements PlatformAdapter for QQ via OBv11 WebSocket.
|
||||
// Adapter implements PlatformAdapter for QQ via OneBot v11 WebSocket.
|
||||
// Supports two modes:
|
||||
// - "server" (正向 WS): adapter starts a WS server, NapCat connects as client.
|
||||
// - "client" (反向 WS): adapter connects to NapCat's WS server as a client.
|
||||
type Adapter struct {
|
||||
port string
|
||||
conn *websocket.Conn
|
||||
connMu sync.Mutex
|
||||
connected bool
|
||||
configName string // instance name, e.g. "qq-home"
|
||||
mode string // "client" or "server"
|
||||
port string
|
||||
accessToken string
|
||||
remoteURL string // NapCat OneBot WS server URL, used in client mode
|
||||
sendIntervalMs int // minimum interval between consecutive messages
|
||||
selfID string // bot's own QQ number, populated from incoming messages
|
||||
conn *websocket.Conn
|
||||
connMu sync.Mutex
|
||||
connected bool
|
||||
connectedAt time.Time // when the connection was established (for historical message detection)
|
||||
srv *http.Server // HTTP server for WS upgrades (server mode only)
|
||||
|
||||
groupNames map[int64]string // group ID → group name cache
|
||||
groupNamesMu sync.RWMutex
|
||||
|
||||
// Pending API call responses.
|
||||
pendingResponses map[string]chan *OBv11APIResponse
|
||||
respMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAdapter(port string) *Adapter {
|
||||
func NewAdapter(configName, mode, port, accessToken, remoteURL string, sendIntervalMs int) *Adapter {
|
||||
if mode == "" {
|
||||
mode = "server"
|
||||
}
|
||||
return &Adapter{
|
||||
configName: configName,
|
||||
mode: mode,
|
||||
port: port,
|
||||
accessToken: accessToken,
|
||||
remoteURL: remoteURL,
|
||||
sendIntervalMs: sendIntervalMs,
|
||||
pendingResponses: make(map[string]chan *OBv11APIResponse),
|
||||
groupNames: make(map[int64]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adapter) PlatformName() string { return "qq" }
|
||||
func (a *Adapter) PlatformName() string { return "qq" }
|
||||
func (a *Adapter) ConfigName() string { return a.configName }
|
||||
func (a *Adapter) SendIntervalMs() int { return a.sendIntervalMs }
|
||||
func (a *Adapter) SelfID() string { return a.selfID }
|
||||
|
||||
// ConnectedAt returns the time the connection was established, for historical message detection.
|
||||
func (a *Adapter) ConnectedAt() time.Time {
|
||||
a.connMu.Lock()
|
||||
defer a.connMu.Unlock()
|
||||
return a.connectedAt
|
||||
}
|
||||
|
||||
// GroupName resolves a group ID to its name. Returns empty string if unknown.
|
||||
func (a *Adapter) GroupName(groupID int64) string {
|
||||
a.groupNamesMu.RLock()
|
||||
n, ok := a.groupNames[groupID]
|
||||
a.groupNamesMu.RUnlock()
|
||||
if ok {
|
||||
return n
|
||||
}
|
||||
// Try to resolve via API.
|
||||
a.fetchGroupName(groupID)
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetGroupName caches a group name for a group ID.
|
||||
func (a *Adapter) SetGroupName(groupID int64, name string) {
|
||||
a.groupNamesMu.Lock()
|
||||
a.groupNames[groupID] = name
|
||||
a.groupNamesMu.Unlock()
|
||||
}
|
||||
|
||||
// fetchGroupName tries to resolve a group name via NapCat HTTP API (client mode).
|
||||
func (a *Adapter) fetchGroupName(groupID int64) {
|
||||
if a.mode != "client" || a.remoteURL == "" {
|
||||
return
|
||||
}
|
||||
// Derive HTTP base from WS URL: ws://host:port → http://host:port
|
||||
httpBase := strings.Replace(a.remoteURL, "ws://", "http://", 1)
|
||||
httpBase = strings.Replace(httpBase, "wss://", "https://", 1)
|
||||
// Strip path suffix if present
|
||||
if idx := strings.LastIndex(httpBase, "/"); idx > 8 {
|
||||
httpBase = httpBase[:idx]
|
||||
}
|
||||
|
||||
go func() {
|
||||
url := fmt.Sprintf("%s/get_group_info?group_id=%d", httpBase, groupID)
|
||||
if a.accessToken != "" {
|
||||
url += "&access_token=" + a.accessToken
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var result struct {
|
||||
Data struct {
|
||||
GroupName string `json:"group_name"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return
|
||||
}
|
||||
if result.Data.GroupName != "" {
|
||||
a.SetGroupName(groupID, result.Data.GroupName)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (a *Adapter) Capabilities() bridge.PlatformCapabilities {
|
||||
return bridge.PlatformCapabilities{
|
||||
MaxMessageLength: 200,
|
||||
SupportsMarkdown: true, // QQ supports basic markdown
|
||||
SupportsMarkdown: true,
|
||||
SupportsImage: true,
|
||||
SupportsVoice: false,
|
||||
SupportsEmoji: true,
|
||||
@@ -51,38 +150,103 @@ func (a *Adapter) Capabilities() bridge.PlatformCapabilities {
|
||||
}
|
||||
}
|
||||
|
||||
// checkAuth 验证 WebSocket 升级请求的 access_token。
|
||||
// NapCat 通过两种方式传递: query 参数 ?access_token=xxx 或 Authorization: Bearer xxx 头.
|
||||
func (a *Adapter) checkAuth(r *http.Request) bool {
|
||||
if a.accessToken == "" {
|
||||
return true // 未配置 token,允许所有连接
|
||||
}
|
||||
// 1) query 参数
|
||||
if r.URL.Query().Get("access_token") == a.accessToken {
|
||||
return true
|
||||
}
|
||||
// 2) Authorization: Bearer <token>
|
||||
auth := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(auth, "Bearer ") && strings.TrimPrefix(auth, "Bearer ") == a.accessToken {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// wsHandler 统一的 WebSocket 连接处理 — 单连接承载 API 调用 + 事件推送 (OneBot 正向 WS).
|
||||
func (a *Adapter) wsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if !a.checkAuth(r) {
|
||||
http.Error(w, "Forbidden: invalid access_token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("[qq] upgrade error: %v\n", err)
|
||||
return
|
||||
}
|
||||
a.connMu.Lock()
|
||||
// 关闭旧连接 (NapCat 重连)
|
||||
if a.conn != nil {
|
||||
a.conn.Close()
|
||||
}
|
||||
a.conn = conn
|
||||
a.connected = true
|
||||
a.connectedAt = time.Now()
|
||||
a.connMu.Unlock()
|
||||
fmt.Println("[qq] NapCat/OneBot connected (正向WS)")
|
||||
}
|
||||
|
||||
// legacyHandler 兼容旧的路径 /ws/qq 和 /ws/qq/event.
|
||||
func (a *Adapter) legacyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Printf("[qq] legacy WS path %s connected (consider changing NapCat URL to root /)\n", r.URL.Path)
|
||||
a.wsHandler(w, r)
|
||||
}
|
||||
|
||||
func (a *Adapter) Connect(ctx context.Context) error {
|
||||
if a.mode == "client" {
|
||||
return a.connectClient(ctx)
|
||||
}
|
||||
return a.connectServer()
|
||||
}
|
||||
|
||||
func (a *Adapter) connectClient(ctx context.Context) error {
|
||||
url := a.remoteURL
|
||||
if a.accessToken != "" {
|
||||
sep := "?"
|
||||
if strings.Contains(url, "?") {
|
||||
sep = "&"
|
||||
}
|
||||
url += sep + "access_token=" + a.accessToken
|
||||
}
|
||||
|
||||
dialer := websocket.DefaultDialer
|
||||
conn, _, err := dialer.DialContext(ctx, url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial NapCat WS %s: %w", url, err)
|
||||
}
|
||||
|
||||
a.connMu.Lock()
|
||||
if a.conn != nil {
|
||||
a.conn.Close()
|
||||
}
|
||||
a.conn = conn
|
||||
a.connected = true
|
||||
a.connectedAt = time.Now()
|
||||
a.connMu.Unlock()
|
||||
|
||||
fmt.Printf("[qq] connected to NapCat OneBot WS (client mode): %s\n", a.remoteURL)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adapter) connectServer() error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws/qq", func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("[qq] upgrade error: %v\n", err)
|
||||
return
|
||||
}
|
||||
a.connMu.Lock()
|
||||
a.conn = conn
|
||||
a.connected = true
|
||||
a.connMu.Unlock()
|
||||
fmt.Println("[qq] bot connected")
|
||||
})
|
||||
mux.HandleFunc("/ws/qq/event", func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("[qq] event upgrade error: %v\n", err)
|
||||
return
|
||||
}
|
||||
a.connMu.Lock()
|
||||
a.conn = conn
|
||||
a.connected = true
|
||||
a.connMu.Unlock()
|
||||
fmt.Println("[qq] event WebSocket connected")
|
||||
})
|
||||
mux.HandleFunc("/", a.wsHandler) // NapCat 正向 WS 标准路径
|
||||
mux.HandleFunc("/ws/qq", a.legacyHandler) // 向下兼容旧配置
|
||||
mux.HandleFunc("/ws/qq/event", a.legacyHandler) // 向下兼容旧配置
|
||||
|
||||
addr := ":" + a.port
|
||||
srv := &http.Server{Addr: addr, Handler: mux}
|
||||
a.srv = &http.Server{Addr: addr, Handler: mux}
|
||||
go func() {
|
||||
fmt.Printf("[qq] listening on %s (waiting for bot WebSocket connection)\n", addr)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
fmt.Printf("[qq] WebSocket server on %s (waiting for NapCat forward WS connection)\n", addr)
|
||||
if a.accessToken != "" {
|
||||
fmt.Println("[qq] access_token 已配置,将验证连接请求")
|
||||
}
|
||||
if err := a.srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
fmt.Printf("[qq] server error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
@@ -91,12 +255,22 @@ func (a *Adapter) Connect(ctx context.Context) error {
|
||||
|
||||
func (a *Adapter) Disconnect(ctx context.Context) error {
|
||||
a.connMu.Lock()
|
||||
defer a.connMu.Unlock()
|
||||
if a.conn != nil {
|
||||
a.conn.Close()
|
||||
a.conn = nil
|
||||
}
|
||||
a.connected = false
|
||||
srv := a.srv
|
||||
a.srv = nil
|
||||
a.connMu.Unlock()
|
||||
|
||||
if srv != nil {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
return fmt.Errorf("qq server shutdown: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -120,10 +294,8 @@ func (a *Adapter) ToUnified(rawMessage interface{}) (*bridge.UnifiedMessage, err
|
||||
return nil, fmt.Errorf("expected *OBv11Message, got %T", rawMessage)
|
||||
}
|
||||
|
||||
// Extract text content.
|
||||
content := extractText(msg)
|
||||
|
||||
// Determine sender.
|
||||
senderID := ""
|
||||
senderName := "unknown"
|
||||
channelType := "direct"
|
||||
@@ -150,8 +322,8 @@ func (a *Adapter) ToUnified(rawMessage interface{}) (*bridge.UnifiedMessage, err
|
||||
}
|
||||
}
|
||||
|
||||
// Extract mentions.
|
||||
var mentions []string
|
||||
var replyToText string
|
||||
if segments, ok := msg.Message.([]interface{}); ok {
|
||||
for _, s := range segments {
|
||||
if seg, ok := s.(map[string]interface{}); ok {
|
||||
@@ -162,10 +334,48 @@ func (a *Adapter) ToUnified(rawMessage interface{}) (*bridge.UnifiedMessage, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if seg["type"] == "reply" {
|
||||
if data, ok := seg["data"].(map[string]interface{}); ok {
|
||||
if t, ok := data["text"].(string); ok && t != "" {
|
||||
replyToText = t
|
||||
}
|
||||
if id, ok := data["id"]; ok {
|
||||
_ = id // message ID of the replied-to message
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prepend reply context for the AI.
|
||||
if replyToText != "" {
|
||||
content = "【回复】" + truncateForReply(replyToText, 100) + "\n" + content
|
||||
}
|
||||
// Fallback: parse CQ at codes from string format (e.g. [CQ:at,qq=2254389756]).
|
||||
if len(mentions) == 0 {
|
||||
raw := msg.RawMessage
|
||||
if raw == "" {
|
||||
if s, ok := msg.Message.(string); ok {
|
||||
raw = s
|
||||
}
|
||||
}
|
||||
for _, m := range cqAtRegex.FindAllStringSubmatch(raw, -1) {
|
||||
if len(m) >= 2 {
|
||||
mentions = append(mentions, m[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve group name for group messages.
|
||||
groupName := ""
|
||||
if msg.MessageType == "group" {
|
||||
groupName = a.GroupName(msg.GroupID)
|
||||
}
|
||||
|
||||
attachments := extractAttachments(msg)
|
||||
attachments = append(attachments, extractCardImageURLs(msg)...)
|
||||
|
||||
return &bridge.UnifiedMessage{
|
||||
SenderID: senderID,
|
||||
SenderName: senderName,
|
||||
@@ -176,8 +386,10 @@ func (a *Adapter) ToUnified(rawMessage interface{}) (*bridge.UnifiedMessage, err
|
||||
ContentType: "text",
|
||||
MessageID: fmt.Sprintf("%d", msg.MessageID),
|
||||
Mentions: mentions,
|
||||
Attachments: attachments,
|
||||
RawData: rawMessage,
|
||||
Timestamp: time.Unix(msg.Time, 0),
|
||||
GroupName: groupName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -185,11 +397,7 @@ func (a *Adapter) ToUnified(rawMessage interface{}) (*bridge.UnifiedMessage, err
|
||||
func (a *Adapter) FromUnified(response *bridge.UnifiedResponse) ([]bridge.PlatformMessage, error) {
|
||||
var msgs []bridge.PlatformMessage
|
||||
for _, rm := range response.Messages {
|
||||
content := rm.Content
|
||||
if rm.FormatMode == "markdown" {
|
||||
content = convertMarkdownToQQ(rm.Content)
|
||||
}
|
||||
// QQ prefers shorter messages — split if needed.
|
||||
content := ConvertMarkdownToQQ(rm.Content)
|
||||
runes := []rune(content)
|
||||
if len(runes) > 200 {
|
||||
content = string(runes[:200])
|
||||
@@ -240,6 +448,14 @@ func (a *Adapter) ReadMessages(ctx context.Context, msgCh chan<- *OBv11Message)
|
||||
conn := a.conn
|
||||
a.connMu.Unlock()
|
||||
if conn == nil {
|
||||
// Client mode: auto-reconnect when connection is lost.
|
||||
if a.mode == "client" {
|
||||
if err := a.connectClient(ctx); err != nil {
|
||||
fmt.Printf("[qq] reconnect failed: %v, retrying in 3s...\n", err)
|
||||
time.Sleep(3 * time.Second)
|
||||
}
|
||||
continue
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
@@ -248,7 +464,10 @@ func (a *Adapter) ReadMessages(ctx context.Context, msgCh chan<- *OBv11Message)
|
||||
if err != nil {
|
||||
fmt.Printf("[qq] read error: %v\n", err)
|
||||
a.connMu.Lock()
|
||||
a.conn = nil
|
||||
if a.conn != nil {
|
||||
a.conn.Close()
|
||||
a.conn = nil
|
||||
}
|
||||
a.connected = false
|
||||
a.connMu.Unlock()
|
||||
time.Sleep(time.Second)
|
||||
@@ -258,7 +477,6 @@ func (a *Adapter) ReadMessages(ctx context.Context, msgCh chan<- *OBv11Message)
|
||||
// Try to parse as OBv11 message (event from QQ).
|
||||
var msg OBv11Message
|
||||
if err := json.Unmarshal(raw, &msg); err != nil {
|
||||
// Might be an API response.
|
||||
var resp OBv11APIResponse
|
||||
if err := json.Unmarshal(raw, &resp); err != nil {
|
||||
fmt.Printf("[qq] unknown message: %s\n", string(raw))
|
||||
@@ -275,7 +493,12 @@ func (a *Adapter) ReadMessages(ctx context.Context, msgCh chan<- *OBv11Message)
|
||||
continue
|
||||
}
|
||||
|
||||
// Only handle message events.
|
||||
// Capture bot's own QQ number from incoming messages.
|
||||
if msg.SelfID != 0 && a.selfID == "" {
|
||||
a.selfID = fmt.Sprintf("%d", msg.SelfID)
|
||||
fmt.Printf("[qq:%s] self ID captured: %s\n", a.configName, a.selfID)
|
||||
}
|
||||
|
||||
if msg.PostType == "message" {
|
||||
select {
|
||||
case msgCh <- &msg:
|
||||
@@ -286,24 +509,287 @@ func (a *Adapter) ReadMessages(ctx context.Context, msgCh chan<- *OBv11Message)
|
||||
}
|
||||
}
|
||||
|
||||
// parseJSONCardTitle extracts human-readable content from a QQ JSON card.
|
||||
func parseJSONCardTitle(data string) string {
|
||||
var card struct {
|
||||
App string `json:"app"`
|
||||
Prompt string `json:"prompt"`
|
||||
Title string `json:"title"`
|
||||
Desc string `json:"desc"`
|
||||
Meta struct {
|
||||
Detail1 struct {
|
||||
Title string `json:"title"`
|
||||
Desc string `json:"desc"`
|
||||
Tag string `json:"tag"`
|
||||
} `json:"detail_1"`
|
||||
News struct {
|
||||
Title string `json:"title"`
|
||||
Desc string `json:"desc"`
|
||||
Tag string `json:"tag"`
|
||||
Preview string `json:"preview"`
|
||||
} `json:"news"`
|
||||
Music struct {
|
||||
Title string `json:"title"`
|
||||
Desc string `json:"desc"`
|
||||
Tag string `json:"tag"`
|
||||
Preview string `json:"preview"`
|
||||
} `json:"music"`
|
||||
} `json:"meta"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &card); err != nil {
|
||||
return parseJSONCardFallback(data)
|
||||
}
|
||||
|
||||
var parts []string
|
||||
push := func(s string) { if s != "" { parts = append(parts, s) } }
|
||||
|
||||
// Build header: [音乐分享] or [卡片]
|
||||
label := cardAppLabel(card.App)
|
||||
if card.Prompt != "" {
|
||||
// Strip redundant QQ-specific prefix from prompt (e.g. "[QQ小程序]" when we already say "[小程序]").
|
||||
prompt := card.Prompt
|
||||
for _, pfx := range qqPromptPrefixes {
|
||||
if strings.HasPrefix(prompt, pfx) {
|
||||
prompt = strings.TrimPrefix(prompt, pfx)
|
||||
break
|
||||
}
|
||||
}
|
||||
push(label + " " + prompt)
|
||||
} else if t := firstNonEmpty(card.Meta.Detail1.Title, card.Meta.News.Title, card.Meta.Music.Title, card.Title); t != "" {
|
||||
push(label + " " + t)
|
||||
}
|
||||
|
||||
// Append description and source tag if available.
|
||||
desc := firstNonEmpty(card.Meta.Detail1.Desc, card.Meta.Music.Desc, card.Meta.News.Desc, card.Desc)
|
||||
if desc != "" && desc != card.Prompt && desc != card.Title &&
|
||||
(len(parts) == 0 || !strings.Contains(parts[0], desc)) {
|
||||
push("简介:" + desc)
|
||||
}
|
||||
tag := firstNonEmpty(card.Meta.Music.Tag, card.Meta.News.Tag, card.Meta.Detail1.Tag)
|
||||
if tag != "" && (len(parts) == 0 || !strings.Contains(parts[len(parts)-1], tag)) {
|
||||
push("来源:" + tag)
|
||||
}
|
||||
preview := firstNonEmpty(card.Meta.Music.Preview, card.Meta.News.Preview)
|
||||
if preview != "" {
|
||||
push("封面:" + preview)
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return parseJSONCardFallback(data)
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
func firstNonEmpty(ss ...string) string {
|
||||
for _, s := range ss {
|
||||
if s != "" { return s }
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// cardAppLabel maps QQ card app IDs to human-readable Chinese labels.
|
||||
func cardAppLabel(app string) string {
|
||||
if label, ok := cardAppMap[app]; ok {
|
||||
return "[" + label + "]"
|
||||
}
|
||||
return "[卡片]"
|
||||
}
|
||||
|
||||
var cardAppMap = map[string]string{
|
||||
"com.tencent.music.lua": "音乐分享",
|
||||
"com.tencent.structmsg": "结构化消息",
|
||||
"com.tencent.qun.pro": "群Pro卡片",
|
||||
"com.tencent.miniapp_01": "小程序",
|
||||
"com.tencent.contact.lua": "联系人分享",
|
||||
"com.tencent.groupphoto": "群相册",
|
||||
"com.tencent.qqfiles": "文件分享",
|
||||
"com.tencent.announcement": "群公告",
|
||||
"com.tencent.mobileqq.ark": "ARK消息",
|
||||
"com.tencent.tuwen.lua": "图文消息",
|
||||
}
|
||||
|
||||
// QQ card prompts often include a self-describing prefix that overlaps with our label.
|
||||
var qqPromptPrefixes = []string{
|
||||
"[QQ小程序] ", "[QQ小程序]",
|
||||
"[QQ音乐] ", "[QQ音乐]",
|
||||
}
|
||||
|
||||
// parseJSONCardFallback tries a generic approach for unknown card structures.
|
||||
func parseJSONCardFallback(data string) string {
|
||||
var raw map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(data), &raw); err != nil {
|
||||
return "[卡片消息]"
|
||||
}
|
||||
|
||||
var parts []string
|
||||
push := func(s string) { if s != "" { parts = append(parts, s) } }
|
||||
|
||||
label := "[卡片]"
|
||||
if app, ok := raw["app"].(string); ok {
|
||||
label = cardAppLabel(app)
|
||||
}
|
||||
if s, _ := raw["prompt"].(string); s != "" {
|
||||
push(label + " " + s)
|
||||
} else if s, _ := raw["title"].(string); s != "" {
|
||||
push(label + " " + s)
|
||||
}
|
||||
|
||||
// Search meta for title and desc.
|
||||
if meta, ok := raw["meta"].(map[string]interface{}); ok {
|
||||
for _, v := range meta {
|
||||
if sub, ok := v.(map[string]interface{}); ok {
|
||||
if s, _ := sub["title"].(string); s != "" && len(parts) == 0 {
|
||||
push(label + " " + s)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, v := range meta {
|
||||
if sub, ok := v.(map[string]interface{}); ok {
|
||||
if s, _ := sub["desc"].(string); s != "" {
|
||||
push("简介:" + s)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, v := range meta {
|
||||
if sub, ok := v.(map[string]interface{}); ok {
|
||||
if s, _ := sub["tag"].(string); s != "" {
|
||||
push("来源:" + s)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if s, _ := raw["desc"].(string); s != "" && len(parts) < 2 {
|
||||
push("简介:" + s)
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
log.Printf("[qq] 卡片解析失败,原始JSON: %s", data)
|
||||
return "[卡片消息]"
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
// cqSimplifyMap maps CQ code types to simplified Chinese labels.
|
||||
var cqSimplifyMap = map[string]string{
|
||||
"image": "[图片]",
|
||||
"reply": "[回复]",
|
||||
"face": "[表情]",
|
||||
"record": "[语音]",
|
||||
"video": "[视频]",
|
||||
"file": "[文件]",
|
||||
}
|
||||
|
||||
// simplifyCQCodes replaces [CQ:type,...] codes with human-readable labels.
|
||||
func simplifyCQCodes(s string) string {
|
||||
return cqAllRegex.ReplaceAllStringFunc(s, func(match string) string {
|
||||
// match looks like "[CQ:image,file=xxx,url=xxx]"
|
||||
// Extract the type (text between "CQ:" and the first "," or "]").
|
||||
typ := match[4:] // strip "[CQ:"
|
||||
for i, c := range typ {
|
||||
if c == ',' || c == ']' {
|
||||
typ = typ[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if typ == "json" {
|
||||
// Parse data= field from [CQ:json,data=HTML_ENCODED_JSON]
|
||||
if dataVal := extractCQParam(match, "data"); dataVal != "" {
|
||||
decoded := html.UnescapeString(dataVal)
|
||||
return parseJSONCardTitle(decoded)
|
||||
}
|
||||
return "[卡片消息]"
|
||||
}
|
||||
if label, ok := cqSimplifyMap[typ]; ok {
|
||||
return label
|
||||
}
|
||||
return "[" + typ + "]"
|
||||
})
|
||||
}
|
||||
|
||||
// extractCQParam extracts a named parameter value from a CQ code string.
|
||||
// e.g. extractCQParam("[CQ:json,data=hello%20world]", "data") → "hello%20world"
|
||||
func extractCQParam(cqCode, paramName string) string {
|
||||
prefix := paramName + "="
|
||||
idx := strings.Index(cqCode, prefix)
|
||||
if idx < 0 {
|
||||
return ""
|
||||
}
|
||||
val := cqCode[idx+len(prefix):]
|
||||
// Value ends at "," or "]"
|
||||
for i, c := range val {
|
||||
if c == ',' || c == ']' {
|
||||
return val[:i]
|
||||
}
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// extractText retrieves plain text from an OBv11 message.
|
||||
// CQ codes are converted to human-readable form where applicable (e.g. [CQ:at,qq=xxx] → @xxx).
|
||||
func extractText(msg *OBv11Message) string {
|
||||
if msg.RawMessage != "" {
|
||||
return msg.RawMessage
|
||||
s := cqAtRegex.ReplaceAllString(msg.RawMessage, "@$1")
|
||||
return simplifyCQCodes(s)
|
||||
}
|
||||
switch m := msg.Message.(type) {
|
||||
case string:
|
||||
return m
|
||||
s := cqAtRegex.ReplaceAllString(m, "@$1")
|
||||
return simplifyCQCodes(s)
|
||||
case []interface{}:
|
||||
var text string
|
||||
for _, seg := range m {
|
||||
if s, ok := seg.(map[string]interface{}); ok {
|
||||
if s["type"] == "text" {
|
||||
switch s["type"] {
|
||||
case "text":
|
||||
if data, ok := s["data"].(map[string]interface{}); ok {
|
||||
if t, ok := data["text"].(string); ok {
|
||||
text += t
|
||||
}
|
||||
}
|
||||
case "at":
|
||||
if data, ok := s["data"].(map[string]interface{}); ok {
|
||||
if qq, ok := data["qq"].(string); ok {
|
||||
text += "@" + qq
|
||||
}
|
||||
}
|
||||
case "image":
|
||||
text += "[图片]"
|
||||
case "face":
|
||||
text += "[表情]"
|
||||
case "record":
|
||||
text += "[语音]"
|
||||
case "video":
|
||||
text += "[视频]"
|
||||
case "file":
|
||||
text += "[文件]"
|
||||
case "reply":
|
||||
// Reply is handled separately in ToUnified with reply text.
|
||||
text += "[回复]"
|
||||
case "json", "card":
|
||||
if data, ok := s["data"].(map[string]interface{}); ok {
|
||||
inner := data["data"]
|
||||
switch v := inner.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
text += parseJSONCardTitle(v)
|
||||
} else {
|
||||
text += "[卡片消息]"
|
||||
}
|
||||
case map[string]interface{}:
|
||||
// Already parsed — re-marshal to JSON string for parsing.
|
||||
if b, err := json.Marshal(v); err == nil {
|
||||
text += parseJSONCardTitle(string(b))
|
||||
} else {
|
||||
text += "[卡片消息]"
|
||||
}
|
||||
default:
|
||||
text += "[卡片消息]"
|
||||
}
|
||||
} else {
|
||||
text += "[卡片消息]"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -312,13 +798,186 @@ func extractText(msg *OBv11Message) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// convertMarkdownToQQ converts common markdown to QQ-supported format.
|
||||
func convertMarkdownToQQ(md string) string {
|
||||
// QQ supports basic markdown: **bold**, *italic*, ~~strikethrough~~
|
||||
// Remove unsupported elements (headings, code blocks, links).
|
||||
var cqAtRegex = regexp.MustCompile(`\[CQ:at,qq=(\d+)\]`)
|
||||
var cqImageRegex = regexp.MustCompile(`\[CQ:image,[^\]]*\]`)
|
||||
var cqVideoRegex = regexp.MustCompile(`\[CQ:video,[^\]]*\]`)
|
||||
var cqRecordRegex = regexp.MustCompile(`\[CQ:record,[^\]]*\]`)
|
||||
var cqURLRegex = regexp.MustCompile(`\burl=([^,\]]+)`)
|
||||
var cqDurationRegex = regexp.MustCompile(`\bduration=(\d+)`)
|
||||
var cqAllRegex = regexp.MustCompile(`\[CQ:[^\]]+\]`)
|
||||
var cqJSONRegex = regexp.MustCompile(`\[CQ:(?:json|card),[^\]]*data=[^\]]*\]`)
|
||||
var boldRegex = regexp.MustCompile(`\*\*(.+?)\*\*`)
|
||||
var italicRegex = regexp.MustCompile(`\*(.+?)\*`)
|
||||
var strikethroughRegex = regexp.MustCompile(`~~(.+?)~~`)
|
||||
|
||||
func parseIntOr(s string, defaultVal int) int {
|
||||
if s == "" {
|
||||
return defaultVal
|
||||
}
|
||||
n := 0
|
||||
for _, c := range s {
|
||||
if c >= '0' && c <= '9' {
|
||||
n = n*10 + int(c-'0')
|
||||
} else {
|
||||
return defaultVal
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// truncateForReply truncates reply preview text to keep messages readable.
|
||||
func truncateForReply(s string, maxLen int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return string(runes[:maxLen]) + "…"
|
||||
}
|
||||
|
||||
// extractAttachments extracts image/video URLs from OBv11Message.
|
||||
// Handles both string format (CQ codes in raw_message) and array format (parsed segments).
|
||||
func extractAttachments(msg *OBv11Message) []bridge.Attachment {
|
||||
var attachments []bridge.Attachment
|
||||
|
||||
// Array format: iterate segments looking for image and video.
|
||||
if segments, ok := msg.Message.([]interface{}); ok {
|
||||
for _, s := range segments {
|
||||
if seg, ok := s.(map[string]interface{}); ok {
|
||||
segType, _ := seg["type"].(string)
|
||||
if segType != "image" && segType != "video" && segType != "record" {
|
||||
continue
|
||||
}
|
||||
data, _ := seg["data"].(map[string]interface{})
|
||||
if data == nil {
|
||||
continue
|
||||
}
|
||||
url, _ := data["url"].(string)
|
||||
file, _ := data["file"].(string)
|
||||
if url == "" {
|
||||
continue
|
||||
}
|
||||
att := bridge.Attachment{
|
||||
Type: segType,
|
||||
URL: url,
|
||||
FileName: file,
|
||||
}
|
||||
if segType == "video" {
|
||||
if d, ok := data["duration"].(float64); ok {
|
||||
att.Duration = int(d)
|
||||
}
|
||||
}
|
||||
attachments = append(attachments, att)
|
||||
}
|
||||
}
|
||||
return attachments
|
||||
}
|
||||
|
||||
// String format: parse CQ codes from RawMessage or string Message.
|
||||
raw := msg.RawMessage
|
||||
if raw == "" {
|
||||
if s, ok := msg.Message.(string); ok {
|
||||
raw = s
|
||||
}
|
||||
}
|
||||
// Images.
|
||||
for _, m := range cqImageRegex.FindAllString(raw, -1) {
|
||||
urlMatch := cqURLRegex.FindStringSubmatch(m)
|
||||
if len(urlMatch) >= 2 {
|
||||
attachments = append(attachments, bridge.Attachment{Type: "image", URL: urlMatch[1]})
|
||||
}
|
||||
}
|
||||
// Videos.
|
||||
for _, m := range cqVideoRegex.FindAllString(raw, -1) {
|
||||
urlMatch := cqURLRegex.FindStringSubmatch(m)
|
||||
if len(urlMatch) >= 2 {
|
||||
dur := 0
|
||||
if dm := cqDurationRegex.FindStringSubmatch(m); len(dm) >= 2 {
|
||||
dur = parseIntOr(dm[1], 0)
|
||||
}
|
||||
attachments = append(attachments, bridge.Attachment{Type: "video", URL: urlMatch[1], Duration: dur})
|
||||
}
|
||||
}
|
||||
// Records (voice messages).
|
||||
for _, m := range cqRecordRegex.FindAllString(raw, -1) {
|
||||
urlMatch := cqURLRegex.FindStringSubmatch(m)
|
||||
if len(urlMatch) >= 2 {
|
||||
attachments = append(attachments, bridge.Attachment{Type: "voice", URL: urlMatch[1]})
|
||||
}
|
||||
}
|
||||
return attachments
|
||||
}
|
||||
|
||||
// extractCardImageURLs finds json/card CQ codes in the raw message and extracts
|
||||
// preview/cover image URLs so they can be fed to the vision pipeline.
|
||||
func extractCardImageURLs(msg *OBv11Message) []bridge.Attachment {
|
||||
raw := msg.RawMessage
|
||||
if raw == "" {
|
||||
if s, ok := msg.Message.(string); ok {
|
||||
raw = s
|
||||
}
|
||||
}
|
||||
// Match [CQ:json,...] or [CQ:card,...]
|
||||
matches := cqJSONRegex.FindAllString(raw, -1)
|
||||
if len(matches) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var attachments []bridge.Attachment
|
||||
seen := map[string]bool{}
|
||||
for _, m := range matches {
|
||||
dataVal := extractCQParam(m, "data")
|
||||
if dataVal == "" {
|
||||
continue
|
||||
}
|
||||
decoded := html.UnescapeString(dataVal)
|
||||
urls := extractPreviewURLs(decoded)
|
||||
for _, u := range urls {
|
||||
if !seen[u] {
|
||||
seen[u] = true
|
||||
attachments = append(attachments, bridge.Attachment{Type: "image", URL: u})
|
||||
}
|
||||
}
|
||||
}
|
||||
return attachments
|
||||
}
|
||||
|
||||
// extractPreviewURLs parses a QQ card JSON and returns any preview/cover image URLs.
|
||||
func extractPreviewURLs(data string) []string {
|
||||
var card struct {
|
||||
Meta struct {
|
||||
Music struct {
|
||||
Preview string `json:"preview"`
|
||||
} `json:"music"`
|
||||
News struct {
|
||||
Preview string `json:"preview"`
|
||||
} `json:"news"`
|
||||
} `json:"meta"`
|
||||
// Some card types have preview at top level or in other locations.
|
||||
Preview string `json:"preview"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &card); err != nil {
|
||||
return nil
|
||||
}
|
||||
var urls []string
|
||||
if card.Meta.Music.Preview != "" {
|
||||
urls = append(urls, card.Meta.Music.Preview)
|
||||
}
|
||||
if card.Meta.News.Preview != "" {
|
||||
urls = append(urls, card.Meta.News.Preview)
|
||||
}
|
||||
if card.Preview != "" {
|
||||
urls = append(urls, card.Preview)
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
// ConvertMarkdownToQQ converts markdown to QQ plain-text format.
|
||||
func ConvertMarkdownToQQ(md string) string {
|
||||
md = boldRegex.ReplaceAllString(md, "$1")
|
||||
md = italicRegex.ReplaceAllString(md, "$1")
|
||||
md = strikethroughRegex.ReplaceAllString(md, "$1")
|
||||
md = removeHeadings(md)
|
||||
md = removeCodeBlocks(md)
|
||||
// Preserve bold, italic, strikethrough which QQ supports.
|
||||
return md
|
||||
}
|
||||
|
||||
@@ -332,21 +991,17 @@ func removeHeadings(s string) string {
|
||||
}
|
||||
|
||||
func removeCodeBlocks(s string) string {
|
||||
// Simple: remove ``` markers.
|
||||
result := ""
|
||||
var kept []string
|
||||
inCode := false
|
||||
for _, line := range splitLines(s) {
|
||||
if hasPrefix(line, "```") {
|
||||
inCode = !inCode
|
||||
continue
|
||||
}
|
||||
if inCode {
|
||||
result += line + "\n"
|
||||
} else {
|
||||
result += line + "\n"
|
||||
}
|
||||
kept = append(kept, line)
|
||||
}
|
||||
return result
|
||||
_ = inCode
|
||||
return strings.Join(kept, "\n")
|
||||
}
|
||||
|
||||
func splitLines(s string) []string {
|
||||
@@ -376,7 +1031,6 @@ func stripPrefix(s, prefix string) string {
|
||||
}
|
||||
|
||||
func replaceLine(s, old, new string) string {
|
||||
// Simple: find old line and replace with new.
|
||||
idx := indexOf(s, old)
|
||||
if idx < 0 {
|
||||
return s
|
||||
|
||||
@@ -44,6 +44,23 @@ func (m *IdentityMapper) Resolve(platform, platformUID string) (*permissions.Pla
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// ResolveOrNil finds the Cyrene user for a platform identity, returning nil for unknown users.
|
||||
func (m *IdentityMapper) ResolveOrNil(platform, platformUID string) *permissions.PlatformIdentity {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
plat, ok := m.byPlatform[platform]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return plat[platformUID]
|
||||
}
|
||||
|
||||
// IsAdmin returns true if the given platform user is a registered admin.
|
||||
func (m *IdentityMapper) IsAdmin(platform, platformUID string) bool {
|
||||
id := m.ResolveOrNil(platform, platformUID)
|
||||
return id != nil && id.PermissionLevel == "admin"
|
||||
}
|
||||
|
||||
// List returns all identities for a platform.
|
||||
func (m *IdentityMapper) List(platform string) []permissions.PlatformIdentity {
|
||||
m.mu.RLock()
|
||||
|
||||
@@ -1,12 +1,26 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/platform-bridge/internal/permissions"
|
||||
)
|
||||
|
||||
|
||||
const participantTTL = 5 * time.Minute
|
||||
|
||||
// adapterKey returns the unique key for an adapter in the router map.
|
||||
// Uses ConfigName() if the adapter implements it, otherwise PlatformName().
|
||||
func adapterKey(a PlatformAdapter) string {
|
||||
if named, ok := a.(interface{ ConfigName() string }); ok {
|
||||
return named.ConfigName()
|
||||
}
|
||||
return a.PlatformName()
|
||||
}
|
||||
|
||||
// PlatformRouter manages all platform adapters and routes messages.
|
||||
type PlatformRouter struct {
|
||||
mu sync.RWMutex
|
||||
@@ -21,11 +35,14 @@ type PlatformRouter struct {
|
||||
|
||||
// ChannelContext stores the active conversation state for a channel.
|
||||
type ChannelContext struct {
|
||||
Platform string
|
||||
ChannelID string
|
||||
ChannelType string
|
||||
LastUserMsg string
|
||||
MessageCount int
|
||||
Platform string
|
||||
ChannelID string
|
||||
ChannelType string
|
||||
LastUserMsg string
|
||||
LastSenderUID string
|
||||
RecentSenders []string // last 5 sender UIDs (original platform UIDs)
|
||||
ActiveParticipants map[string]time.Time // UID -> last bot reply time (for multi-user conversation continuity)
|
||||
MessageCount int
|
||||
}
|
||||
|
||||
func NewPlatformRouter(mapper *IdentityMapper, checker *permissions.Checker) *PlatformRouter {
|
||||
@@ -37,11 +54,37 @@ func NewPlatformRouter(mapper *IdentityMapper, checker *permissions.Checker) *Pl
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterAdapter adds a platform adapter.
|
||||
// RegisterAdapter adds a platform adapter, keyed by its config name.
|
||||
func (r *PlatformRouter) RegisterAdapter(a PlatformAdapter) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.adapters[a.PlatformName()] = a
|
||||
r.adapters[adapterKey(a)] = a
|
||||
}
|
||||
|
||||
// RemoveAdapter disconnects and removes a platform adapter.
|
||||
func (r *PlatformRouter) RemoveAdapter(platform string) {
|
||||
r.mu.Lock()
|
||||
a, ok := r.adapters[platform]
|
||||
if ok {
|
||||
delete(r.adapters, platform)
|
||||
}
|
||||
r.mu.Unlock()
|
||||
if ok {
|
||||
a.Disconnect(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
// ReplaceAdapter disconnects the old adapter (if present), registers the new one,
|
||||
// and connects it. Returns an error if the new adapter fails to connect.
|
||||
func (r *PlatformRouter) ReplaceAdapter(a PlatformAdapter) error {
|
||||
key := adapterKey(a)
|
||||
r.mu.Lock()
|
||||
if old, ok := r.adapters[key]; ok {
|
||||
old.Disconnect(context.Background())
|
||||
}
|
||||
r.adapters[key] = a
|
||||
r.mu.Unlock()
|
||||
return a.Connect(context.Background())
|
||||
}
|
||||
|
||||
// GetAdapter returns the adapter for a platform.
|
||||
@@ -55,7 +98,7 @@ func (r *PlatformRouter) GetAdapter(platform string) (PlatformAdapter, error) {
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// ListAdapters returns all registered adapter names.
|
||||
// ListAdapters returns all registered adapter names (config names).
|
||||
func (r *PlatformRouter) ListAdapters() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
@@ -66,14 +109,28 @@ func (r *PlatformRouter) ListAdapters() []string {
|
||||
return names
|
||||
}
|
||||
|
||||
// GetAdaptersByPlatform returns all registered adapters for a given platform type.
|
||||
func (r *PlatformRouter) GetAdaptersByPlatform(platform string) []PlatformAdapter {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
var result []PlatformAdapter
|
||||
for _, a := range r.adapters {
|
||||
if a.PlatformName() == platform {
|
||||
result = append(result, a)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SetMessageHandler sets the callback for processing unified messages.
|
||||
func (r *PlatformRouter) SetMessageHandler(h MessageHandler) {
|
||||
r.handler = h
|
||||
}
|
||||
|
||||
// RouteMessage converts a platform message to unified, checks permissions, and dispatches.
|
||||
func (r *PlatformRouter) RouteMessage(platform string, rawMsg interface{}) (*UnifiedResponse, error) {
|
||||
a, err := r.GetAdapter(platform)
|
||||
// adapterKey is the config name (e.g., "qq", "qq-home") used to look up the adapter instance.
|
||||
func (r *PlatformRouter) RouteMessage(adapterKey string, rawMsg interface{}) (*UnifiedResponse, error) {
|
||||
a, err := r.GetAdapter(adapterKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -83,18 +140,23 @@ func (r *PlatformRouter) RouteMessage(platform string, rawMsg interface{}) (*Uni
|
||||
return nil, fmt.Errorf("convert to unified: %w", err)
|
||||
}
|
||||
|
||||
// Resolve identity.
|
||||
identity, err := r.mapper.Resolve(platform, unified.SenderID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity not found: %w", err)
|
||||
// Preserve original platform UID before identity mapping.
|
||||
unified.OriginalSenderUID = unified.SenderID
|
||||
unified.OriginalSenderName = unified.SenderName
|
||||
unified.OriginalRawMessage = rawMsg
|
||||
|
||||
// Capture bot's own UID for @mention detection.
|
||||
if selfAware, ok := a.(interface{ SelfID() string }); ok {
|
||||
unified.BotUID = selfAware.SelfID()
|
||||
}
|
||||
|
||||
// Merge identity info into the unified message.
|
||||
unified.SenderID = identity.CyreneUser
|
||||
unified.SenderName = identity.Nickname
|
||||
|
||||
// Apply permission-based filtering.
|
||||
_ = identity // used by permission checks on tools
|
||||
// Resolve identity (nil for unknown users; caller decides routing).
|
||||
// Use platform type (e.g. "qq") for identity resolution, not adapter key.
|
||||
identity := r.mapper.ResolveOrNil(a.PlatformName(), unified.SenderID)
|
||||
if identity != nil {
|
||||
unified.SenderID = identity.CyreneUser
|
||||
unified.SenderName = identity.Nickname
|
||||
}
|
||||
|
||||
// Update channel context.
|
||||
r.updateContext(unified)
|
||||
@@ -108,8 +170,9 @@ func (r *PlatformRouter) RouteMessage(platform string, rawMsg interface{}) (*Uni
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response.Platform = platform
|
||||
response.PlatformHints = r.platformHints(platform)
|
||||
// Use adapter key for response routing so SendResponse finds the correct instance.
|
||||
response.Platform = adapterKey
|
||||
response.PlatformHints = r.platformHints(adapterKey)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
@@ -152,6 +215,11 @@ func (r *PlatformRouter) updateContext(msg *UnifiedMessage) {
|
||||
r.contexts[key] = ctx
|
||||
}
|
||||
ctx.LastUserMsg = msg.Content
|
||||
ctx.LastSenderUID = msg.OriginalSenderUID
|
||||
ctx.RecentSenders = append(ctx.RecentSenders, msg.OriginalSenderUID)
|
||||
if len(ctx.RecentSenders) > 5 {
|
||||
ctx.RecentSenders = ctx.RecentSenders[len(ctx.RecentSenders)-5:]
|
||||
}
|
||||
ctx.MessageCount++
|
||||
}
|
||||
|
||||
@@ -166,3 +234,37 @@ func (r *PlatformRouter) GetContext(platform, channelID string) *ChannelContext
|
||||
defer r.mu.RUnlock()
|
||||
return r.contexts[platform+":"+channelID]
|
||||
}
|
||||
|
||||
// NoteBotReply records that the bot just replied to a specific user in a channel.
|
||||
// Used for conversation continuity: subsequent messages from this user continue the
|
||||
// conversation even without an explicit @mention, within the participant TTL window.
|
||||
func (r *PlatformRouter) NoteBotReply(platform, channelID, recipientUID string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
key := r.channelKey(platform, channelID)
|
||||
ctx, ok := r.contexts[key]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if ctx.ActiveParticipants == nil {
|
||||
ctx.ActiveParticipants = make(map[string]time.Time)
|
||||
}
|
||||
ctx.ActiveParticipants[recipientUID] = time.Now()
|
||||
}
|
||||
|
||||
// IsActiveParticipant checks if a user was recently engaged by the bot.
|
||||
// TTL controls how long the continuity window stays open after the last bot reply.
|
||||
func (r *PlatformRouter) IsActiveParticipant(platform, channelID, uid string) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
key := r.channelKey(platform, channelID)
|
||||
ctx, ok := r.contexts[key]
|
||||
if !ok || ctx.ActiveParticipants == nil {
|
||||
return false
|
||||
}
|
||||
t, ok := ctx.ActiveParticipants[uid]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return time.Since(t) < participantTTL
|
||||
}
|
||||
|
||||
@@ -21,15 +21,24 @@ type UnifiedMessage struct {
|
||||
|
||||
RawData interface{} `json:"raw_data,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
|
||||
// Routing metadata.
|
||||
RouteType string `json:"route_type,omitempty"` // "normal", "silent", "admin_mention"
|
||||
OriginalSenderUID string `json:"original_sender_uid,omitempty"` // preserved before identity mapping
|
||||
OriginalSenderName string `json:"original_sender_name,omitempty"` // preserved before identity mapping
|
||||
GroupName string `json:"group_name,omitempty"` // resolved group name for group chats
|
||||
OriginalRawMessage interface{} `json:"-"` // preserved for SendMessage wiring
|
||||
BotUID string `json:"-"` // bot's own platform UID, set by router
|
||||
}
|
||||
|
||||
// Attachment represents a file/image/voice attachment.
|
||||
// Attachment represents a file/image/voice/video attachment.
|
||||
type Attachment struct {
|
||||
Type string `json:"type"` // "image", "voice", "file", "video"
|
||||
URL string `json:"url,omitempty"`
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
MimeType string `json:"mime_type,omitempty"`
|
||||
Size int64 `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"` // video/voice duration in seconds
|
||||
}
|
||||
|
||||
// UnifiedResponse is AI-Core's response converted to unified format.
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// BlocklistMode is either "blacklist" or "whitelist".
|
||||
type BlocklistSettings struct {
|
||||
Mode string `json:"mode"` // "blacklist" (default) or "whitelist"
|
||||
GroupIDs []string `json:"group_ids"` // group IDs to block/allow
|
||||
UserIDs []string `json:"user_ids"` // private chat user IDs to block/allow
|
||||
}
|
||||
|
||||
// BlocklistStore manages persistence of blocklist settings.
|
||||
type BlocklistStore struct {
|
||||
mu sync.RWMutex
|
||||
path string
|
||||
settings BlocklistSettings
|
||||
}
|
||||
|
||||
// NewBlocklistStore loads or creates blocklist settings file.
|
||||
func NewBlocklistStore(path string) (*BlocklistStore, error) {
|
||||
s := &BlocklistStore{
|
||||
path: path,
|
||||
settings: BlocklistSettings{
|
||||
Mode: "blacklist",
|
||||
GroupIDs: []string{},
|
||||
UserIDs: []string{},
|
||||
},
|
||||
}
|
||||
if err := s.load(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *BlocklistStore) load() error {
|
||||
data, err := os.ReadFile(s.path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return s.save() // write defaults
|
||||
}
|
||||
return fmt.Errorf("read blocklist file: %w", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if err := json.Unmarshal(data, &s.settings); err != nil {
|
||||
return fmt.Errorf("parse blocklist file: %w", err)
|
||||
}
|
||||
if s.settings.Mode == "" {
|
||||
s.settings.Mode = "blacklist"
|
||||
}
|
||||
if s.settings.GroupIDs == nil {
|
||||
s.settings.GroupIDs = []string{}
|
||||
}
|
||||
if s.settings.UserIDs == nil {
|
||||
s.settings.UserIDs = []string{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BlocklistStore) save() error {
|
||||
s.mu.RLock()
|
||||
data, err := json.MarshalIndent(s.settings, "", " ")
|
||||
s.mu.RUnlock()
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal blocklist: %w", err)
|
||||
}
|
||||
tmpPath := s.path + ".tmp"
|
||||
if err := os.WriteFile(tmpPath, data, 0640); err != nil {
|
||||
return fmt.Errorf("write blocklist: %w", err)
|
||||
}
|
||||
return os.Rename(tmpPath, s.path)
|
||||
}
|
||||
|
||||
// Get returns current blocklist settings.
|
||||
func (s *BlocklistStore) Get() BlocklistSettings {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
cp := BlocklistSettings{
|
||||
Mode: s.settings.Mode,
|
||||
GroupIDs: make([]string, len(s.settings.GroupIDs)),
|
||||
UserIDs: make([]string, len(s.settings.UserIDs)),
|
||||
}
|
||||
copy(cp.GroupIDs, s.settings.GroupIDs)
|
||||
copy(cp.UserIDs, s.settings.UserIDs)
|
||||
return cp
|
||||
}
|
||||
|
||||
// Set updates and persists blocklist settings.
|
||||
func (s *BlocklistStore) Set(bs BlocklistSettings) error {
|
||||
if bs.Mode != "blacklist" && bs.Mode != "whitelist" {
|
||||
return fmt.Errorf("invalid mode: %s (must be blacklist or whitelist)", bs.Mode)
|
||||
}
|
||||
if bs.GroupIDs == nil {
|
||||
bs.GroupIDs = []string{}
|
||||
}
|
||||
if bs.UserIDs == nil {
|
||||
bs.UserIDs = []string{}
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.settings = bs
|
||||
s.mu.Unlock()
|
||||
return s.save()
|
||||
}
|
||||
|
||||
// IsBlocked checks whether a message should be blocked based on channel type and ID.
|
||||
// In blacklist mode: returns true if the id is IN the list.
|
||||
// In whitelist mode: returns true if the id is NOT in the list.
|
||||
// Admin users should call this with isAdmin=true to always bypass.
|
||||
func (s *BlocklistStore) IsBlocked(channelType, channelID, senderID string, isAdmin bool) bool {
|
||||
if isAdmin {
|
||||
return false
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
switch s.settings.Mode {
|
||||
case "whitelist":
|
||||
// Block if NOT in the whitelist.
|
||||
if channelType == "group" {
|
||||
return !contains(s.settings.GroupIDs, channelID)
|
||||
}
|
||||
return !contains(s.settings.UserIDs, senderID)
|
||||
|
||||
case "blacklist":
|
||||
fallthrough
|
||||
default:
|
||||
// Block if IN the blacklist.
|
||||
if channelType == "group" {
|
||||
return contains(s.settings.GroupIDs, channelID)
|
||||
}
|
||||
return contains(s.settings.UserIDs, senderID)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(list []string, val string) bool {
|
||||
for _, v := range list {
|
||||
if v == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package config
|
||||
|
||||
import "os"
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Config holds Platform Bridge configuration.
|
||||
type Config struct {
|
||||
@@ -11,9 +14,18 @@ type Config struct {
|
||||
InternalToken string
|
||||
|
||||
// Platform-specific.
|
||||
QQBotPort string // port for QQ OBv11 reverse WebSocket
|
||||
TelegramToken string // Telegram Bot API token
|
||||
TelegramWebhookURL string // public webhook URL for Telegram
|
||||
QQBotPort string // port for QQ OBv11 reverse WebSocket
|
||||
TelegramToken string // Telegram Bot API token
|
||||
TelegramWebhookURL string // public webhook URL for Telegram
|
||||
|
||||
// Silent observation mode.
|
||||
PlatformSilentEnabled bool // PLATFORM_SILENT_ENABLED, default true
|
||||
AdminNickname string // ADMIN_NICKNAME, admin's Cyrene identity nickname (default "开拓者")
|
||||
AdminNicknames []string // ADMIN_NICKNAMES, default ["开拓者"]
|
||||
AdminMentionKeywords []string // ADMIN_MENTION_KEYWORDS, default ["昔涟","Cyrene","管理员"]
|
||||
|
||||
// Message sending.
|
||||
MessageSendIntervalMs int // MSG_SEND_INTERVAL_MS, minimum interval between platform messages (default 2000)
|
||||
}
|
||||
|
||||
func Load() *Config {
|
||||
@@ -48,5 +60,58 @@ func Load() *Config {
|
||||
if v := os.Getenv("TELEGRAM_WEBHOOK_URL"); v != "" {
|
||||
cfg.TelegramWebhookURL = v
|
||||
}
|
||||
// Silent observation defaults.
|
||||
cfg.PlatformSilentEnabled = getEnvBool("PLATFORM_SILENT_ENABLED", true)
|
||||
cfg.AdminNickname = os.Getenv("ADMIN_NICKNAME")
|
||||
if cfg.AdminNickname == "" {
|
||||
cfg.AdminNickname = "开拓者"
|
||||
}
|
||||
cfg.AdminNicknames = getEnvList("ADMIN_NICKNAMES", []string{"开拓者"})
|
||||
cfg.AdminMentionKeywords = getEnvList("ADMIN_MENTION_KEYWORDS", []string{"昔涟", "Cyrene", "管理员"})
|
||||
cfg.MessageSendIntervalMs = getEnvInt("MSG_SEND_INTERVAL_MS", 2000)
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func getEnvBool(key string, defaultVal bool) bool {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return defaultVal
|
||||
}
|
||||
return v == "true" || v == "1" || v == "yes"
|
||||
}
|
||||
|
||||
func getEnvInt(key string, defaultVal int) int {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return defaultVal
|
||||
}
|
||||
n := 0
|
||||
for _, c := range v {
|
||||
if c >= '0' && c <= '9' {
|
||||
n = n*10 + int(c-'0')
|
||||
} else {
|
||||
return defaultVal
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func getEnvList(key string, defaultVal []string) []string {
|
||||
v := os.Getenv(key)
|
||||
if v == "" {
|
||||
return defaultVal
|
||||
}
|
||||
parts := strings.Split(v, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return defaultVal
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
// PlatformConfig holds persistent configuration for one platform adapter.
|
||||
type PlatformConfig struct {
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"` // base platform type: "qq", "telegram", etc.
|
||||
Enabled bool `json:"enabled"`
|
||||
Label string `json:"label"`
|
||||
Fields map[string]string `json:"fields"`
|
||||
@@ -51,6 +52,12 @@ func (s *Store) load() error {
|
||||
if err := json.Unmarshal(data, &s.configs); err != nil {
|
||||
return fmt.Errorf("parse config file: %w", err)
|
||||
}
|
||||
// Backward compat: old configs without platform field default to Name.
|
||||
for _, c := range s.configs {
|
||||
if c.Platform == "" {
|
||||
c.Platform = c.Name
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/platform-bridge/internal/config"
|
||||
)
|
||||
|
||||
// BlocklistHandler exposes CRUD for blocklist settings.
|
||||
type BlocklistHandler struct {
|
||||
store *config.BlocklistStore
|
||||
}
|
||||
|
||||
func NewBlocklistHandler(store *config.BlocklistStore) *BlocklistHandler {
|
||||
return &BlocklistHandler{store: store}
|
||||
}
|
||||
|
||||
func (h *BlocklistHandler) RegisterRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/v1/settings/blocklist", h.handleBlocklist)
|
||||
}
|
||||
|
||||
func (h *BlocklistHandler) handleBlocklist(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
writeJSON(w, http.StatusOK, h.store.Get())
|
||||
case "POST", "PUT":
|
||||
var bs config.BlocklistSettings
|
||||
if err := json.NewDecoder(r.Body).Decode(&bs); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, errResp("invalid JSON: "+err.Error()))
|
||||
return
|
||||
}
|
||||
if err := h.store.Set(bs); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, errResp(err.Error()))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"status": "saved",
|
||||
"settings": h.store.Get(),
|
||||
})
|
||||
default:
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
||||
}
|
||||
}
|
||||
@@ -8,21 +8,27 @@ import (
|
||||
"git.yeij.top/AskaEth/Cyrene/platform-bridge/internal/config"
|
||||
)
|
||||
|
||||
var knownPlatforms = map[string]bool{
|
||||
var validPlatformTypes = map[string]bool{
|
||||
"qq": true, "telegram": true, "webhook": true,
|
||||
"wechat": true, "feishu": true, "discord": true,
|
||||
}
|
||||
|
||||
// ConfigHandler exposes CRUD endpoints for platform configs.
|
||||
type ConfigHandler struct {
|
||||
store *config.Store
|
||||
router *bridge.PlatformRouter
|
||||
store *config.Store
|
||||
router *bridge.PlatformRouter
|
||||
onChanged func(name, platform string, enabled bool, fields map[string]string)
|
||||
}
|
||||
|
||||
func NewConfigHandler(store *config.Store, router *bridge.PlatformRouter) *ConfigHandler {
|
||||
return &ConfigHandler{store: store, router: router}
|
||||
}
|
||||
|
||||
// SetOnConfigChanged sets a callback invoked after config is saved or deleted.
|
||||
func (h *ConfigHandler) SetOnConfigChanged(fn func(name, platform string, enabled bool, fields map[string]string)) {
|
||||
h.onChanged = fn
|
||||
}
|
||||
|
||||
func (h *ConfigHandler) RegisterRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/v1/configs", h.listConfigs)
|
||||
mux.HandleFunc("/api/v1/configs/", h.handleConfig)
|
||||
@@ -33,6 +39,7 @@ func (h *ConfigHandler) listConfigs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
type configSummary struct {
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Label string `json:"label,omitempty"`
|
||||
Fields map[string]string `json:"fields"`
|
||||
@@ -46,8 +53,13 @@ func (h *ConfigHandler) listConfigs(w http.ResponseWriter, r *http.Request) {
|
||||
if a, err := h.router.GetAdapter(c.Name); err == nil {
|
||||
connected = a.IsConnected()
|
||||
}
|
||||
platform := c.Platform
|
||||
if platform == "" {
|
||||
platform = c.Name
|
||||
}
|
||||
result = append(result, configSummary{
|
||||
Name: c.Name,
|
||||
Platform: platform,
|
||||
Enabled: c.Enabled,
|
||||
Label: c.Label,
|
||||
Fields: c.Fields,
|
||||
@@ -71,6 +83,7 @@ func (h *ConfigHandler) listConfigs(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
result = append(result, configSummary{
|
||||
Name: name,
|
||||
Platform: name,
|
||||
Enabled: false,
|
||||
Fields: map[string]string{},
|
||||
Connected: connected,
|
||||
@@ -92,10 +105,6 @@ func (h *ConfigHandler) handleConfig(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusBadRequest, errResp("missing config name"))
|
||||
return
|
||||
}
|
||||
if !knownPlatforms[name] {
|
||||
writeJSON(w, http.StatusBadRequest, errResp("unknown platform: "+name))
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
@@ -120,26 +129,37 @@ func (h *ConfigHandler) getConfig(w http.ResponseWriter, r *http.Request, name s
|
||||
connected = a.IsConnected()
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"name": cfg.Name,
|
||||
"enabled": cfg.Enabled,
|
||||
"label": cfg.Label,
|
||||
"fields": cfg.Fields,
|
||||
"updated_at": cfg.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
"connected": connected,
|
||||
"name": cfg.Name,
|
||||
"platform": cfg.Platform,
|
||||
"enabled": cfg.Enabled,
|
||||
"label": cfg.Label,
|
||||
"fields": cfg.Fields,
|
||||
"updated_at": cfg.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
"connected": connected,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ConfigHandler) saveConfig(w http.ResponseWriter, r *http.Request, name string) {
|
||||
var body struct {
|
||||
Enabled *bool `json:"enabled"`
|
||||
Label string `json:"label"`
|
||||
Fields map[string]string `json:"fields"`
|
||||
Platform *string `json:"platform"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
Label string `json:"label"`
|
||||
Fields map[string]string `json:"fields"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, errResp("invalid JSON: "+err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
platform := name
|
||||
if body.Platform != nil && *body.Platform != "" {
|
||||
platform = *body.Platform
|
||||
}
|
||||
if !validPlatformTypes[platform] {
|
||||
writeJSON(w, http.StatusBadRequest, errResp("unknown or missing platform type: "+platform))
|
||||
return
|
||||
}
|
||||
|
||||
enabled := true
|
||||
if body.Enabled != nil {
|
||||
enabled = *body.Enabled
|
||||
@@ -151,29 +171,48 @@ func (h *ConfigHandler) saveConfig(w http.ResponseWriter, r *http.Request, name
|
||||
}
|
||||
|
||||
cfg := config.PlatformConfig{
|
||||
Name: name,
|
||||
Enabled: enabled,
|
||||
Label: body.Label,
|
||||
Fields: fields,
|
||||
Name: name,
|
||||
Platform: platform,
|
||||
Enabled: enabled,
|
||||
Label: body.Label,
|
||||
Fields: fields,
|
||||
}
|
||||
if err := h.store.Set(cfg); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, errResp(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Trigger hot-reload.
|
||||
if h.onChanged != nil {
|
||||
h.onChanged(name, platform, enabled, fields)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"name": name,
|
||||
"enabled": enabled,
|
||||
"label": body.Label,
|
||||
"fields": fields,
|
||||
"status": "saved",
|
||||
"name": name,
|
||||
"platform": platform,
|
||||
"enabled": enabled,
|
||||
"label": body.Label,
|
||||
"fields": fields,
|
||||
"status": "saved",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *ConfigHandler) deleteConfig(w http.ResponseWriter, r *http.Request, name string) {
|
||||
// Get platform type before deleting (needed for onChanged callback).
|
||||
platform := name
|
||||
if cfg, err := h.store.Get(name); err == nil && cfg.Platform != "" {
|
||||
platform = cfg.Platform
|
||||
}
|
||||
|
||||
if err := h.store.Delete(name); err != nil {
|
||||
writeJSON(w, http.StatusNotFound, errResp(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Trigger hot-reload: disable and clear fields.
|
||||
if h.onChanged != nil {
|
||||
h.onChanged(name, platform, false, nil)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "deleted", "name": name})
|
||||
}
|
||||
|
||||
@@ -4,16 +4,18 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/platform-bridge/internal/config"
|
||||
"git.yeij.top/AskaEth/Cyrene/platform-bridge/internal/logging"
|
||||
)
|
||||
|
||||
// LogHandler exposes message log retrieval endpoints.
|
||||
type LogHandler struct {
|
||||
logger *logging.Logger
|
||||
store *config.Store
|
||||
}
|
||||
|
||||
func NewLogHandler(logger *logging.Logger) *LogHandler {
|
||||
return &LogHandler{logger: logger}
|
||||
func NewLogHandler(logger *logging.Logger, store *config.Store) *LogHandler {
|
||||
return &LogHandler{logger: logger, store: store}
|
||||
}
|
||||
|
||||
func (h *LogHandler) RegisterRoutes(mux *http.ServeMux) {
|
||||
@@ -27,6 +29,14 @@ func (h *LogHandler) handleLogs(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve platform type from config name (e.g. "qq-home" → "qq").
|
||||
platform := name
|
||||
if h.store != nil {
|
||||
if cfg, err := h.store.Get(name); err == nil && cfg.Platform != "" {
|
||||
platform = cfg.Platform
|
||||
}
|
||||
}
|
||||
|
||||
limit := 100
|
||||
if l := r.URL.Query().Get("limit"); l != "" {
|
||||
if n, err := strconv.Atoi(l); err == nil && n > 0 && n <= 1000 {
|
||||
@@ -34,7 +44,7 @@ func (h *LogHandler) handleLogs(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
entries, err := h.logger.ReadLogs(name, limit)
|
||||
entries, err := h.logger.ReadLogs(platform, limit)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, errResp(err.Error()))
|
||||
return
|
||||
@@ -43,7 +53,7 @@ func (h *LogHandler) handleLogs(w http.ResponseWriter, r *http.Request) {
|
||||
entries = []logging.LogEntry{}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"platform": name,
|
||||
"platform": platform,
|
||||
"total": len(entries),
|
||||
"logs": entries,
|
||||
})
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/platform-bridge/internal/logging"
|
||||
)
|
||||
|
||||
var wsUpgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// LogWSHub broadcasts log entries to connected WebSocket clients.
|
||||
type LogWSHub struct {
|
||||
mu sync.Mutex
|
||||
clients map[*websocket.Conn]chan logging.LogEntry
|
||||
}
|
||||
|
||||
// NewLogWSHub creates a LogWSHub and subscribes to the logger.
|
||||
func NewLogWSHub(logger *logging.Logger) *LogWSHub {
|
||||
h := &LogWSHub{
|
||||
clients: make(map[*websocket.Conn]chan logging.LogEntry),
|
||||
}
|
||||
logger.OnLog(func(entry logging.LogEntry) {
|
||||
h.broadcast(entry)
|
||||
})
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *LogWSHub) broadcast(entry logging.LogEntry) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
for _, ch := range h.clients {
|
||||
select {
|
||||
case ch <- entry:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ServeWS handles WebSocket upgrade and streams log entries to the client.
|
||||
func (h *LogWSHub) ServeWS(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := wsUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ch := make(chan logging.LogEntry, 64)
|
||||
h.mu.Lock()
|
||||
h.clients[conn] = ch
|
||||
h.mu.Unlock()
|
||||
|
||||
// Write goroutine: drains ch until it is closed.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for entry := range ch {
|
||||
data, _ := json.Marshal(entry)
|
||||
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Read goroutine: detect client disconnect.
|
||||
// (websocket requires a reader to detect close frames.)
|
||||
go func() {
|
||||
for {
|
||||
if _, _, err := conn.ReadMessage(); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Client disconnected — stop broadcasting, close channel.
|
||||
h.mu.Lock()
|
||||
delete(h.clients, conn)
|
||||
h.mu.Unlock()
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
<-done
|
||||
conn.Close()
|
||||
}
|
||||
@@ -18,6 +18,7 @@ type LogEntry struct {
|
||||
ChannelID string `json:"channel_id"`
|
||||
SenderID string `json:"sender_id"`
|
||||
SenderName string `json:"sender_name"`
|
||||
GroupName string `json:"group_name,omitempty"`
|
||||
Content string `json:"content"`
|
||||
ContentType string `json:"content_type"`
|
||||
MessageID string `json:"message_id,omitempty"`
|
||||
@@ -25,11 +26,23 @@ type LogEntry struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// LogListener receives log entries as they are written.
|
||||
type LogListener func(LogEntry)
|
||||
|
||||
// Logger writes message logs to per-platform JSONL files.
|
||||
type Logger struct {
|
||||
mu sync.Mutex
|
||||
dir string
|
||||
files map[string]*os.File
|
||||
mu sync.Mutex
|
||||
dir string
|
||||
files map[string]*os.File
|
||||
listeners []LogListener
|
||||
}
|
||||
|
||||
// OnLog registers a listener that is called for every log entry written.
|
||||
// The listener is called synchronously; avoid heavy work in the callback.
|
||||
func (l *Logger) OnLog(fn LogListener) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.listeners = append(l.listeners, fn)
|
||||
}
|
||||
|
||||
// NewLogger creates a Logger, ensuring the log directory exists.
|
||||
@@ -60,12 +73,23 @@ func (l *Logger) Log(entry LogEntry) error {
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if _, err := f.Write(append(data, '\n')); err != nil {
|
||||
l.mu.Unlock()
|
||||
return fmt.Errorf("write log: %w", err)
|
||||
}
|
||||
return f.Sync()
|
||||
if err := f.Sync(); err != nil {
|
||||
l.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
listeners := make([]LogListener, len(l.listeners))
|
||||
copy(listeners, l.listeners)
|
||||
l.mu.Unlock()
|
||||
|
||||
// Notify listeners outside the lock.
|
||||
for _, fn := range listeners {
|
||||
fn(entry)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadLogs reads the last N log entries for a platform, newest first.
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
)
|
||||
|
||||
type hostAPI struct {
|
||||
registry *manager.ToolRegistry
|
||||
}
|
||||
|
||||
func newHostAPI(registry *manager.ToolRegistry) *hostAPI {
|
||||
return &hostAPI{registry: registry}
|
||||
}
|
||||
|
||||
func (h *hostAPI) CallLLM(_ context.Context, _ []sdk.LLMMessage) (*sdk.LLMResponse, error) {
|
||||
return nil, fmt.Errorf("LLM call not available in plugin host")
|
||||
}
|
||||
|
||||
func (h *hostAPI) SearchMemory(_ context.Context, _, _ string, _ int) ([]sdk.MemoryEntry, error) {
|
||||
return nil, fmt.Errorf("memory search not available in plugin host")
|
||||
}
|
||||
|
||||
func (h *hostAPI) StoreMemory(_ context.Context, _ sdk.MemoryEntry) error {
|
||||
return fmt.Errorf("memory store not available in plugin host")
|
||||
}
|
||||
|
||||
func (h *hostAPI) Logger() sdk.Logger {
|
||||
return log.Default()
|
||||
}
|
||||
|
||||
func (h *hostAPI) GetConfig(key string) (string, error) {
|
||||
return "", fmt.Errorf("config key not found: %s", key)
|
||||
}
|
||||
|
||||
func (h *hostAPI) SetConfig(_, _ string) error { return nil }
|
||||
|
||||
func (h *hostAPI) PublishEvent(_ context.Context, _ map[string]interface{}) error { return nil }
|
||||
|
||||
func (h *hostAPI) HTTPClient() *http.Client {
|
||||
return http.DefaultClient
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
iotquery "git.yeij.top/AskaEth/Cyrene/pkg/plugins/iot_query"
|
||||
)
|
||||
|
||||
type iotClient struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func newIoTClient(baseURL string) *iotClient {
|
||||
return &iotClient{
|
||||
baseURL: baseURL,
|
||||
httpClient: &http.Client{Timeout: 5 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *iotClient) GetAllDevices(ctx context.Context) ([]sdk.IoTDeviceState, error) {
|
||||
url := c.baseURL + "/api/v1/devices"
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Devices []sdk.IoTDeviceState `json:"devices"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.Devices, nil
|
||||
}
|
||||
|
||||
func (c *iotClient) GetDevice(ctx context.Context, deviceID string) (*sdk.IoTDeviceState, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/devices/%s", c.baseURL, deviceID)
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var dev sdk.IoTDeviceState
|
||||
if err := json.NewDecoder(resp.Body).Decode(&dev); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dev, nil
|
||||
}
|
||||
|
||||
// iotControllerAdapter adapts IoTClient to iotcontrol.IoTController.
|
||||
type iotControllerAdapter struct {
|
||||
query iotquery.IoTClient
|
||||
client *http.Client
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func newIoTControllerAdapter(query iotquery.IoTClient, baseURL string) *iotControllerAdapter {
|
||||
return &iotControllerAdapter{
|
||||
query: query,
|
||||
client: &http.Client{Timeout: 5 * time.Second},
|
||||
baseURL: baseURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *iotControllerAdapter) GetDevice(ctx context.Context, deviceID string) (*sdk.IoTDeviceState, error) {
|
||||
return a.query.GetDevice(ctx, deviceID)
|
||||
}
|
||||
|
||||
func (a *iotControllerAdapter) SetDeviceProperty(ctx context.Context, deviceID, property string, value interface{}) error {
|
||||
url := fmt.Sprintf("%s/api/v1/devices/%s/property", a.baseURL, deviceID)
|
||||
body, _ := json.Marshal(map[string]interface{}{"property": property, "value": value})
|
||||
req, _ := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
msg, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("set property failed: HTTP %d - %s", resp.StatusCode, string(msg))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *iotControllerAdapter) ToggleDevice(ctx context.Context, deviceID string) (*sdk.IoTDeviceState, error) {
|
||||
url := fmt.Sprintf("%s/api/v1/devices/%s/toggle", a.baseURL, deviceID)
|
||||
req, _ := http.NewRequestWithContext(ctx, "POST", url, nil)
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var dev sdk.IoTDeviceState
|
||||
if err := json.NewDecoder(resp.Body).Decode(&dev); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dev, nil
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/calculator"
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/crypto"
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/datetime"
|
||||
fileplugin "git.yeij.top/AskaEth/Cyrene/pkg/plugins/file"
|
||||
httpplugin "git.yeij.top/AskaEth/Cyrene/pkg/plugins/http"
|
||||
iotcontrol "git.yeij.top/AskaEth/Cyrene/pkg/plugins/iot_control"
|
||||
iotquery "git.yeij.top/AskaEth/Cyrene/pkg/plugins/iot_query"
|
||||
jsonplugin "git.yeij.top/AskaEth/Cyrene/pkg/plugins/json"
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/markdown"
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/random"
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/sdk"
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/text"
|
||||
webfetch "git.yeij.top/AskaEth/Cyrene/pkg/plugins/web_fetch"
|
||||
websearch "git.yeij.top/AskaEth/Cyrene/pkg/plugins/web_search"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/plugin-manager/internal/config"
|
||||
"git.yeij.top/AskaEth/Cyrene/plugin-manager/internal/handler"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cfg := config.Load()
|
||||
|
||||
var iotAPI iotquery.IoTClient
|
||||
if cfg.IoTSvcURL != "" {
|
||||
iotAPI = newIoTClient(cfg.IoTSvcURL)
|
||||
}
|
||||
|
||||
registry := manager.NewToolRegistry()
|
||||
host := newHostAPI(registry)
|
||||
mgr := manager.NewPluginManager(registry, host)
|
||||
|
||||
builtins := []sdk.Plugin{
|
||||
&calculator.CalculatorPlugin{},
|
||||
&datetime.DatetimePlugin{},
|
||||
&text.TextPlugin{},
|
||||
&crypto.CryptoPlugin{},
|
||||
&random.RandomPlugin{},
|
||||
&markdown.MarkdownPlugin{},
|
||||
&jsonplugin.JSONPlugin{},
|
||||
fileplugin.NewFilePlugin(cfg.DataDir),
|
||||
httpplugin.NewHTTPPlugin(),
|
||||
websearch.NewWebSearchPlugin(),
|
||||
webfetch.NewWebFetchPlugin(),
|
||||
iotquery.NewIoTQueryPlugin(iotAPI),
|
||||
}
|
||||
for _, p := range builtins {
|
||||
if err := mgr.Install(p); err != nil {
|
||||
println("WARN: install plugin failed:", err.Error())
|
||||
}
|
||||
}
|
||||
if iotAPI != nil {
|
||||
ctrlPlugin := iotcontrol.NewIoTControlPlugin(newIoTControllerAdapter(iotAPI, cfg.IoTSvcURL))
|
||||
if err := mgr.Install(ctrlPlugin); err != nil {
|
||||
println("WARN: install plugin failed:", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
errs := mgr.EnableAll(ctx)
|
||||
for _, e := range errs {
|
||||
println("WARN: enable plugin failed:", e.Error())
|
||||
}
|
||||
println("Plugin Manager: all built-in plugins enabled")
|
||||
|
||||
mux := http.NewServeMux()
|
||||
ph := handler.NewPluginHandler(mgr)
|
||||
ph.RegisterRoutes(mux)
|
||||
|
||||
println("Plugin Manager listening on port", cfg.Port)
|
||||
srv := &http.Server{Addr: ":" + cfg.Port, Handler: mux}
|
||||
|
||||
go func() {
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
println("FATAL:", err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
println("Shutting down Plugin Manager...")
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
mgr.Shutdown(shutdownCtx)
|
||||
srv.Shutdown(shutdownCtx)
|
||||
println("Plugin Manager stopped")
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
module git.yeij.top/AskaEth/Cyrene/plugin-manager
|
||||
|
||||
go 1.26.2
|
||||
@@ -1,32 +0,0 @@
|
||||
package config
|
||||
|
||||
import "os"
|
||||
|
||||
type Config struct {
|
||||
Port string
|
||||
Env string
|
||||
DataDir string
|
||||
IoTSvcURL string
|
||||
}
|
||||
|
||||
func Load() *Config {
|
||||
cfg := &Config{
|
||||
Port: "8094",
|
||||
Env: "development",
|
||||
DataDir: "./data",
|
||||
IoTSvcURL: "http://localhost:8093",
|
||||
}
|
||||
if v := os.Getenv("PORT"); v != "" {
|
||||
cfg.Port = v
|
||||
}
|
||||
if v := os.Getenv("ENV"); v != "" {
|
||||
cfg.Env = v
|
||||
}
|
||||
if v := os.Getenv("DATA_DIR"); v != "" {
|
||||
cfg.DataDir = v
|
||||
}
|
||||
if v := os.Getenv("IOT_SERVICE_URL"); v != "" {
|
||||
cfg.IoTSvcURL = v
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
@@ -1,210 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/plugins/manager"
|
||||
)
|
||||
|
||||
// PluginHandler exposes the Plugin Manager REST API via net/http.
|
||||
type PluginHandler struct {
|
||||
mgr *manager.PluginManager
|
||||
}
|
||||
|
||||
func NewPluginHandler(mgr *manager.PluginManager) *PluginHandler {
|
||||
return &PluginHandler{mgr: mgr}
|
||||
}
|
||||
|
||||
func (h *PluginHandler) RegisterRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/v1/plugins", h.listPlugins)
|
||||
mux.HandleFunc("/api/v1/plugins/", h.pluginRoute)
|
||||
mux.HandleFunc("/api/v1/tools", h.listTools)
|
||||
mux.HandleFunc("/api/v1/tools/", h.toolRoute)
|
||||
mux.HandleFunc("/api/v1/health", h.health)
|
||||
}
|
||||
|
||||
func (h *PluginHandler) health(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{"status": "ok", "service": "plugin-manager"})
|
||||
}
|
||||
|
||||
func (h *PluginHandler) listPlugins(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
||||
return
|
||||
}
|
||||
plugins := h.mgr.List()
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{"plugins": plugins, "total": len(plugins)})
|
||||
}
|
||||
|
||||
func (h *PluginHandler) pluginRoute(w http.ResponseWriter, r *http.Request) {
|
||||
path := strings.TrimPrefix(r.URL.Path, "/api/v1/plugins/")
|
||||
parts := strings.SplitN(path, "/", 2)
|
||||
pluginID := parts[0]
|
||||
|
||||
if pluginID == "" {
|
||||
h.listPlugins(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if len(parts) == 1 {
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
h.getPlugin(w, pluginID)
|
||||
case "DELETE":
|
||||
h.uninstallPlugin(w, r, pluginID)
|
||||
default:
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
action := parts[1]
|
||||
switch action {
|
||||
case "enable":
|
||||
if r.Method != "POST" {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
||||
return
|
||||
}
|
||||
h.enablePlugin(w, r, pluginID)
|
||||
case "disable":
|
||||
if r.Method != "POST" {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
||||
return
|
||||
}
|
||||
h.disablePlugin(w, r, pluginID)
|
||||
case "reload":
|
||||
if r.Method != "POST" {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
||||
return
|
||||
}
|
||||
h.reloadPlugin(w, r, pluginID)
|
||||
case "tools":
|
||||
if r.Method != "GET" {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
||||
return
|
||||
}
|
||||
h.pluginTools(w, pluginID)
|
||||
default:
|
||||
writeJSON(w, http.StatusNotFound, errResp("not found"))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *PluginHandler) getPlugin(w http.ResponseWriter, id string) {
|
||||
info, ok := h.mgr.Get(id)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusNotFound, errResp("plugin not found"))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, info)
|
||||
}
|
||||
|
||||
func (h *PluginHandler) enablePlugin(w http.ResponseWriter, r *http.Request, id string) {
|
||||
if err := h.mgr.Enable(r.Context(), id); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, errResp(err.Error()))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "enabled"})
|
||||
}
|
||||
|
||||
func (h *PluginHandler) disablePlugin(w http.ResponseWriter, r *http.Request, id string) {
|
||||
if err := h.mgr.Disable(r.Context(), id); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, errResp(err.Error()))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "disabled"})
|
||||
}
|
||||
|
||||
func (h *PluginHandler) reloadPlugin(w http.ResponseWriter, r *http.Request, id string) {
|
||||
if err := h.mgr.Reload(r.Context(), id); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, errResp(err.Error()))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "reloaded"})
|
||||
}
|
||||
|
||||
func (h *PluginHandler) uninstallPlugin(w http.ResponseWriter, r *http.Request, id string) {
|
||||
if err := h.mgr.Uninstall(r.Context(), id); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, errResp(err.Error()))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "uninstalled"})
|
||||
}
|
||||
|
||||
func (h *PluginHandler) pluginTools(w http.ResponseWriter, id string) {
|
||||
info, ok := h.mgr.Get(id)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusNotFound, errResp("plugin not found"))
|
||||
return
|
||||
}
|
||||
registry := h.mgr.Registry()
|
||||
tools := make([]interface{}, 0)
|
||||
for _, toolID := range info.Tools {
|
||||
if t, ok := registry.Get(toolID); ok {
|
||||
tools = append(tools, t.Definition())
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{"tools": tools, "total": len(tools)})
|
||||
}
|
||||
|
||||
func (h *PluginHandler) listTools(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
||||
return
|
||||
}
|
||||
defs := h.mgr.Registry().Definitions()
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{"tools": defs, "total": len(defs)})
|
||||
}
|
||||
|
||||
func (h *PluginHandler) toolRoute(w http.ResponseWriter, r *http.Request) {
|
||||
path := strings.TrimPrefix(r.URL.Path, "/api/v1/tools/")
|
||||
toolID := path
|
||||
|
||||
if strings.HasSuffix(path, "/execute") {
|
||||
toolID = strings.TrimSuffix(path, "/execute")
|
||||
if r.Method != "POST" {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
||||
return
|
||||
}
|
||||
h.executeTool(w, r, toolID)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != "GET" {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errResp("method not allowed"))
|
||||
return
|
||||
}
|
||||
tool, ok := h.mgr.Registry().Get(toolID)
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusNotFound, errResp("tool not found"))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, tool.Definition())
|
||||
}
|
||||
|
||||
func (h *PluginHandler) executeTool(w http.ResponseWriter, r *http.Request, toolID string) {
|
||||
var body struct {
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, errResp("invalid request body"))
|
||||
return
|
||||
}
|
||||
result, err := h.mgr.Registry().Execute(r.Context(), toolID, body.Arguments)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, errResp(err.Error()))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func errResp(msg string) map[string]string {
|
||||
return map[string]string{"error": msg}
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
@@ -19,7 +19,7 @@ RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /voice-service ./cmd/m
|
||||
# ========== 运行阶段 ==========
|
||||
FROM alpine:3.21
|
||||
|
||||
RUN apk add --no-cache ca-certificates tzdata && \
|
||||
RUN apk add --no-cache ca-certificates tzdata ffmpeg && \
|
||||
cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \
|
||||
echo "Asia/Shanghai" > /etc/timezone
|
||||
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func main() {
|
||||
mode := flag.String("mode", "offline", "测试模式: offline (非实时) 或 realtime (实时)")
|
||||
file := flag.String("file", "", "音频文件路径 (WAV/MP3/OGG/FLAC)")
|
||||
server := flag.String("server", "http://localhost:8093", "Voice-Service 地址")
|
||||
lang := flag.String("lang", "zh", "语言代码")
|
||||
flag.Parse()
|
||||
|
||||
if *file == "" {
|
||||
fmt.Println("用法: test_asr -mode=offline -file=audio.wav [-server=http://localhost:8093]")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
switch *mode {
|
||||
case "offline":
|
||||
testOffline(*server, *file, *lang)
|
||||
case "realtime":
|
||||
testRealtime(*server, *file, *lang)
|
||||
default:
|
||||
fmt.Printf("未知模式: %s (支持: offline, realtime)\n", *mode)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// testOffline 测试非实时语音识别 (HTTP multipart 上传)。
|
||||
func testOffline(server, filePath, lang string) {
|
||||
fmt.Printf("=== 非实时 ASR 测试 ===\n")
|
||||
fmt.Printf("服务器: %s\n", server)
|
||||
fmt.Printf("文件: %s\n", filePath)
|
||||
fmt.Printf("语言: %s\n\n", lang)
|
||||
|
||||
// 读取音频文件
|
||||
audioData, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
fmt.Printf("读取文件失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("音频大小: %d bytes\n", len(audioData))
|
||||
|
||||
// 创建 multipart 请求
|
||||
req, err := http.NewRequest("POST", server+"/api/v1/transcribe", nil)
|
||||
if err != nil {
|
||||
fmt.Printf("创建请求失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 使用 multipart form
|
||||
body, contentType, err := createMultipartBody(audioData, filePath, lang)
|
||||
if err != nil {
|
||||
fmt.Printf("创建 multipart body 失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
req.Body = body
|
||||
req.Header.Set("Content-Type", contentType)
|
||||
|
||||
start := time.Now()
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("请求失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
elapsed := time.Since(start)
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
fmt.Printf("状态码: %d\n", resp.StatusCode)
|
||||
fmt.Printf("耗时: %v\n", elapsed)
|
||||
fmt.Printf("响应:\n%s\n", string(respBody))
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
fmt.Println("\n✅ 非实时语音识别成功!")
|
||||
} else {
|
||||
fmt.Println("\n❌ 非实时语音识别失败")
|
||||
}
|
||||
}
|
||||
|
||||
// testRealtime 测试实时语音识别 (WebSocket 流式)。
|
||||
func testRealtime(server, filePath, lang string) {
|
||||
fmt.Printf("=== 实时 ASR 测试 ===\n")
|
||||
fmt.Printf("服务器: %s\n", server)
|
||||
fmt.Printf("文件: %s\n", filePath)
|
||||
fmt.Printf("语言: %s\n\n", lang)
|
||||
|
||||
// 读取音频文件
|
||||
audioData, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
fmt.Printf("读取文件失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("音频大小: %d bytes\n", len(audioData))
|
||||
|
||||
// 推断格式
|
||||
format := inferFormat(filePath)
|
||||
|
||||
// 连接 WebSocket
|
||||
wsURL := fmt.Sprintf("ws://%s/api/v1/stt/stream?format=%s&language=%s",
|
||||
server[7:], format, lang) // 去掉 http:// 前缀
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
fmt.Printf("WebSocket 连接失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
fmt.Printf("WebSocket 已连接: %s\n", wsURL)
|
||||
|
||||
// 设置 interrupt 处理
|
||||
interrupt := make(chan os.Signal, 1)
|
||||
signal.Notify(interrupt, os.Interrupt)
|
||||
|
||||
// goroutine: 读取识别结果
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
_, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
fmt.Printf("读取结果错误: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("◀ 结果: %s\n", string(msg))
|
||||
}
|
||||
}()
|
||||
|
||||
// 模拟实时流式发送音频(每 100ms 发送 3200 bytes)
|
||||
chunkSize := 3200
|
||||
totalSent := 0
|
||||
start := time.Now()
|
||||
var elapsed time.Duration
|
||||
|
||||
cancelled := false
|
||||
for i := 0; i < len(audioData); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
if end > len(audioData) {
|
||||
end = len(audioData)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-interrupt:
|
||||
fmt.Println("\n用户中断")
|
||||
cancelled = true
|
||||
default:
|
||||
}
|
||||
if cancelled {
|
||||
break
|
||||
}
|
||||
|
||||
if err := conn.WriteMessage(websocket.BinaryMessage, audioData[i:end]); err != nil {
|
||||
fmt.Printf("发送音频失败: %v\n", err)
|
||||
break
|
||||
}
|
||||
totalSent += end - i
|
||||
fmt.Printf("▶ 发送 %d/%d bytes (%.1f%%)\n", totalSent, len(audioData),
|
||||
float64(totalSent)/float64(len(audioData))*100)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
elapsed = time.Since(start)
|
||||
|
||||
// 发送停止消息
|
||||
conn.WriteMessage(websocket.TextMessage, []byte(`{"action":"stop"}`))
|
||||
|
||||
// 等待最后的结果
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
fmt.Printf("\n总耗时: %v, 总发送: %d bytes\n", elapsed, totalSent)
|
||||
fmt.Println("✅ 实时语音识别测试完成")
|
||||
}
|
||||
|
||||
func inferFormat(filename string) string {
|
||||
ext := ""
|
||||
for i := len(filename) - 1; i >= 0; i-- {
|
||||
if filename[i] == '.' {
|
||||
ext = filename[i+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
switch ext {
|
||||
case "wav", "wave":
|
||||
return "wav"
|
||||
case "mp3", "mpeg":
|
||||
return "mp3"
|
||||
case "ogg", "opus":
|
||||
return "ogg"
|
||||
case "flac":
|
||||
return "flac"
|
||||
case "m4a", "mp4", "aac":
|
||||
return "m4a"
|
||||
default:
|
||||
return "pcm"
|
||||
}
|
||||
}
|
||||
|
||||
func createMultipartBody(audioData []byte, filename, lang string) (io.ReadCloser, string, error) {
|
||||
boundary := "cyrene-asr-test-boundary"
|
||||
header := fmt.Sprintf("--%s\r\nContent-Disposition: form-data; name=\"audio\"; filename=\"%s\"\r\nContent-Type: application/octet-stream\r\n\r\n",
|
||||
boundary, filename)
|
||||
footer := fmt.Sprintf("\r\n--%s\r\nContent-Disposition: form-data; name=\"language\"\r\n\r\n%s\r\n--%s--\r\n",
|
||||
boundary, lang, boundary)
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
pw.Write([]byte(header))
|
||||
pw.Write(audioData)
|
||||
pw.Write([]byte(footer))
|
||||
pw.Close()
|
||||
}()
|
||||
|
||||
return pr, "multipart/form-data; boundary=" + boundary, nil
|
||||
}
|
||||
@@ -2,9 +2,15 @@ module git.yeij.top/AskaEth/Cyrene/voice-service
|
||||
|
||||
go 1.26.2
|
||||
|
||||
replace git.yeij.top/AskaEth/Cyrene/pkg/logger => ../pkg/logger
|
||||
replace (
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/logger => ../pkg/logger
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/audio => ../pkg/audio
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/dashscope => ../pkg/dashscope
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/audio v0.0.0
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/dashscope v0.0.0
|
||||
git.yeij.top/AskaEth/Cyrene/pkg/logger v0.0.0
|
||||
)
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/dashscope"
|
||||
)
|
||||
|
||||
// DashScopeRESTSTT 使用 DashScope REST API 进行离线语音识别。
|
||||
// 离线模型 (qwen3-asr-flash-2026-02-10) 通过 HTTP REST 端点进行转录,
|
||||
// 无需 session 协商和 Server VAD,延迟更低,适合非实时场景。
|
||||
type DashScopeRESTSTT struct {
|
||||
model string
|
||||
client *dashscope.RESTClient
|
||||
}
|
||||
|
||||
// NewDashScopeRESTSTT 创建 DashScope REST STT 客户端。
|
||||
func NewDashScopeRESTSTT(apiKey, model string) *DashScopeRESTSTT {
|
||||
if model == "" {
|
||||
model = "qwen3-asr-flash-2026-02-10"
|
||||
}
|
||||
return &DashScopeRESTSTT{
|
||||
model: model,
|
||||
client: dashscope.NewRESTClient(apiKey),
|
||||
}
|
||||
}
|
||||
|
||||
// IsAvailable 检查 API Key 是否已配置。
|
||||
func (d *DashScopeRESTSTT) IsAvailable() bool {
|
||||
return d.client.IsAvailable()
|
||||
}
|
||||
|
||||
// Model 返回模型名。
|
||||
func (d *DashScopeRESTSTT) Model() string { return d.model }
|
||||
|
||||
// Transcribe 使用 DashScope REST API 进行离线语音识别。
|
||||
func (d *DashScopeRESTSTT) Transcribe(ctx context.Context, audioData []byte, format, language string) (string, error) {
|
||||
if !d.IsAvailable() {
|
||||
return "", fmt.Errorf("DashScope REST ASR API key 未配置")
|
||||
}
|
||||
return d.client.Transcribe(ctx, d.model, audioData, format, 16000, language)
|
||||
}
|
||||
|
||||
// GetStatus 返回 REST STT 客户端的运行状态。
|
||||
func (d *DashScopeRESTSTT) GetStatus() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"available": d.IsAvailable(),
|
||||
"model": d.model,
|
||||
"protocol": "rest",
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDashScopeRESTSTT_Transcribe_Success(t *testing.T) {
|
||||
client := NewDashScopeRESTSTT("test-key", "qwen3-asr-flash-2026-02-10")
|
||||
|
||||
if !client.IsAvailable() {
|
||||
t.Error("client should be available when apiKey is set")
|
||||
}
|
||||
if client.Model() != "qwen3-asr-flash-2026-02-10" {
|
||||
t.Errorf("unexpected model: %s", client.Model())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashScopeRESTSTT_NotAvailable(t *testing.T) {
|
||||
client := NewDashScopeRESTSTT("", "")
|
||||
if client.IsAvailable() {
|
||||
t.Error("client should not be available without apiKey")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashScopeRESTSTT_DefaultModel(t *testing.T) {
|
||||
client := NewDashScopeRESTSTT("key", "")
|
||||
if client.Model() != "qwen3-asr-flash-2026-02-10" {
|
||||
t.Errorf("expected default model, got %s", client.Model())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashScopeRESTSTT_Transcribe_NoAPIKey(t *testing.T) {
|
||||
client := NewDashScopeRESTSTT("", "")
|
||||
_, err := client.Transcribe(context.Background(), []byte{}, "wav", "zh")
|
||||
if err == nil {
|
||||
t.Error("expected error when API key is not configured")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashScopeRESTSTT_GetStatus(t *testing.T) {
|
||||
client := NewDashScopeRESTSTT("key", "test-model")
|
||||
status := client.GetStatus()
|
||||
if status["available"] != true {
|
||||
t.Error("status should be available")
|
||||
}
|
||||
if status["model"] != "test-model" {
|
||||
t.Errorf("unexpected model in status: %v", status["model"])
|
||||
}
|
||||
if status["protocol"] != "rest" {
|
||||
t.Errorf("unexpected protocol: %v", status["protocol"])
|
||||
}
|
||||
}
|
||||
@@ -5,12 +5,10 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.yeij.top/AskaEth/Cyrene/pkg/audio"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
@@ -145,7 +143,7 @@ func (d *DashScopeSTT) Transcribe(ctx context.Context, audioData []byte, format
|
||||
}
|
||||
|
||||
// 4. 规范化音频格式并发送
|
||||
pcmData, err := convertToPCM16(audioData, format)
|
||||
pcmData, err := audio.ConvertToPCM16(audioData, format)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("音频格式转换失败: %w", err)
|
||||
}
|
||||
@@ -447,74 +445,3 @@ func (d *DashScopeSTT) GetStatus() map[string]interface{} {
|
||||
"provider": "dashscope",
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeSTTFormat 规范化音频格式字符串。
|
||||
func normalizeSTTFormat(format string) string {
|
||||
switch strings.ToLower(format) {
|
||||
case "pcm", "wav", "mp3", "mpeg", "ogg", "opus", "flac", "m4a", "mp4", "aac", "webm":
|
||||
return strings.ToLower(format)
|
||||
default:
|
||||
return format
|
||||
}
|
||||
}
|
||||
|
||||
// convertToPCM16 将音频数据转换为 16-bit PCM 16000Hz mono。
|
||||
func convertToPCM16(data []byte, format string) ([]byte, error) {
|
||||
normFormat := normalizeSTTFormat(format)
|
||||
switch normFormat {
|
||||
case "pcm":
|
||||
return data, nil
|
||||
case "wav":
|
||||
if len(data) > 44 {
|
||||
return data[44:], nil
|
||||
}
|
||||
return data, nil
|
||||
default:
|
||||
return transcodeToPCM(data, normFormat)
|
||||
}
|
||||
}
|
||||
|
||||
// transcodeToPCM 使用 ffmpeg 将音频数据转码为 PCM 16-bit 16000Hz mono。
|
||||
func transcodeToPCM(data []byte, format string) ([]byte, error) {
|
||||
inFile, err := os.CreateTemp(os.TempDir(), "cyrene-asr-in-*."+format)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建输入临时文件失败: %w", err)
|
||||
}
|
||||
inPath := inFile.Name()
|
||||
defer os.Remove(inPath)
|
||||
if _, err := inFile.Write(data); err != nil {
|
||||
inFile.Close()
|
||||
return nil, fmt.Errorf("写入输入临时文件失败: %w", err)
|
||||
}
|
||||
inFile.Close()
|
||||
|
||||
outFile, err := os.CreateTemp(os.TempDir(), "cyrene-asr-out-*.pcm")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建输出临时文件失败: %w", err)
|
||||
}
|
||||
outPath := outFile.Name()
|
||||
outFile.Close()
|
||||
defer os.Remove(outPath)
|
||||
|
||||
cmd := exec.Command("ffmpeg",
|
||||
"-i", inPath,
|
||||
"-ar", "16000",
|
||||
"-ac", "1",
|
||||
"-c:a", "pcm_s16le",
|
||||
"-f", "s16le",
|
||||
outPath,
|
||||
"-y",
|
||||
)
|
||||
cmd.Stderr = nil
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("音频转码失败 (ffmpeg): %w", err)
|
||||
}
|
||||
|
||||
outData, err := os.ReadFile(outPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取转码结果失败: %w", err)
|
||||
}
|
||||
|
||||
return outData, nil
|
||||
}
|
||||
|
||||
@@ -16,41 +16,59 @@ import (
|
||||
var SupportedLanguages = []string{"zh", "en", "ja", "ko", "auto"}
|
||||
|
||||
// STTService 语音转文字服务。
|
||||
// 优先使用 DashScope API,不可用时回退到本地 Whisper。
|
||||
// 离线转录优先使用 DashScope REST API,失败回退 Whisper。
|
||||
// 流式转录使用 DashScope Realtime WS。
|
||||
type STTService struct {
|
||||
whisperBinary string
|
||||
whisperModel string
|
||||
language string
|
||||
dashscope *DashScopeSTT // 实时 ASR (qwen3-asr-flash-realtime)
|
||||
whisperBinary string
|
||||
whisperModel string
|
||||
language string
|
||||
dashscope *DashScopeSTT // 实时 ASR (qwen3-asr-flash-realtime)
|
||||
dashscopeREST *DashScopeRESTSTT // 离线 ASR (qwen3-asr-flash-2026-02-10)
|
||||
}
|
||||
|
||||
// NewSTTService 创建 STT 服务。
|
||||
func NewSTTService(cfg *config.Config) *STTService {
|
||||
// 实时模型用于所有 WebSocket ASR 请求(支持 one-shot 和 streaming)
|
||||
// 离线模型 (qwen3-asr-flash-2026-02-10) 是 HTTP REST API,暂未实现
|
||||
model := cfg.DashScopeSTTRealtime
|
||||
if model == "" {
|
||||
model = cfg.DashScopeModel
|
||||
realtimeModel := cfg.DashScopeSTTRealtime
|
||||
if realtimeModel == "" {
|
||||
realtimeModel = "qwen3-asr-flash-realtime"
|
||||
}
|
||||
offlineModel := cfg.DashScopeModel
|
||||
if offlineModel == "" {
|
||||
offlineModel = "qwen3-asr-flash-2026-02-10"
|
||||
}
|
||||
return &STTService{
|
||||
whisperBinary: cfg.WhisperBinary,
|
||||
whisperModel: cfg.WhisperModel,
|
||||
language: cfg.WhisperLanguage,
|
||||
dashscope: NewDashScopeSTT(cfg.DashScopeAPIKey, model),
|
||||
dashscope: NewDashScopeSTT(cfg.DashScopeAPIKey, realtimeModel),
|
||||
dashscopeREST: NewDashScopeRESTSTT(cfg.DashScopeAPIKey, offlineModel),
|
||||
}
|
||||
}
|
||||
|
||||
// IsAvailable 检查是否有任一 STT 引擎可用。
|
||||
func (s *STTService) IsAvailable() bool {
|
||||
if s.dashscope.IsAvailable() {
|
||||
if s.dashscopeREST.IsAvailable() || s.dashscope.IsAvailable() {
|
||||
return true
|
||||
}
|
||||
_, err := os.Stat(s.whisperBinary)
|
||||
return err == nil
|
||||
return s.whisperAvailable()
|
||||
}
|
||||
|
||||
// whisperAvailable 检查本地 Whisper 引擎是否真正可用。
|
||||
func (s *STTService) whisperAvailable() bool {
|
||||
if _, err := os.Stat(s.whisperBinary); err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := os.Stat(s.whisperModel); err != nil {
|
||||
return false
|
||||
}
|
||||
if _, err := exec.LookPath("ffmpeg"); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Transcribe 将音频数据转录为文字。
|
||||
// 优先使用 DashScope,不可用时回退到本地 Whisper。
|
||||
// 优先使用 DashScope REST 离线模型,失败回退到本地 Whisper。
|
||||
func (s *STTService) Transcribe(audioData []byte, format string, language string) (string, error) {
|
||||
if language == "" {
|
||||
language = s.language
|
||||
@@ -59,16 +77,15 @@ func (s *STTService) Transcribe(audioData []byte, format string, language string
|
||||
return "", fmt.Errorf("不支持的语言: %s,支持的语言: %s", language, strings.Join(SupportedLanguages, ", "))
|
||||
}
|
||||
|
||||
// 优先 DashScope
|
||||
if s.dashscope.IsAvailable() {
|
||||
// 优先 DashScope REST 离线模型(低延迟,无需 session 协商)
|
||||
if s.dashscopeREST.IsAvailable() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
text, err := s.dashscope.Transcribe(ctx, audioData, format, language)
|
||||
text, err := s.dashscopeREST.Transcribe(ctx, audioData, format, language)
|
||||
if err == nil {
|
||||
return text, nil
|
||||
}
|
||||
// DashScope 失败,返回具体错误而不是回退到 Whisper
|
||||
return "", fmt.Errorf("语音识别失败: %w", err)
|
||||
fmt.Printf("[stt] DashScope REST 失败,回退 Whisper: %v\n", err)
|
||||
}
|
||||
|
||||
// 回退到本地 Whisper
|
||||
@@ -152,15 +169,21 @@ func (s *STTService) GetStatus() map[string]interface{} {
|
||||
if _, err := os.Stat(s.whisperModel); err == nil {
|
||||
modelExists = true
|
||||
}
|
||||
ffmpegAvailable := false
|
||||
if _, err := exec.LookPath("ffmpeg"); err == nil {
|
||||
ffmpegAvailable = true
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"available": s.IsAvailable(),
|
||||
"primary": "dashscope",
|
||||
"dashscope": s.dashscope.GetStatus(),
|
||||
"available": s.IsAvailable(),
|
||||
"primary": "dashscope_rest",
|
||||
"dashscope_rest": s.dashscopeREST.GetStatus(),
|
||||
"dashscope_ws": s.dashscope.GetStatus(),
|
||||
"whisper": map[string]interface{}{
|
||||
"available": binaryAvailable && modelExists,
|
||||
"available": s.whisperAvailable(),
|
||||
"binary_available": binaryAvailable,
|
||||
"model_loaded": modelExists,
|
||||
"ffmpeg_available": ffmpegAvailable,
|
||||
"model_name": filepath.Base(s.whisperModel),
|
||||
},
|
||||
"default_language": s.language,
|
||||
|
||||
@@ -67,6 +67,18 @@ services:
|
||||
- "4222:4222"
|
||||
- "8222:8222"
|
||||
|
||||
# ========== SearXNG 搜索引擎 ==========
|
||||
searxng:
|
||||
image: searxng/searxng:latest
|
||||
container_name: cyrene_searxng
|
||||
volumes:
|
||||
- ./searxng/settings.yml:/etc/searxng/settings.yml:ro
|
||||
environment:
|
||||
SEARXNG_SETTINGS_PATH: /etc/searxng/settings.yml
|
||||
ports:
|
||||
- "8088:8080"
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
pg_data:
|
||||
redis_data:
|
||||
|
||||
@@ -133,6 +133,26 @@ services:
|
||||
WHISPER_LANGUAGE: "zh"
|
||||
restart: unless-stopped
|
||||
|
||||
platform-bridge:
|
||||
container_name: cyrene_platform_bridge
|
||||
build:
|
||||
context: .
|
||||
dockerfile: ./backend/platform-bridge/Dockerfile
|
||||
ports:
|
||||
- "${QQ_BOT_PORT:-8096}:8096"
|
||||
environment:
|
||||
PORT: "8095"
|
||||
ENV: production
|
||||
GATEWAY_URL: http://gateway:8080
|
||||
AI_CORE_URL: http://ai-core:8081
|
||||
INTERNAL_SERVICE_TOKEN: ${INTERNAL_SERVICE_TOKEN}
|
||||
QQ_BOT_PORT: ${QQ_BOT_PORT:-8096}
|
||||
TELEGRAM_BOT_TOKEN: ${TELEGRAM_BOT_TOKEN:-}
|
||||
TELEGRAM_WEBHOOK_URL: ${TELEGRAM_WEBHOOK_URL:-}
|
||||
depends_on:
|
||||
- gateway
|
||||
restart: unless-stopped
|
||||
|
||||
iot-debug-service:
|
||||
container_name: cyrene_iot_debug_service
|
||||
build:
|
||||
@@ -155,6 +175,7 @@ services:
|
||||
MEMORY_SERVICE_URL: http://memory-service:8091
|
||||
VOICE_SERVICE_URL: http://voice-service:8093
|
||||
IOT_DEBUG_SERVICE_URL: http://iot-debug-service:8083
|
||||
PLATFORM_BRIDGE_URL: http://platform-bridge:8095
|
||||
ADMIN_USERNAME: ${ADMIN_USERNAME:-admin}
|
||||
ADMIN_PASSWORD: ${ADMIN_PASSWORD}
|
||||
RUNNING_IN_DOCKER: "true"
|
||||
|
||||
@@ -3,9 +3,14 @@
|
||||
**Base URL:** `http://<host>:8093` | **Auth:** 无
|
||||
|
||||
语音服务封装两层引擎:
|
||||
- **STT (语音转文字):** DashScope `qwen3-asr-flash-realtime` (主) + 本地 Whisper (备)
|
||||
- **STT (语音转文字):** DashScope REST 离线模型 `qwen3-asr-flash-2026-02-10` (主) → DashScope Realtime WS `qwen3-asr-flash-realtime` (流式) → 本地 Whisper (备)
|
||||
- **TTS (文字转语音):** edge-tts (主) + espeak-ng (备)
|
||||
|
||||
> **引擎分层说明:**
|
||||
> - **离线转录** (`POST /api/v1/transcribe`): 使用 DashScope REST API,无需 session 协商和 Server VAD,延迟更低。失败自动回退 Whisper。
|
||||
> - **流式转录** (`GET /api/v1/stt/stream`): 使用 DashScope Realtime WebSocket,支持实时分片输入和中间结果输出,需客户端发送 PCM 音频。
|
||||
> - **音频转码**: 所有非 PCM 格式通过 ffmpeg 转码为 16kHz mono PCM 后再识别,支持 WAV/MP3/OGG/FLAC/WebM/Opus/AMR 等格式。
|
||||
|
||||
---
|
||||
|
||||
## 目录
|
||||
@@ -24,13 +29,25 @@
|
||||
|
||||
**Content-Type:** `multipart/form-data` | **Max body:** 10 MB
|
||||
|
||||
### 引擎选择
|
||||
|
||||
优先使用 DashScope REST 离线模型 `qwen3-asr-flash-2026-02-10`,失败自动回退本地 Whisper。
|
||||
|
||||
```
|
||||
接收音频 → DashScope REST API (HTTP POST)
|
||||
↓ 失败
|
||||
ffmpeg 转码 PCM → 本地 Whisper 引擎
|
||||
```
|
||||
|
||||
### 表单字段
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
|------|------|------|------|
|
||||
| `audio` | file | 是 | 音频文件。格式从扩展名推断:wav/mp3/ogg/flac/m4a |
|
||||
| `audio` | file | 是 | 音频文件。支持格式: wav, mp3, ogg, flac, m4a, aac, webm, opus, amr, pcm |
|
||||
| `language` | string | 否 | 默认 `"zh"`。可选: `zh`, `en`, `ja`, `ko`, `auto` |
|
||||
|
||||
> **转码说明:** 非 PCM 格式(含 Opus/WebM/AMR)通过 ffmpeg 自动转码为 16-bit PCM 16000Hz mono 后识别。需部署环境安装 ffmpeg。
|
||||
|
||||
### 响应 200
|
||||
|
||||
```json
|
||||
@@ -49,7 +66,7 @@
|
||||
| 400 | `{"error":"文件过大或解析失败,最大支持 10MB"}` |
|
||||
| 400 | `{"error":"缺少 audio 文件字段"}` |
|
||||
| 400 | `{"error":"音频文件为空"}` |
|
||||
| 400 | `{"error":"不支持的音频格式: <ext>,支持的格式: WAV, MP3, OGG, FLAC, M4A"}` |
|
||||
| 400 | `{"error":"不支持的语言: <lang>,支持的语言: zh, en, ja, ko, auto"}` |
|
||||
| 405 | `{"error":"method not allowed"}` |
|
||||
| 500 | `{"error":"读取音频文件失败"}` |
|
||||
| 500 | `{"success":false,"error":"<engine error>"}` |
|
||||
@@ -112,12 +129,23 @@
|
||||
"service": "voice-service",
|
||||
"stt": {
|
||||
"available": true,
|
||||
"primary": "dashscope",
|
||||
"dashscope": { "available": true, "model": "qwen3-asr-flash-realtime", "provider": "dashscope" },
|
||||
"primary": "dashscope_rest",
|
||||
"dashscope_rest": {
|
||||
"available": true,
|
||||
"model": "qwen3-asr-flash-2026-02-10",
|
||||
"protocol": "rest"
|
||||
},
|
||||
"dashscope_ws": {
|
||||
"available": true,
|
||||
"model": "qwen3-asr-flash-realtime",
|
||||
"protocol": "websocket",
|
||||
"state": "idle"
|
||||
},
|
||||
"whisper": {
|
||||
"available": true,
|
||||
"binary_available": true,
|
||||
"model_loaded": true,
|
||||
"ffmpeg_available": true,
|
||||
"model_name": "ggml-small.bin"
|
||||
},
|
||||
"default_language": "zh",
|
||||
@@ -138,9 +166,15 @@
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `stt.available` | DashScope 或 Whisper 至少一个可用 |
|
||||
| `stt.dashscope.available` | DashScope API Key 已配置 |
|
||||
| `stt.whisper.available` | Whisper 二进制 + 模型文件均存在 |
|
||||
| `stt.available` | DashScope REST / WS 或 Whisper 至少一个可用 |
|
||||
| `stt.primary` | 当前优先引擎: `dashscope_rest` |
|
||||
| `stt.dashscope_rest.available` | DashScope REST API Key 已配置 |
|
||||
| `stt.dashscope_rest.protocol` | 协议类型: `rest` |
|
||||
| `stt.dashscope_ws.available` | DashScope Realtime WS 可用 |
|
||||
| `stt.dashscope_ws.protocol` | 协议类型: `websocket` |
|
||||
| `stt.dashscope_ws.state` | 连接状态: `idle`, `connected`, `error` |
|
||||
| `stt.whisper.available` | Whisper 二进制 + 模型文件 + ffmpeg 均存在 |
|
||||
| `stt.whisper.ffmpeg_available` | ffmpeg 可用于音频转码 |
|
||||
| `tts.available` | 至少一个 TTS 引擎可用 |
|
||||
| `tts.engine` | 当前激活引擎: `edge-tts`, `espeak-ng`, `fallback (silent WAV)`, `none` |
|
||||
|
||||
@@ -183,6 +217,10 @@
|
||||
**Query 参数:** `?language=zh&format=pcm` (language 默认 zh, format 默认 pcm)
|
||||
**Read deadline:** 300s
|
||||
|
||||
> **注意:** 此端点使用 DashScope Realtime WebSocket (`qwen3-asr-flash-realtime`),音频帧必须是 PCM 格式。非 PCM 格式应使用 REST 离线转录 (`POST /api/v1/transcribe`)。
|
||||
>
|
||||
> **Gateway 代理:** Gateway 的 `voice_stream_*` 消息类型通过此端点与前端 VAD 配合,实现端到端流式语音 → STT → LLM 管道。详见 [Gateway WebSocket 文档](../../gateway-api.md#流式语音输入流程-voice_stream_)。
|
||||
|
||||
### 客户端 → 服务端
|
||||
|
||||
**Binary 帧:** 原始 PCM 音频 (16-bit LE, 16000Hz, mono)。每帧通过 `input_audio_buffer.append` 转发到 DashScope。
|
||||
|
||||
+137
-6
@@ -167,11 +167,14 @@ ws://<gateway>/ws/chat?token=<jwt>&session_id=<optional>&client_id=<optional>&de
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "message|voice_input|ping|history",
|
||||
"type": "message|voice_input|voice_stream_start|voice_stream_chunk|voice_stream_end|ping|history",
|
||||
"session_id": "string (可选)",
|
||||
"mode": "text|voice_msg|voice_assistant",
|
||||
"content": "string (纯图片消息可留空,文字+图片时填写提问内容)",
|
||||
"audio_data": "string (voice_input 类型必填, base64)",
|
||||
"audio_data": "string (voice_input / voice_stream_chunk 类型必填, base64)",
|
||||
"format": "string (voice_stream_start 可选, 音频格式: webm, wav, pcm, opus; 默认 webm)",
|
||||
"language": "string (voice_stream_start 可选, 识别语言: zh, en, ja, ko, auto; 默认 zh)",
|
||||
"sequence": 0,
|
||||
"attachments": [
|
||||
{
|
||||
"type": "image",
|
||||
@@ -194,14 +197,18 @@ ws://<gateway>/ws/chat?token=<jwt>&session_id=<optional>&client_id=<optional>&de
|
||||
"timestamp": 1717000000000,
|
||||
"client_id": "string",
|
||||
"device_name": "string",
|
||||
"user_agent": "string"
|
||||
"user_agent": "string",
|
||||
"client_msg_id": "string"
|
||||
}
|
||||
```
|
||||
|
||||
| type | 说明 |
|
||||
|------|------|
|
||||
| `message` | 文字聊天,触发 AI 回复 |
|
||||
| `voice_input` | 语音输入,先转录再作为 message 处理 |
|
||||
| `voice_input` | 语音输入(完整音频),先转录再作为 message 处理 |
|
||||
| `voice_stream_start` | 开启流式语音会话,Gateway 连接 Voice-Service 流式 STT |
|
||||
| `voice_stream_chunk` | 流式语音音频分片 (base64),Gateway 转发至 Voice-Service |
|
||||
| `voice_stream_end` | 结束流式语音,等待最终识别结果,自动触发 LLM 回复 |
|
||||
| `ping` | 心跳,自动回复 pong |
|
||||
| `history` | 请求历史消息 |
|
||||
|
||||
@@ -244,7 +251,9 @@ ws://<gateway>/ws/chat?token=<jwt>&session_id=<optional>&client_id=<optional>&de
|
||||
| `stream_chunk` | 增量文本块 |
|
||||
| `stream_end` | AI 生成结束(含完整 text) |
|
||||
| `stream_segments` | 流式断句(语音) |
|
||||
| `voice_transcript` | 语音转录结果 |
|
||||
| `voice_transcript` | 语音转录结果 (非流式, voice_input) |
|
||||
| `voice_interim` | 流式语音中间识别结果 |
|
||||
| `voice_final` | 流式语音最终识别文本 |
|
||||
| `error` | 错误 |
|
||||
| `history_response` | 历史消息返回 |
|
||||
| `notification` | 推送通知 |
|
||||
@@ -325,7 +334,7 @@ Client Gateway
|
||||
|
||||
---
|
||||
|
||||
### 语音输入流程
|
||||
### 语音输入流程 (非流式)
|
||||
|
||||
```
|
||||
Client Gateway Voice-Service
|
||||
@@ -343,6 +352,128 @@ Client Gateway Voice-Service
|
||||
|<-- ... 正常流式回复 ... | |
|
||||
```
|
||||
|
||||
> **注意:** `voice_input` 为非流式模式,客户端发送完整音频后一次性获取转录结果。适合 MediaRecorder 录音完成后使用。
|
||||
> 推荐使用下方的流式语音输入,配合前端 VAD 实现边说边识别。
|
||||
|
||||
---
|
||||
|
||||
### 流式语音输入流程 (voice_stream_*)
|
||||
|
||||
配合前端 VAD (Voice Activity Detection) 实现自动语音检测和边说边识别。前端逐帧发送音频分片,Gateway 通过 WebSocket 代理到 Voice-Service 流式 STT,实时返回中间结果。
|
||||
|
||||
```
|
||||
Client Gateway Voice-Service
|
||||
| | |
|
||||
|-- {type:"voice_stream_start", | |
|
||||
| format:"webm", language:"zh"} --> | |
|
||||
| |-- WS /api/v1/stt/stream --------> |
|
||||
| |<-- session ready |
|
||||
|<-- {type:"voice_interim", text:""} | |
|
||||
| | |
|
||||
|-- {type:"voice_stream_chunk", | |
|
||||
| audio_data:"<base64>", | |
|
||||
| sequence:0} ------------------> | |
|
||||
| |-- binary audio frame ----------> |
|
||||
| |<-- {type:"result", |
|
||||
| | text:"你好", isFinal:false} |
|
||||
|<-- {type:"voice_interim", | |
|
||||
| text:"你好"} | |
|
||||
| | |
|
||||
|-- ... more chunks ... | |
|
||||
| | |
|
||||
|-- {type:"voice_stream_end"} -----> | |
|
||||
| |-- {action:"stop"} --------------> |
|
||||
| |<-- {type:"result", |
|
||||
| | text:"你好世界", isFinal:true}|
|
||||
|<-- {type:"voice_final", | |
|
||||
| text:"你好世界"} | |
|
||||
| | |
|
||||
| (Gateway 自动将最终文本 | |
|
||||
| 作为 message 发给 AI-Core) | |
|
||||
|<-- {type:"stream_start"} | |
|
||||
|<-- ... 正常流式 LLM 回复 ... | |
|
||||
```
|
||||
|
||||
**消息详情:**
|
||||
|
||||
#### voice_stream_start — 开启流式语音会话
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "voice_stream_start",
|
||||
"format": "webm",
|
||||
"language": "zh"
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
|------|------|------|------|
|
||||
| `format` | string | 否 | 音频格式,默认 `"webm"`。支持: `webm`, `wav`, `pcm`, `opus` |
|
||||
| `language` | string | 否 | 识别语言,默认 `"zh"`。支持: `zh`, `en`, `ja`, `ko`, `auto` |
|
||||
|
||||
Gateway 收到后连接 Voice-Service 流式 STT WebSocket。成功时返回空 `voice_interim` 确认会话建立;失败返回 `error`。
|
||||
|
||||
#### voice_stream_chunk — 发送音频分片
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "voice_stream_chunk",
|
||||
"audio_data": "<base64 encoded audio>",
|
||||
"sequence": 0
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
|------|------|------|------|
|
||||
| `audio_data` | string | 是 | Base64 编码的音频数据 |
|
||||
| `sequence` | int | 否 | 分片序号,从 0 递增,用于排序和去重 |
|
||||
|
||||
Gateway 将 audio_data 解码后以 binary 帧转发至 Voice-Service。无直接响应;识别结果通过 `voice_interim` 异步推送。
|
||||
|
||||
#### voice_stream_end — 结束流式语音
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "voice_stream_end"
|
||||
}
|
||||
```
|
||||
|
||||
Gateway 向 Voice-Service 发送 stop 信号,等待最终识别结果。最终文本通过 `voice_final` 返回,并自动触发 LLM 回复流程。
|
||||
|
||||
#### voice_interim — 中间识别结果 (Server → Client)
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "voice_interim",
|
||||
"message_id": "voice_<random>",
|
||||
"text": "中间识别文本",
|
||||
"timestamp": 1717000000000
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `text` | 当前累积的识别文本,**非最终结果**,会随更多音频输入而更新 |
|
||||
|
||||
> **前端处理:** 收到 `voice_interim` 后应在 UI 中展示实时识别文本(如灰色斜体),收到 `voice_final` 后替换为最终文本。
|
||||
|
||||
#### voice_final — 最终识别结果 (Server → Client)
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "voice_final",
|
||||
"message_id": "voice_<random>",
|
||||
"text": "最终识别文本",
|
||||
"timestamp": 1717000000000
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `text` | 最终的完整识别文本。空字符串表示未识别到语音 |
|
||||
|
||||
收到 `voice_final` 后,Gateway 自动将 `text` 作为 `message` 类型转发至 AI-Core 触发 LLM 流式回复。随后的 `stream_start` / `review` / `stream_end` 流程与普通文字消息相同。
|
||||
|
||||
---
|
||||
|
||||
## 3. 会话管理
|
||||
|
||||
+32
-24
@@ -8,7 +8,7 @@ setlocal enabledelayedexpansion
|
||||
set "SCRIPT_DIR=%~dp0"
|
||||
set "ETHEND_DIR=%SCRIPT_DIR%ethend"
|
||||
set "ROOT=%SCRIPT_DIR%"
|
||||
if "%ETHEND_PORT%"=="" (set PORT=9090) else (set PORT=%ETHEND_PORT%)
|
||||
if "%ETHEND_PORT%"=="" (set ETHEND_PORT=9090) else (set ETHEND_PORT=%ETHEND_PORT%)
|
||||
set "LOG_DIR=%ETHEND_DIR%\logs"
|
||||
set "LOG_FILE=%LOG_DIR%\sh.log"
|
||||
set "DB_COMPOSE_FILE=%ROOT%docker-compose.dev.db.yml"
|
||||
@@ -63,7 +63,7 @@ echo ethend.bat logs gateway View Gateway log
|
||||
echo ethend.bat build ai-core Build AI-Core only
|
||||
echo ethend.bat db:status Check database
|
||||
echo.
|
||||
echo Web console: http://localhost:%PORT%
|
||||
echo Web console: http://localhost:%ETHEND_PORT%
|
||||
exit /b 0
|
||||
|
||||
:: ==========================================
|
||||
@@ -89,7 +89,7 @@ echo Cyrene ethend
|
||||
echo ==========================================
|
||||
call :check_node
|
||||
for /f "tokens=*" %%i in ('!NODE_CMD! --version') do echo Node.js: %%i
|
||||
echo Port: %PORT%
|
||||
echo Port: %ETHEND_PORT%
|
||||
|
||||
:: Check for --build / --fresh
|
||||
set DO_BUILD=0
|
||||
@@ -110,21 +110,21 @@ if %DO_BUILD%==1 (
|
||||
if %DO_FRESH%==1 (
|
||||
echo.
|
||||
echo [INFO] Force restarting all services...
|
||||
curl -s -X POST "http://localhost:%PORT%/api/services/start-all-fresh" >nul 2>&1
|
||||
curl -s -X POST "http://localhost:%ETHEND_PORT%/api/services/start-all-fresh" >nul 2>&1
|
||||
)
|
||||
|
||||
:: Check if already running
|
||||
curl -s -o nul "http://localhost:%PORT%/api/health" 2>nul
|
||||
curl -s -o nul "http://localhost:%ETHEND_PORT%/api/health" 2>nul
|
||||
if %ERRORLEVEL%==0 (
|
||||
echo.
|
||||
echo [OK] ethend already running: http://localhost:%PORT%
|
||||
echo API: http://localhost:%PORT%/api/health
|
||||
echo [OK] ethend already running: http://localhost:%ETHEND_PORT%
|
||||
echo API: http://localhost:%ETHEND_PORT%/api/health
|
||||
exit /b 0
|
||||
)
|
||||
|
||||
:: Free port
|
||||
for /f "tokens=5" %%a in ('netstat -ano ^| findstr /R ":%PORT% " ^| findstr "LISTENING"') do (
|
||||
echo [WARN] Port %PORT% in use by PID %%a, releasing...
|
||||
for /f "tokens=5" %%a in ('netstat -ano ^| findstr /R ":%ETHEND_PORT% " ^| findstr "LISTENING"') do (
|
||||
echo [WARN] Port %ETHEND_PORT% in use by PID %%a, releasing...
|
||||
taskkill /PID %%a /F >nul 2>&1
|
||||
timeout /t 1 /nobreak >nul
|
||||
)
|
||||
@@ -141,10 +141,10 @@ if not exist "node_modules\" (
|
||||
if not exist "%LOG_DIR%" mkdir "%LOG_DIR%"
|
||||
|
||||
echo.
|
||||
echo [INFO] Starting ethend on port %PORT%
|
||||
echo Web UI: http://localhost:%PORT%
|
||||
echo API: http://localhost:%PORT%/api/health
|
||||
echo WebSocket: ws://localhost:%PORT%/ws
|
||||
echo [INFO] Starting ethend on port %ETHEND_PORT%
|
||||
echo Web UI: http://localhost:%ETHEND_PORT%
|
||||
echo API: http://localhost:%ETHEND_PORT%/api/health
|
||||
echo WebSocket: ws://localhost:%ETHEND_PORT%/ws
|
||||
echo.
|
||||
|
||||
start "Cyrene-ethend" /B !NODE_CMD! src\index.js 1>>"%LOG_FILE%" 2>&1
|
||||
@@ -155,12 +155,12 @@ set MAX_WAIT=30
|
||||
set WAITED=0
|
||||
:health_loop
|
||||
if %WAITED% geq %MAX_WAIT% goto :timeout
|
||||
powershell -Command "try { (Invoke-WebRequest 'http://localhost:%PORT%/api/health' -TimeoutSec 2 -UseBasicParsing).StatusCode -eq 200 } catch { $false }" | findstr "True" >nul 2>&1
|
||||
powershell -Command "try { (Invoke-WebRequest 'http://localhost:%ETHEND_PORT%/api/health' -TimeoutSec 2 -UseBasicParsing).StatusCode -eq 200 } catch { $false }" | findstr "True" >nul 2>&1
|
||||
if %ERRORLEVEL%==0 (
|
||||
echo.
|
||||
echo ==========================================
|
||||
echo ethend is ready!
|
||||
echo Console: http://localhost:%PORT%
|
||||
echo Console: http://localhost:%ETHEND_PORT%
|
||||
echo Log: %LOG_FILE%
|
||||
echo ==========================================
|
||||
echo.
|
||||
@@ -173,12 +173,12 @@ goto :health_loop
|
||||
:timeout
|
||||
echo.
|
||||
echo [WARN] Still starting - waited %MAX_WAIT%s
|
||||
echo Check http://localhost:%PORT%/api/health
|
||||
echo Check http://localhost:%ETHEND_PORT%/api/health
|
||||
exit /b 0
|
||||
|
||||
:: ==========================================
|
||||
:stop
|
||||
for /f "tokens=5" %%a in ('netstat -ano ^| findstr /R ":%PORT% " ^| findstr "LISTENING"') do (
|
||||
for /f "tokens=5" %%a in ('netstat -ano ^| findstr /R ":%ETHEND_PORT% " ^| findstr "LISTENING"') do (
|
||||
echo [INFO] Stopping ethend (PID: %%a)...
|
||||
taskkill /PID %%a /F >nul 2>&1
|
||||
echo [OK] ethend stopped
|
||||
@@ -189,7 +189,7 @@ exit /b 0
|
||||
|
||||
:: ==========================================
|
||||
:status
|
||||
curl -s -o nul "http://localhost:%PORT%/api/health" 2>nul
|
||||
curl -s -o nul "http://localhost:%ETHEND_PORT%/api/health" 2>nul
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo [ERROR] ethend is offline
|
||||
exit /b 1
|
||||
@@ -197,7 +197,7 @@ if %ERRORLEVEL% neq 0 (
|
||||
echo [OK] ethend online
|
||||
echo.
|
||||
echo Service Status:
|
||||
curl -s "http://localhost:%PORT%/api/services" 2>nul | !NODE_CMD! -e "const d=require('fs').readFileSync(0,'utf-8');const s=JSON.parse(d);for(const[k,v]of Object.entries(s)){const icon=v.status==='running'?'+':' ';const pid=v.pid?' (PID:'+v.pid+')':'';const upt=v.uptime?' uptime:'+Math.round(v.uptime/1000)+'s':'';console.log(' ['+icon+'] '+v.name.padEnd(18)+v.status+pid+upt);}" 2>nul
|
||||
curl -s "http://localhost:%ETHEND_PORT%/api/services" 2>nul | !NODE_CMD! -e "const d=require('fs').readFileSync(0,'utf-8');const s=JSON.parse(d);for(const[k,v]of Object.entries(s)){const icon=v.status==='running'?'+':' ';const pid=v.pid?' (PID:'+v.pid+')':'';const upt=v.uptime?' uptime:'+Math.round(v.uptime/1000)+'s':'';console.log(' ['+icon+'] '+v.name.padEnd(18)+v.status+pid+upt);}" 2>nul
|
||||
echo.
|
||||
echo Database:
|
||||
powershell -Command "$tcp=New-Object Net.Sockets.TcpClient;try{$tcp.Connect('127.0.0.1',5432);' [OK] PostgreSQL online'}catch{' [--] PostgreSQL offline'};$tcp.Dispose()" 2>nul
|
||||
@@ -230,14 +230,18 @@ if not "%SVC_ID%"=="" goto :build_one
|
||||
|
||||
:build_all
|
||||
echo [INFO] Building all backend services...
|
||||
set SERVICES=memory-service:backend/memory-service tool-engine:backend/tool-engine iot-debug-service:backend/iot-debug-service voice-service:backend/voice-service ai-core:backend/ai-core plugin-manager:backend/plugin-manager platform-bridge:backend/platform-bridge gateway:backend/gateway
|
||||
set SERVICES=memory-service:backend/memory-service iot-debug-service:backend/iot-debug-service voice-service:backend/voice-service ai-core:backend/ai-core plugin-manager:backend/cyrene-plugins platform-bridge:backend/platform-bridge gateway:backend/gateway
|
||||
for %%s in (%SERVICES%) do (
|
||||
for /f "tokens=1,2 delims=:" %%a in ("%%s") do (
|
||||
if exist "%ROOT%%%b" (
|
||||
echo Building %%a...
|
||||
cd /d "%ROOT%%%b"
|
||||
set GOWORK=off
|
||||
go build -o main.exe .\cmd\main.go 2>&1
|
||||
if "%%a"=="plugin-manager" (
|
||||
go build -o main.exe .\cmd\plugin-manager\ 2>&1
|
||||
) else (
|
||||
go build -o main.exe .\cmd\main.go 2>&1
|
||||
)
|
||||
if !ERRORLEVEL!==0 (
|
||||
echo [OK] %%a
|
||||
) else (
|
||||
@@ -251,7 +255,7 @@ echo [OK] Build complete
|
||||
exit /b 0
|
||||
|
||||
:build_one
|
||||
for %%s in (memory-service:backend/memory-service tool-engine:backend/tool-engine iot-debug-service:backend/iot-debug-service voice-service:backend/voice-service ai-core:backend/ai-core plugin-manager:backend/plugin-manager platform-bridge:backend/platform-bridge gateway:backend/gateway) do (
|
||||
for %%s in (memory-service:backend/memory-service iot-debug-service:backend/iot-debug-service voice-service:backend/voice-service ai-core:backend/ai-core plugin-manager:backend/cyrene-plugins platform-bridge:backend/platform-bridge gateway:backend/gateway) do (
|
||||
for /f "tokens=1,2 delims=:" %%a in ("%%s") do (
|
||||
if "%%a"=="%SVC_ID%" (
|
||||
if not exist "%ROOT%%%b" (
|
||||
@@ -261,7 +265,11 @@ for %%s in (memory-service:backend/memory-service tool-engine:backend/tool-engin
|
||||
echo [INFO] Building %%a...
|
||||
cd /d "%ROOT%%%b"
|
||||
set GOWORK=off
|
||||
go build -o main.exe .\cmd\main.go 2>&1
|
||||
if "%%a"=="plugin-manager" (
|
||||
go build -o main.exe .\cmd\plugin-manager\ 2>&1
|
||||
) else (
|
||||
go build -o main.exe .\cmd\main.go 2>&1
|
||||
)
|
||||
if !ERRORLEVEL!==0 (echo [OK] %%a) else (echo [FAIL] %%a)
|
||||
cd /d "%ROOT%"
|
||||
exit /b !ERRORLEVEL!
|
||||
@@ -269,7 +277,7 @@ for %%s in (memory-service:backend/memory-service tool-engine:backend/tool-engin
|
||||
)
|
||||
)
|
||||
echo [ERROR] Unknown service: %SVC_ID%
|
||||
echo Available: memory-service, tool-engine, iot-debug-service, voice-service, ai-core, plugin-manager, platform-bridge, gateway
|
||||
echo Available: memory-service, iot-debug-service, voice-service, ai-core, plugin-manager, platform-bridge, gateway
|
||||
exit /b 1
|
||||
|
||||
:: ==========================================
|
||||
|
||||
@@ -185,11 +185,10 @@ db_stop() {
|
||||
# ========== 服务编译 ==========
|
||||
SERVICES=(
|
||||
"memory-service:backend/memory-service"
|
||||
"tool-engine:backend/tool-engine"
|
||||
"iot-debug-service:backend/iot-debug-service"
|
||||
"voice-service:backend/voice-service"
|
||||
"ai-core:backend/ai-core"
|
||||
"plugin-manager:backend/plugin-manager"
|
||||
"plugin-manager:backend/cyrene-plugins"
|
||||
"platform-bridge:backend/platform-bridge"
|
||||
"gateway:backend/gateway"
|
||||
)
|
||||
@@ -204,7 +203,11 @@ build_service() {
|
||||
$IS_WIN && binary="main.exe"
|
||||
|
||||
cd "$ROOT/$dir"
|
||||
if GOWORK=off go build -o "$binary" ./cmd/main.go 2>&1; then
|
||||
local build_target="./cmd/main.go"
|
||||
if [ "$id" = "plugin-manager" ]; then
|
||||
build_target="./cmd/plugin-manager/"
|
||||
fi
|
||||
if GOWORK=off go build -o "$binary" "$build_target" 2>&1; then
|
||||
echo -e "${GREEN} ✓ $label 编译完成${NC}"
|
||||
return 0
|
||||
else
|
||||
|
||||
+747
-173
File diff suppressed because it is too large
Load Diff
@@ -147,7 +147,7 @@ export const SERVICES = {
|
||||
},
|
||||
'plugin-manager': {
|
||||
name: '插件管理器',
|
||||
cwd: path.join(ROOT, 'backend/plugin-manager'),
|
||||
cwd: path.join(ROOT, 'backend/cyrene-plugins'),
|
||||
command: './main',
|
||||
env: {
|
||||
PORT: '8094',
|
||||
@@ -156,7 +156,7 @@ export const SERVICES = {
|
||||
healthUrl: IS_DOCKER ? 'http://plugin-manager:8094/api/v1/health' : 'http://localhost:8094/api/v1/health',
|
||||
port: 8094,
|
||||
buildCommand: 'go',
|
||||
buildArgs: ['build', '-o', isWin ? 'main.exe' : 'main', './cmd/main.go'],
|
||||
buildArgs: ['build', '-o', isWin ? 'main.exe' : 'main', './cmd/plugin-manager/'],
|
||||
goBin: GO_BIN,
|
||||
},
|
||||
'platform-bridge': {
|
||||
|
||||
+107
-5
@@ -79,6 +79,67 @@ processManager.on('log', (serviceId, stream, text) => {
|
||||
}
|
||||
});
|
||||
|
||||
// ========== 平台桥接实时日志流 ==========
|
||||
|
||||
let logStreamWs = null;
|
||||
let logStreamReconnectTimer = null;
|
||||
|
||||
function connectPlatformBridgeLogStream() {
|
||||
if (logStreamReconnectTimer) { clearTimeout(logStreamReconnectTimer); logStreamReconnectTimer = null; }
|
||||
if (logStreamWs && (logStreamWs.readyState === WebSocket.OPEN || logStreamWs.readyState === WebSocket.CONNECTING)) return;
|
||||
|
||||
const wsUrl = PLATFORM_BRIDGE_URL.replace(/^http/, 'ws') + '/ws/logs';
|
||||
console.log(`[LogStream] 连接 ${wsUrl} ...`);
|
||||
try {
|
||||
logStreamWs = new WebSocket(wsUrl);
|
||||
} catch (err) {
|
||||
console.error(`[LogStream] 创建连接失败: ${err.message}`);
|
||||
scheduleLogStreamReconnect();
|
||||
return;
|
||||
}
|
||||
|
||||
logStreamWs.on('open', () => {
|
||||
console.log('[LogStream] 已连接,实时日志推送中');
|
||||
});
|
||||
|
||||
logStreamWs.on('message', (raw) => {
|
||||
try {
|
||||
const entry = JSON.parse(raw.toString());
|
||||
broadcast('chat-log', entry);
|
||||
} catch {}
|
||||
});
|
||||
|
||||
logStreamWs.on('close', () => {
|
||||
console.log('[LogStream] 连接断开');
|
||||
logStreamWs = null;
|
||||
scheduleLogStreamReconnect();
|
||||
});
|
||||
|
||||
logStreamWs.on('error', (err) => {
|
||||
console.error(`[LogStream] 错误: ${err.message}`);
|
||||
logStreamWs = null;
|
||||
scheduleLogStreamReconnect();
|
||||
});
|
||||
}
|
||||
|
||||
function scheduleLogStreamReconnect() {
|
||||
if (logStreamReconnectTimer) return;
|
||||
logStreamReconnectTimer = setTimeout(() => {
|
||||
logStreamReconnectTimer = null;
|
||||
connectPlatformBridgeLogStream();
|
||||
}, 5000);
|
||||
}
|
||||
|
||||
// 启动时连接,后续 platform-bridge 重启时通过状态变化自动重连。
|
||||
connectPlatformBridgeLogStream();
|
||||
|
||||
// 监听服务状态:platform-bridge 上线后重连。
|
||||
setInterval(() => {
|
||||
if (!logStreamWs || logStreamWs.readyState === WebSocket.CLOSED || logStreamWs.readyState === WebSocket.CLOSING) {
|
||||
connectPlatformBridgeLogStream();
|
||||
}
|
||||
}, 15000);
|
||||
|
||||
// ========== Gateway 代理辅助函数 ==========
|
||||
|
||||
/** 缓存的 JWT token 和过期时间 */
|
||||
@@ -740,6 +801,47 @@ app.get('/api/chat-platforms/logs/:name', async (req, res) => {
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// GET /api/chat-platforms/platforms — 列出所有平台适配器 (含连接状态与能力)
|
||||
app.get('/api/chat-platforms/platforms', async (_req, res) => {
|
||||
const result = await proxyToPlatformBridge('/api/v1/platforms');
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// GET /api/chat-platforms/platforms/:name — 获取单个平台适配器详情
|
||||
app.get('/api/chat-platforms/platforms/:name', async (req, res) => {
|
||||
const result = await proxyToPlatformBridge(`/api/v1/platforms/${req.params.name}`);
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// GET /api/chat-platforms/identities — 列出所有已注册的身份映射
|
||||
app.get('/api/chat-platforms/identities', async (_req, res) => {
|
||||
const result = await proxyToPlatformBridge('/api/v1/identities');
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// GET /api/chat-platforms/health — platform-bridge 整体健康状态
|
||||
app.get('/api/chat-platforms/health', async (_req, res) => {
|
||||
const result = await proxyToPlatformBridge('/health');
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// ---- 黑名单/白名单设置代理 ----
|
||||
|
||||
// GET /api/chat-platforms/settings/blocklist
|
||||
app.get('/api/chat-platforms/settings/blocklist', async (_req, res) => {
|
||||
const result = await proxyToPlatformBridge('/api/v1/settings/blocklist');
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// POST /api/chat-platforms/settings/blocklist
|
||||
app.post('/api/chat-platforms/settings/blocklist', async (req, res) => {
|
||||
const result = await proxyToPlatformBridge('/api/v1/settings/blocklist', {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(req.body),
|
||||
});
|
||||
res.status(result.status).json(result.body);
|
||||
});
|
||||
|
||||
// ---- 多端客户端管理代理 (转发到 Gateway) ----
|
||||
|
||||
// GET /api/clients — 获取已知客户端列表
|
||||
@@ -1184,7 +1286,7 @@ app.get('/api/trace/recent', async (req, res) => {
|
||||
const traces = [];
|
||||
|
||||
// LLM 调用 → 追踪节点
|
||||
const llmCalls = Array.isArray(llmResult.body) ? llmResult.body : [];
|
||||
const llmCalls = Array.isArray(llmResult.body) ? llmResult.body : (llmResult.body?.calls || []);
|
||||
for (const call of llmCalls) {
|
||||
const ts = call.time ? new Date(call.time).getTime() : Date.now();
|
||||
traces.push({
|
||||
@@ -1195,7 +1297,7 @@ app.get('/api/trace/recent', async (req, res) => {
|
||||
hop: 'llm_call',
|
||||
label: `LLM 调用: ${call.model || 'unknown'}`,
|
||||
status: call.success ? 'success' : 'error',
|
||||
durationMs: call.duration_ms || call.Duration || 0,
|
||||
durationMs: call.duration_ms || (call.Duration ? Math.round(call.Duration / 1e6) : 0),
|
||||
detail: call.error || `${call.prompt_tokens || 0}+${call.completion_tokens || 0} tokens`,
|
||||
data: call,
|
||||
});
|
||||
@@ -1214,7 +1316,7 @@ app.get('/api/trace/recent', async (req, res) => {
|
||||
hop: 'tool_call',
|
||||
label: `工具调用: ${tc.tool_name || tc.name || 'unknown'}`,
|
||||
status: tc.error ? 'error' : 'success',
|
||||
durationMs: tc.duration_ms || tc.Duration || 0,
|
||||
durationMs: tc.duration_ms || (tc.Duration ? Math.round(tc.Duration / 1e6) : 0),
|
||||
detail: tc.error || tc.result?.substring?.(0, 100) || '',
|
||||
data: tc,
|
||||
});
|
||||
@@ -1330,7 +1432,7 @@ app.get('/api/trace/session/:sessionId', async (req, res) => {
|
||||
}
|
||||
|
||||
// LLM 调用记录中如果有 session 相关信息也加入
|
||||
const llmCalls = Array.isArray(llmResult.body) ? llmResult.body : [];
|
||||
const llmCalls = Array.isArray(llmResult.body) ? llmResult.body : (llmResult.body?.calls || []);
|
||||
for (const call of llmCalls) {
|
||||
const ts = call.time ? new Date(call.time).getTime() : Date.now();
|
||||
traces.push({
|
||||
@@ -1341,7 +1443,7 @@ app.get('/api/trace/session/:sessionId', async (req, res) => {
|
||||
hop: 'llm_call',
|
||||
label: `LLM: ${call.model || 'unknown'}`,
|
||||
status: call.success ? 'success' : 'error',
|
||||
durationMs: call.duration_ms || call.Duration || 0,
|
||||
durationMs: call.duration_ms || (call.Duration ? Math.round(call.Duration / 1e6) : 0),
|
||||
detail: call.error || `${call.prompt_tokens || 0}→${call.completion_tokens || 0} tokens`,
|
||||
data: call,
|
||||
});
|
||||
|
||||
@@ -86,10 +86,6 @@ function detectDockerServices() {
|
||||
for (const m of ports.matchAll(/:(\d+)->/g)) {
|
||||
detectedPorts.add(parseInt(m[1]));
|
||||
}
|
||||
for (const m of ports.matchAll(/(\d+)\/(tcp|udp)/g)) {
|
||||
detectedPorts.add(parseInt(m[1]));
|
||||
}
|
||||
|
||||
for (const [id, svc] of Object.entries(SERVICES)) {
|
||||
if (svc.port && detectedPorts.has(svc.port)) {
|
||||
result.set(id, {
|
||||
@@ -306,6 +302,7 @@ class ProcessManager extends EventEmitter {
|
||||
env,
|
||||
stdio: ['ignore', 'pipe', 'pipe'],
|
||||
shell: needsShell,
|
||||
windowsHide: true,
|
||||
});
|
||||
|
||||
child.stdout.on('data', (data) => {
|
||||
@@ -363,9 +360,10 @@ class ProcessManager extends EventEmitter {
|
||||
|
||||
const procInfo = this.processes.get(serviceId);
|
||||
if (!procInfo.process) {
|
||||
// 可能已经崩溃了,重置状态
|
||||
procInfo.status = 'stopped';
|
||||
return { success: true, message: `${svc.name} 未在运行` };
|
||||
procInfo.pid = null;
|
||||
procInfo.startTime = null;
|
||||
return { success: true, message: `${svc.name} 已停止` };
|
||||
}
|
||||
|
||||
return new Promise((resolve) => {
|
||||
@@ -434,6 +432,7 @@ class ProcessManager extends EventEmitter {
|
||||
cwd: svc.cwd,
|
||||
env: { ...process.env, GOPROXY: 'https://goproxy.cn,direct', GOWORK: 'off' },
|
||||
stdio: ['ignore', 'pipe', 'pipe'],
|
||||
windowsHide: true,
|
||||
});
|
||||
|
||||
let stdout = '';
|
||||
|
||||
Generated
+156
@@ -8,6 +8,7 @@
|
||||
"name": "cyrene-frontend",
|
||||
"version": "0.1.0",
|
||||
"dependencies": {
|
||||
"@ricky0123/vad-web": "^0.0.30",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"zustand": "^4.5.5"
|
||||
@@ -848,6 +849,78 @@
|
||||
"node": ">= 8"
|
||||
}
|
||||
},
|
||||
"node_modules/@protobufjs/aspromise": {
|
||||
"version": "1.1.2",
|
||||
"resolved": "https://registry.npmmirror.com/@protobufjs/aspromise/-/aspromise-1.1.2.tgz",
|
||||
"integrity": "sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ==",
|
||||
"license": "BSD-3-Clause"
|
||||
},
|
||||
"node_modules/@protobufjs/base64": {
|
||||
"version": "1.1.2",
|
||||
"resolved": "https://registry.npmmirror.com/@protobufjs/base64/-/base64-1.1.2.tgz",
|
||||
"integrity": "sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg==",
|
||||
"license": "BSD-3-Clause"
|
||||
},
|
||||
"node_modules/@protobufjs/codegen": {
|
||||
"version": "2.0.5",
|
||||
"resolved": "https://registry.npmmirror.com/@protobufjs/codegen/-/codegen-2.0.5.tgz",
|
||||
"integrity": "sha512-zgXFLzW3Ap33e6d0Wlj4MGIm6Ce8O89n/apUaGNB/jx+hw+ruWEp7EwGUshdLKVRCxZW12fp9r40E1mQrf/34g==",
|
||||
"license": "BSD-3-Clause"
|
||||
},
|
||||
"node_modules/@protobufjs/eventemitter": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmmirror.com/@protobufjs/eventemitter/-/eventemitter-1.1.1.tgz",
|
||||
"integrity": "sha512-vW1GmwMZNnL+gMRaovlh9yZX74kc+TTU3FObkkurpMaRtBfLP3ldjS9KQWlwZgraRE0+dheEEoAxdzcJQ8eXZg==",
|
||||
"license": "BSD-3-Clause"
|
||||
},
|
||||
"node_modules/@protobufjs/fetch": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmmirror.com/@protobufjs/fetch/-/fetch-1.1.1.tgz",
|
||||
"integrity": "sha512-GpptLrs57adMSuHi3VNj0mAF8dwh36LMaYF6XyJ6JMWlVsc+t42tm1HSEDmOs3A8fC9yyeisgLhsTVQokOZ0zw==",
|
||||
"license": "BSD-3-Clause",
|
||||
"dependencies": {
|
||||
"@protobufjs/aspromise": "^1.1.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@protobufjs/float": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmmirror.com/@protobufjs/float/-/float-1.0.2.tgz",
|
||||
"integrity": "sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ==",
|
||||
"license": "BSD-3-Clause"
|
||||
},
|
||||
"node_modules/@protobufjs/inquire": {
|
||||
"version": "1.1.2",
|
||||
"resolved": "https://registry.npmmirror.com/@protobufjs/inquire/-/inquire-1.1.2.tgz",
|
||||
"integrity": "sha512-pa0vFRuws4wkvaXKK1uXZMAwAX4/t8ANaJo45iw/oQHNQ9q5xUzwgFmVJGXiga2BeN+zpX7Vf9vmsiIa2J+MUw==",
|
||||
"license": "BSD-3-Clause"
|
||||
},
|
||||
"node_modules/@protobufjs/path": {
|
||||
"version": "1.1.2",
|
||||
"resolved": "https://registry.npmmirror.com/@protobufjs/path/-/path-1.1.2.tgz",
|
||||
"integrity": "sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA==",
|
||||
"license": "BSD-3-Clause"
|
||||
},
|
||||
"node_modules/@protobufjs/pool": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmmirror.com/@protobufjs/pool/-/pool-1.1.0.tgz",
|
||||
"integrity": "sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw==",
|
||||
"license": "BSD-3-Clause"
|
||||
},
|
||||
"node_modules/@protobufjs/utf8": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmmirror.com/@protobufjs/utf8/-/utf8-1.1.1.tgz",
|
||||
"integrity": "sha512-oOAWABowe8EAbMyWKM0tYDKi8Yaox52D+HWZhAIJqQXbqe0xI/GV7FhLWqlEKreMkfDjshR5FKgi3mnle0h6Eg==",
|
||||
"license": "BSD-3-Clause"
|
||||
},
|
||||
"node_modules/@ricky0123/vad-web": {
|
||||
"version": "0.0.30",
|
||||
"resolved": "https://registry.npmmirror.com/@ricky0123/vad-web/-/vad-web-0.0.30.tgz",
|
||||
"integrity": "sha512-cJyYrh4YeeUBJcbR9Bic/bFDyB9qBkAepvpuWM3vLxnAi7bC3VHzf51UeNdT+OtY4D7MLAgV8iJMc4z41ZnaWg==",
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"onnxruntime-web": "^1.17.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@rolldown/pluginutils": {
|
||||
"version": "1.0.0-beta.27",
|
||||
"resolved": "https://registry.npmjs.org/@rolldown/pluginutils/-/pluginutils-1.0.0-beta.27.tgz",
|
||||
@@ -1257,6 +1330,15 @@
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/node": {
|
||||
"version": "25.9.2",
|
||||
"resolved": "https://registry.npmmirror.com/@types/node/-/node-25.9.2.tgz",
|
||||
"integrity": "sha512-G05zqtJhcDLb8uslf5EjCxXg9G1KQxiV8OS0R26IC//Eoyitzqe8z37I7cqvnZlrlSfgocQRfSn/AHBZJJFyGw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"undici-types": ">=7.24.0 <7.24.7"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/prop-types": {
|
||||
"version": "15.7.15",
|
||||
"resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.15.tgz",
|
||||
@@ -1704,6 +1786,12 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/flatbuffers": {
|
||||
"version": "25.9.23",
|
||||
"resolved": "https://registry.npmmirror.com/flatbuffers/-/flatbuffers-25.9.23.tgz",
|
||||
"integrity": "sha512-MI1qs7Lo4Syw0EOzUl0xjs2lsoeqFku44KpngfIduHBYvzm8h2+7K8YMQh1JtVVVrUvhLpNwqVi4DERegUJhPQ==",
|
||||
"license": "Apache-2.0"
|
||||
},
|
||||
"node_modules/fraction.js": {
|
||||
"version": "5.3.4",
|
||||
"resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-5.3.4.tgz",
|
||||
@@ -1766,6 +1854,12 @@
|
||||
"node": ">=10.13.0"
|
||||
}
|
||||
},
|
||||
"node_modules/guid-typescript": {
|
||||
"version": "1.0.9",
|
||||
"resolved": "https://registry.npmmirror.com/guid-typescript/-/guid-typescript-1.0.9.tgz",
|
||||
"integrity": "sha512-Y8T4vYhEfwJOTbouREvG+3XDsjr8E3kIr7uf+JZ0BYloFsttiHU0WfvANVsR7TxNUJa/WpCnw/Ino/p+DeBhBQ==",
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/hasown": {
|
||||
"version": "2.0.3",
|
||||
"resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.3.tgz",
|
||||
@@ -1903,6 +1997,12 @@
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/long": {
|
||||
"version": "5.3.2",
|
||||
"resolved": "https://registry.npmmirror.com/long/-/long-5.3.2.tgz",
|
||||
"integrity": "sha512-mNAgZ1GmyNhD7AuqnTG3/VQ26o760+ZYBPKjPvugO8+nLbYfX6TVpJPseBvopbdY+qpZ/lKUnmEc1LeZYS3QAA==",
|
||||
"license": "Apache-2.0"
|
||||
},
|
||||
"node_modules/loose-envify": {
|
||||
"version": "1.4.0",
|
||||
"resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz",
|
||||
@@ -2024,6 +2124,26 @@
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/onnxruntime-common": {
|
||||
"version": "1.26.0",
|
||||
"resolved": "https://registry.npmmirror.com/onnxruntime-common/-/onnxruntime-common-1.26.0.tgz",
|
||||
"integrity": "sha512-qVyMR4lcWgbkc4getFV+GQijsTnbg/siteoqcDwa3sI/LxbrMSNw4ePyvCq/ymdQaRomCA7YuWmhzsswxvymdw==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/onnxruntime-web": {
|
||||
"version": "1.26.0",
|
||||
"resolved": "https://registry.npmmirror.com/onnxruntime-web/-/onnxruntime-web-1.26.0.tgz",
|
||||
"integrity": "sha512-LbRr/8zZt2xilI2smrVQGGKINo0U46i8qJp+UXyMBGfqN7KjnH1BiwCwLwyNIVV4i9CKFv7Sf4PwLKWnT8/bEA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"flatbuffers": "^25.1.24",
|
||||
"guid-typescript": "^1.0.9",
|
||||
"long": "^5.2.3",
|
||||
"onnxruntime-common": "1.26.0",
|
||||
"platform": "^1.3.6",
|
||||
"protobufjs": "^7.2.4"
|
||||
}
|
||||
},
|
||||
"node_modules/path-parse": {
|
||||
"version": "1.0.7",
|
||||
"resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz",
|
||||
@@ -2071,6 +2191,12 @@
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/platform": {
|
||||
"version": "1.3.6",
|
||||
"resolved": "https://registry.npmmirror.com/platform/-/platform-1.3.6.tgz",
|
||||
"integrity": "sha512-fnWVljUchTro6RiCFvCXBbNhJc2NijN7oIQxbwsyL0buWJPG85v81ehlHI9fXrJsMNgTofEoWIQeClKpgxFLrg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/postcss": {
|
||||
"version": "8.5.14",
|
||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.14.tgz",
|
||||
@@ -2234,6 +2360,30 @@
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/protobufjs": {
|
||||
"version": "7.6.2",
|
||||
"resolved": "https://registry.npmmirror.com/protobufjs/-/protobufjs-7.6.2.tgz",
|
||||
"integrity": "sha512-N9EiLovGEQOJSPF26Ij7qUGvahfEnq0eeYZ02aigIedkmz1qZSwjnP9SBITHJuF/6MYbIW4HDN8zdYjsjqJKXQ==",
|
||||
"hasInstallScript": true,
|
||||
"license": "BSD-3-Clause",
|
||||
"dependencies": {
|
||||
"@protobufjs/aspromise": "^1.1.2",
|
||||
"@protobufjs/base64": "^1.1.2",
|
||||
"@protobufjs/codegen": "^2.0.5",
|
||||
"@protobufjs/eventemitter": "^1.1.1",
|
||||
"@protobufjs/fetch": "^1.1.1",
|
||||
"@protobufjs/float": "^1.0.2",
|
||||
"@protobufjs/inquire": "^1.1.2",
|
||||
"@protobufjs/path": "^1.1.2",
|
||||
"@protobufjs/pool": "^1.1.0",
|
||||
"@protobufjs/utf8": "^1.1.1",
|
||||
"@types/node": ">=13.7.0",
|
||||
"long": "^5.3.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/queue-microtask": {
|
||||
"version": "1.2.3",
|
||||
"resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz",
|
||||
@@ -2623,6 +2773,12 @@
|
||||
"node": ">=14.17"
|
||||
}
|
||||
},
|
||||
"node_modules/undici-types": {
|
||||
"version": "7.24.6",
|
||||
"resolved": "https://registry.npmmirror.com/undici-types/-/undici-types-7.24.6.tgz",
|
||||
"integrity": "sha512-WRNW+sJgj5OBN4/0JpHFqtqzhpbnV0GuB+OozA9gCL7a993SmU+1JBZCzLNxYsbMfIeDL+lTsphD5jN5N+n0zg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/update-browserslist-db": {
|
||||
"version": "1.2.3",
|
||||
"resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.2.3.tgz",
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ricky0123/vad-web": "^0.0.30",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"zustand": "^4.5.5"
|
||||
|
||||
@@ -36,7 +36,7 @@ function setHashSessionId(sessionId: string | null) {
|
||||
|
||||
export default function App() {
|
||||
const { isLoggedIn, login, register, loading: authLoading, userId } = useAuth();
|
||||
const { send } = useChat();
|
||||
const { send, sendVoiceStreamMessage } = useChat();
|
||||
const { loadSessionsFromServer, ensureMainSession, setCurrentSessionId, setMessages, loadMessagesFromServer, sessions, currentSessionId } = useSessionStore();
|
||||
|
||||
const [authMode, setAuthMode] = useState<'login' | 'register'>('login');
|
||||
@@ -330,41 +330,42 @@ export default function App() {
|
||||
return (
|
||||
<ErrorBoundary>
|
||||
<AppLayout>
|
||||
<PageRouter onSend={send} />
|
||||
<PageRouter onSend={send} onSendVoiceStream={sendVoiceStreamMessage} />
|
||||
</AppLayout>
|
||||
</ErrorBoundary>
|
||||
);
|
||||
}
|
||||
|
||||
type SendFn = (content: string, mode?: import('@/types/chat').ChatMode, attachments?: import('@/types/chat').MessageAttachment[]) => void;
|
||||
type SendVoiceStreamFn = (msg: import('@/types/chat').WSClientMessage) => void;
|
||||
|
||||
function PageRouter({ onSend }: { onSend: SendFn }) {
|
||||
function PageRouter({ onSend, onSendVoiceStream }: { onSend: SendFn; onSendVoiceStream: SendVoiceStreamFn }) {
|
||||
const currentPage = usePageStore((s) => s.currentPage);
|
||||
const isAdmin = isAdminUser(localStorage.getItem('user_id') || '');
|
||||
|
||||
switch (currentPage) {
|
||||
case 'admin-models':
|
||||
if (!isAdmin) return <ChatPage onSend={onSend} />;
|
||||
if (!isAdmin) return <ChatPage onSend={onSend} onSendVoiceStream={onSendVoiceStream} />;
|
||||
return <ModelsAdminPage />;
|
||||
case 'admin-dashboard':
|
||||
if (!isAdmin) return <ChatPage onSend={onSend} />;
|
||||
if (!isAdmin) return <ChatPage onSend={onSend} onSendVoiceStream={onSendVoiceStream} />;
|
||||
return <AdminDashboard />;
|
||||
case 'profile':
|
||||
return <ProfilePage />;
|
||||
case 'chat':
|
||||
default:
|
||||
return <ChatPage onSend={onSend} />;
|
||||
return <ChatPage onSend={onSend} onSendVoiceStream={onSendVoiceStream} />;
|
||||
}
|
||||
}
|
||||
|
||||
function ChatPage({ onSend }: { onSend: SendFn }) {
|
||||
function ChatPage({ onSend, onSendVoiceStream }: { onSend: SendFn; onSendVoiceStream: SendVoiceStreamFn }) {
|
||||
return (
|
||||
<div className="flex flex-col h-full overflow-hidden">
|
||||
<div className="flex-1 min-h-0 overflow-hidden">
|
||||
<ChatContainer />
|
||||
</div>
|
||||
<div className="flex-shrink-0">
|
||||
<ChatInput onSend={onSend} />
|
||||
<ChatInput onSend={onSend} onSendVoiceStream={onSendVoiceStream} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import { useState, useRef, useCallback, useEffect } from 'react';
|
||||
import type { ChatMode, MessageAttachment } from '@/types/chat';
|
||||
import { useSpeechRecognition } from '@/hooks/useSpeechRecognition';
|
||||
import { useVoiceInput } from '@/hooks/useVoiceInput';
|
||||
import { uploadFile } from '@/api/files';
|
||||
import { useChatStore } from '@/store/chatStore';
|
||||
|
||||
interface ChatInputProps {
|
||||
onSend: (content: string, mode: ChatMode, attachments?: MessageAttachment[]) => void;
|
||||
onSendVoiceStream?: (msg: import('@/types/chat').WSClientMessage) => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
@@ -19,7 +21,7 @@ const MAX_IMAGE_SIZE = 10 * 1024 * 1024; // 10MB
|
||||
const SUPPORTED_IMAGE_TYPES = ['image/jpeg', 'image/png', 'image/gif', 'image/webp', 'image/bmp'];
|
||||
const MAX_IMAGES = 5;
|
||||
|
||||
export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
||||
export function ChatInput({ onSend, onSendVoiceStream, disabled }: ChatInputProps) {
|
||||
const [content, setContent] = useState('');
|
||||
const [mode, setMode] = useState<ChatMode>('text');
|
||||
const [pendingImages, setPendingImages] = useState<PendingImage[]>([]);
|
||||
@@ -30,27 +32,50 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
||||
const isTyping = useChatStore((s) => s.isTyping);
|
||||
|
||||
const {
|
||||
isListening,
|
||||
isSupported,
|
||||
isListening: isSRListening,
|
||||
isSupported: isSRSpported,
|
||||
isFallbackMode,
|
||||
interimText,
|
||||
finalText,
|
||||
error,
|
||||
startListening,
|
||||
stopListening,
|
||||
error: srError,
|
||||
startListening: startSR,
|
||||
stopListening: stopSR,
|
||||
resetText,
|
||||
} = useSpeechRecognition();
|
||||
|
||||
// 当 finalText 更新时,追加到输入框
|
||||
// VAD-based voice input (primary when supported)
|
||||
const {
|
||||
isListening: isVADListening,
|
||||
isSpeaking: isVADSpeaking,
|
||||
isSupported: isVADSupported,
|
||||
interimText: vadInterimText,
|
||||
finalText: vadFinalText,
|
||||
error: vadError,
|
||||
startListening: startVAD,
|
||||
stopListening: stopVAD,
|
||||
} = useVoiceInput({
|
||||
onTranscription: (text: string) => {
|
||||
setContent((prev) => {
|
||||
const trimmed = prev.trimEnd();
|
||||
return (trimmed ? trimmed + ' ' : '') + text;
|
||||
});
|
||||
},
|
||||
sendMessage: onSendVoiceStream,
|
||||
});
|
||||
|
||||
const isListening = isVADSupported ? isVADListening : isSRListening;
|
||||
const voiceError = isVADSupported ? vadError : srError;
|
||||
|
||||
// 当 SR finalText 更新时,追加到输入框 (仅非 VAD 模式)
|
||||
useEffect(() => {
|
||||
if (finalText) {
|
||||
if (!isVADSupported && finalText) {
|
||||
setContent((prev) => {
|
||||
const trimmed = prev.trimEnd();
|
||||
return (trimmed ? trimmed + ' ' : '') + finalText;
|
||||
});
|
||||
resetText();
|
||||
}
|
||||
}, [finalText, resetText]);
|
||||
}, [isVADSupported, finalText, resetText]);
|
||||
|
||||
const handleSend = useCallback(async () => {
|
||||
const trimmed = content.trim();
|
||||
@@ -121,13 +146,13 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
||||
if (e.key === 'V' && e.ctrlKey && e.shiftKey) {
|
||||
e.preventDefault();
|
||||
if (isListening) {
|
||||
stopListening();
|
||||
isVADSupported ? stopVAD() : stopSR();
|
||||
} else {
|
||||
startListening();
|
||||
isVADSupported ? startVAD() : startSR();
|
||||
}
|
||||
}
|
||||
},
|
||||
[handleSend, isListening, startListening, stopListening]
|
||||
[handleSend, isListening, isVADSupported, startVAD, stopVAD, startSR, stopSR]
|
||||
);
|
||||
|
||||
// 粘贴图片
|
||||
@@ -260,11 +285,11 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
||||
|
||||
const handleVoiceToggle = useCallback(() => {
|
||||
if (isListening) {
|
||||
stopListening();
|
||||
isVADSupported ? stopVAD() : stopSR();
|
||||
} else {
|
||||
startListening();
|
||||
isVADSupported ? startVAD() : startSR();
|
||||
}
|
||||
}, [isListening, startListening, stopListening]);
|
||||
}, [isListening, isVADSupported, startVAD, stopVAD, startSR, stopSR]);
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -305,8 +330,17 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 实时识别文本提示 */}
|
||||
{isListening && interimText && (
|
||||
{/* VAD 语音状态提示 */}
|
||||
{isVADSupported && isVADListening && (
|
||||
<div className="text-sm text-pink-500 dark:text-pink-400 italic px-1" aria-live="polite">
|
||||
{isVADSpeaking
|
||||
? (vadInterimText || '检测到语音,正在识别...')
|
||||
: '正在聆听...'}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 实时识别文本提示 (仅非 VAD 模式) */}
|
||||
{!isVADSupported && isListening && interimText && (
|
||||
<div
|
||||
className="interim-text text-sm text-pink-500 dark:text-pink-400 italic px-1"
|
||||
aria-live="polite"
|
||||
@@ -317,12 +351,12 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
||||
)}
|
||||
|
||||
{/* 错误提示 */}
|
||||
{error && (
|
||||
{voiceError && (
|
||||
<div
|
||||
className="text-xs text-red-500 dark:text-red-400 px-1"
|
||||
role="alert"
|
||||
>
|
||||
⚠️ {error}
|
||||
⚠️ {voiceError}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -434,18 +468,20 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
||||
className="flex-1 resize-none rounded-xl border border-pink-200 dark:border-pink-800 bg-white dark:bg-gray-800 px-4 py-2 text-sm text-gray-700 dark:text-gray-200 placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-pink-400 focus:border-transparent disabled:opacity-50"
|
||||
/>
|
||||
|
||||
{/* 语音输入按钮 (仅浏览器支持时显示) */}
|
||||
{isSupported && (
|
||||
{/* 语音输入按钮 (VAD 或浏览器 SpeechRecognition/MediaRecorder 支持时显示) */}
|
||||
{(isVADSupported || isSRSpported) && (
|
||||
<button
|
||||
onClick={handleVoiceToggle}
|
||||
disabled={disabled || uploading}
|
||||
aria-label={isListening ? '停止语音输入' : '开始语音输入'}
|
||||
aria-pressed={isListening}
|
||||
title={isListening ? '停止聆听 (Ctrl+Shift+V)' : '语音输入 (Ctrl+Shift+V)'}
|
||||
title={isListening ? '停止聆听 (Ctrl+Shift+V)' : isVADSupported ? '语音输入 (自动检测说话)' : '语音输入 (Ctrl+Shift+V)'}
|
||||
className={`p-2 rounded-xl transition-all flex-shrink-0 border-2 ${
|
||||
isListening
|
||||
? 'voice-btn-active bg-red-500 border-red-500 text-white'
|
||||
: 'bg-gray-100 dark:bg-gray-700 border-gray-200 dark:border-gray-600 text-gray-500 hover:text-red-500 hover:border-red-300'
|
||||
: isVADSpeaking
|
||||
? 'bg-yellow-400 border-yellow-400 text-white'
|
||||
: 'bg-gray-100 dark:bg-gray-700 border-gray-200 dark:border-gray-600 text-gray-500 hover:text-red-500 hover:border-red-300'
|
||||
} disabled:opacity-40 disabled:cursor-not-allowed`}
|
||||
>
|
||||
<svg
|
||||
@@ -461,7 +497,7 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
||||
)}
|
||||
|
||||
{/* 不支持时显示禁用按钮 */}
|
||||
{!isSupported && (
|
||||
{!isVADSupported && !isSRSpported && (
|
||||
<button
|
||||
disabled
|
||||
title="您的浏览器不支持语音识别"
|
||||
@@ -508,7 +544,10 @@ export function ChatInput({ onSend, disabled }: ChatInputProps) {
|
||||
{/* 语音输入状态提示 */}
|
||||
{isListening && (
|
||||
<p className="text-xs text-red-400 text-center animate-pulse">
|
||||
{isFallbackMode ? '🎤 后端语音识别中...' : '🎤 正在聆听...'}
|
||||
{isVADSupported
|
||||
? (isVADSpeaking ? '🔊 检测到语音,正在识别...' : '🎤 正在聆听...')
|
||||
: (isFallbackMode ? '🎤 后端语音识别中...' : '🎤 正在聆听...')
|
||||
}
|
||||
<span className="text-gray-400 ml-2">(Ctrl+Shift+V 停止)</span>
|
||||
</p>
|
||||
)}
|
||||
|
||||
@@ -43,11 +43,19 @@ export function useChat() {
|
||||
[addMessage, setTyping, sendMessage]
|
||||
);
|
||||
|
||||
const sendVoiceStreamMessage = useCallback(
|
||||
(msg: import('@/types/chat').WSClientMessage) => {
|
||||
sendMessage(msg);
|
||||
},
|
||||
[sendMessage]
|
||||
);
|
||||
|
||||
return {
|
||||
messages,
|
||||
isTyping,
|
||||
isConnected,
|
||||
send,
|
||||
sendVoiceStreamMessage,
|
||||
clearMessages,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,221 @@
|
||||
import { useState, useRef, useCallback, useEffect } from 'react';
|
||||
import { MicVAD, utils } from '@ricky0123/vad-web';
|
||||
import { transcribeAudio } from '@/api/voice';
|
||||
import { useChatStore } from '@/store/chatStore';
|
||||
import type { WSClientMessage } from '@/types/chat';
|
||||
|
||||
interface UseVoiceInputOptions {
|
||||
onTranscription: (text: string) => void;
|
||||
language?: string;
|
||||
/** 提供后启用流式模式:通过 WebSocket 分片发送音频,实时返回中间结果 */
|
||||
sendMessage?: (msg: WSClientMessage) => void;
|
||||
}
|
||||
|
||||
interface UseVoiceInputReturn {
|
||||
isListening: boolean;
|
||||
isSpeaking: boolean;
|
||||
isSupported: boolean;
|
||||
/** 流式模式:voice_interim 中间识别文本 */
|
||||
interimText: string;
|
||||
/** 流式模式:voice_final 最终识别文本 */
|
||||
finalText: string;
|
||||
error: string | null;
|
||||
startListening: () => Promise<void>;
|
||||
stopListening: () => void;
|
||||
}
|
||||
|
||||
export function useVoiceInput({
|
||||
onTranscription,
|
||||
language = 'zh',
|
||||
sendMessage,
|
||||
}: UseVoiceInputOptions): UseVoiceInputReturn {
|
||||
const [isListening, setIsListening] = useState(false);
|
||||
const [isSpeaking, setIsSpeaking] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const vadRef = useRef<MicVAD | null>(null);
|
||||
const inSpeechRef = useRef(false);
|
||||
const seqRef = useRef(0);
|
||||
|
||||
const isSupported =
|
||||
typeof window !== 'undefined' &&
|
||||
typeof AudioWorkletNode !== 'undefined' &&
|
||||
typeof WebAssembly !== 'undefined';
|
||||
|
||||
const isStreaming = !!sendMessage;
|
||||
|
||||
// Subscribe to streaming results from store
|
||||
const voiceInterimText = useChatStore((s) => s.voiceInterimText);
|
||||
const voiceFinalText = useChatStore((s) => s.voiceFinalText);
|
||||
|
||||
const stopListening = useCallback(() => {
|
||||
if (vadRef.current) {
|
||||
vadRef.current.pause();
|
||||
}
|
||||
inSpeechRef.current = false;
|
||||
setIsListening(false);
|
||||
setIsSpeaking(false);
|
||||
}, []);
|
||||
|
||||
const startListening = useCallback(async () => {
|
||||
if (!isSupported) return;
|
||||
|
||||
setError(null);
|
||||
seqRef.current = 0;
|
||||
|
||||
try {
|
||||
if (vadRef.current) {
|
||||
await vadRef.current.destroy();
|
||||
vadRef.current = null;
|
||||
}
|
||||
|
||||
vadRef.current = await MicVAD.new({
|
||||
onSpeechStart: () => {
|
||||
setIsSpeaking(true);
|
||||
inSpeechRef.current = true;
|
||||
seqRef.current = 0;
|
||||
|
||||
if (isStreaming) {
|
||||
sendMessage!({
|
||||
type: 'voice_stream_start',
|
||||
format: 'wav',
|
||||
language,
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
onFrameProcessed: async (_probabilities, frame) => {
|
||||
if (!isStreaming || !inSpeechRef.current) return;
|
||||
|
||||
// Accumulate audio frames during speech for the voice_stream_chunk
|
||||
// Send every ~300ms of audio (10 frames at ~30ms each)
|
||||
const frameSeq = seqRef.current++;
|
||||
const CHUNK_INTERVAL = 10; // send every 10 frames
|
||||
|
||||
if (frameSeq % CHUNK_INTERVAL === 0) {
|
||||
try {
|
||||
const wavBuffer = utils.encodeWAV(frame);
|
||||
const base64 = arrayBufferToBase64(wavBuffer);
|
||||
|
||||
sendMessage!({
|
||||
type: 'voice_stream_chunk',
|
||||
audio_data: base64,
|
||||
sequence: Math.floor(frameSeq / CHUNK_INTERVAL),
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
} catch {
|
||||
// Ignore encoding errors for individual chunks
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
onSpeechEnd: async (audio: Float32Array) => {
|
||||
setIsSpeaking(false);
|
||||
inSpeechRef.current = false;
|
||||
|
||||
if (isStreaming) {
|
||||
// Send the last chunk with accumulated audio if any
|
||||
try {
|
||||
const wavBuffer = utils.encodeWAV(audio);
|
||||
const base64 = arrayBufferToBase64(wavBuffer);
|
||||
|
||||
sendMessage!({
|
||||
type: 'voice_stream_chunk',
|
||||
audio_data: base64,
|
||||
sequence: -1,
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
|
||||
// Signal end of voice stream — gateway returns voice_final
|
||||
sendMessage!({
|
||||
type: 'voice_stream_end',
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
// onTranscription will be called in useEffect when voiceFinalText updates
|
||||
return;
|
||||
}
|
||||
|
||||
// REST mode: send full audio for transcription
|
||||
const wavBuffer = utils.encodeWAV(audio);
|
||||
const wavBlob = new Blob([wavBuffer], { type: 'audio/wav' });
|
||||
|
||||
try {
|
||||
const result = await transcribeAudio(wavBlob, language);
|
||||
if (result.error) {
|
||||
setError(result.error);
|
||||
} else if (result.data?.text) {
|
||||
onTranscription(result.data.text);
|
||||
}
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : '语音识别失败');
|
||||
}
|
||||
},
|
||||
|
||||
onVADMisfire: () => {
|
||||
setIsSpeaking(false);
|
||||
inSpeechRef.current = false;
|
||||
|
||||
if (isStreaming) {
|
||||
sendMessage!({
|
||||
type: 'voice_stream_end',
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
startOnLoad: true,
|
||||
});
|
||||
|
||||
setIsListening(true);
|
||||
} catch (err) {
|
||||
const message =
|
||||
err instanceof DOMException && err.name === 'NotAllowedError'
|
||||
? '麦克风权限被拒绝'
|
||||
: err instanceof Error
|
||||
? err.message
|
||||
: 'VAD 初始化失败';
|
||||
setError(message);
|
||||
}
|
||||
}, [isSupported, language, isStreaming, sendMessage, onTranscription]);
|
||||
|
||||
// In streaming mode, watch for voice_final from the server
|
||||
useEffect(() => {
|
||||
if (isStreaming && voiceFinalText) {
|
||||
onTranscription(voiceFinalText);
|
||||
useChatStore.getState().setVoiceFinalText('');
|
||||
}
|
||||
}, [isStreaming, voiceFinalText, onTranscription]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (vadRef.current) {
|
||||
vadRef.current.destroy();
|
||||
vadRef.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
return {
|
||||
isListening,
|
||||
isSpeaking,
|
||||
isSupported,
|
||||
interimText: isStreaming ? voiceInterimText : '',
|
||||
finalText: isStreaming ? voiceFinalText : '',
|
||||
error,
|
||||
startListening,
|
||||
stopListening,
|
||||
};
|
||||
}
|
||||
|
||||
function arrayBufferToBase64(buffer: ArrayBuffer): string {
|
||||
const bytes = new Uint8Array(buffer);
|
||||
let binary = '';
|
||||
for (let i = 0; i < bytes.byteLength; i++) {
|
||||
binary += String.fromCharCode(bytes[i]);
|
||||
}
|
||||
return btoa(binary);
|
||||
}
|
||||
@@ -254,7 +254,7 @@ function handleServerMessage(msg: WSServerMessage) {
|
||||
client_info: msg.client_info,
|
||||
audioUrl: msg.full_audio_url,
|
||||
segments: msg.segments,
|
||||
metadata: msg.tool_calls ? { tool_calls: msg.tool_calls } : undefined,
|
||||
metadata: msg.metadata || (msg.tool_calls ? { tool_calls: msg.tool_calls } : undefined),
|
||||
});
|
||||
}
|
||||
setTyping(false);
|
||||
@@ -462,6 +462,18 @@ function handleServerMessage(msg: WSServerMessage) {
|
||||
}
|
||||
break;
|
||||
|
||||
case 'voice_interim':
|
||||
if (msg.text !== undefined) {
|
||||
useChatStore.getState().setVoiceInterimText(msg.text);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'voice_final':
|
||||
if (msg.text !== undefined) {
|
||||
useChatStore.getState().setVoiceFinalText(msg.text);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'pong':
|
||||
break;
|
||||
|
||||
|
||||
@@ -21,6 +21,10 @@ interface ChatStore {
|
||||
isLoadingHistory: boolean;
|
||||
historyPage: number;
|
||||
|
||||
// 流式语音识别状态
|
||||
voiceInterimText: string;
|
||||
voiceFinalText: string;
|
||||
|
||||
// 多气泡消息队列:确保气泡依次出现 + 逐字动画
|
||||
messageQueue: Message[];
|
||||
|
||||
@@ -37,6 +41,8 @@ interface ChatStore {
|
||||
clearMessages: () => void;
|
||||
|
||||
setContinuousMode: (enabled: boolean) => void;
|
||||
setVoiceInterimText: (text: string) => void;
|
||||
setVoiceFinalText: (text: string) => void;
|
||||
setBackgroundThinkingStatus: (status: BackgroundThinkingStatus) => void;
|
||||
setIoTDevices: (devices: IoTDevice[]) => void;
|
||||
|
||||
@@ -54,6 +60,8 @@ export const useChatStore = create<ChatStore>((set) => ({
|
||||
backgroundThinkingStatus: 'idle',
|
||||
iotDevices: [],
|
||||
iotDevicesLastUpdated: null,
|
||||
voiceInterimText: '',
|
||||
voiceFinalText: '',
|
||||
hasMoreMessages: false,
|
||||
isLoadingHistory: false,
|
||||
historyPage: 1,
|
||||
@@ -143,6 +151,9 @@ export const useChatStore = create<ChatStore>((set) => ({
|
||||
|
||||
setContinuousMode: (enabled) => set({ continuousMode: enabled }),
|
||||
|
||||
setVoiceInterimText: (text) => set({ voiceInterimText: text }),
|
||||
setVoiceFinalText: (text) => set({ voiceFinalText: text, voiceInterimText: '' }),
|
||||
|
||||
setBackgroundThinkingStatus: (status) => set({ backgroundThinkingStatus: status }),
|
||||
|
||||
setIoTDevices: (devices) =>
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
// 页面路由 Store — 管理当前显示的页面
|
||||
|
||||
import { create } from 'zustand';
|
||||
|
||||
export type PageId = 'chat' | 'admin-models' | 'admin-dashboard' | 'profile';
|
||||
|
||||
interface PageState {
|
||||
currentPage: PageId;
|
||||
setPage: (page: PageId) => void;
|
||||
goToChat: () => void;
|
||||
}
|
||||
|
||||
export const usePageStore = create<PageState>((set) => ({
|
||||
currentPage: 'chat',
|
||||
setPage: (page) => set({ currentPage: page }),
|
||||
goToChat: () => set({ currentPage: 'chat' }),
|
||||
}));
|
||||
@@ -105,11 +105,14 @@ export interface StreamSegment {
|
||||
|
||||
/** WebSocket 客户端消息 */
|
||||
export interface WSClientMessage {
|
||||
type: 'message' | 'voice_input' | 'ping' | 'history';
|
||||
type: 'message' | 'voice_input' | 'voice_stream_start' | 'voice_stream_chunk' | 'voice_stream_end' | 'ping' | 'history';
|
||||
session_id?: string;
|
||||
mode?: ChatMode;
|
||||
content?: string;
|
||||
audio_data?: string; // base64
|
||||
format?: string; // 音频格式 (voice_stream_start): webm, wav, pcm, opus
|
||||
language?: string; // 识别语言 (voice_stream_start): zh, en, ja, ko, auto
|
||||
sequence?: number; // 音频分片序号 (voice_stream_chunk)
|
||||
attachments?: MessageAttachment[];
|
||||
timestamp: number;
|
||||
client_id?: string;
|
||||
@@ -139,7 +142,7 @@ export interface AppNotification extends NotificationData {
|
||||
|
||||
/** WebSocket 服务端消息 */
|
||||
export interface WSServerMessage {
|
||||
type: 'stream_start' | 'response' | 'segment' | 'audio' | 'error' | 'device_update' | 'pong' | 'history_response' | 'stream_chunk' | 'stream_end' | 'background_thinking' | 'notification' | 'multi_message' | 'stream_segments' | 'review' | 'thinking' | 'tool_progress' | 'system_info';
|
||||
type: 'stream_start' | 'response' | 'segment' | 'audio' | 'error' | 'device_update' | 'pong' | 'history_response' | 'stream_chunk' | 'stream_end' | 'background_thinking' | 'notification' | 'multi_message' | 'stream_segments' | 'review' | 'thinking' | 'tool_progress' | 'system_info' | 'voice_interim' | 'voice_final';
|
||||
message_id?: string;
|
||||
text?: string;
|
||||
content?: string;
|
||||
@@ -161,6 +164,7 @@ export interface WSServerMessage {
|
||||
notification?: NotificationData;
|
||||
tool_progress?: ToolProgressInfo;
|
||||
system_info?: SystemInfoPayload;
|
||||
metadata?: Record<string, unknown>;
|
||||
protocol_version?: number;
|
||||
timestamp: number;
|
||||
client_info?: ClientInfo;
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
批量 WEM → WAV 转换,使用 vgmstream-cli + ffmpeg 标准化。
|
||||
输出: 22050Hz, mono, 16-bit PCM WAV
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
VGMSTREAM = r"D:\Project\Code\Uni\Cyrene\scripts\voice\tools\vgmstream\vgmstream-cli.exe"
|
||||
RAW_DIR = r"D:\Project\Code\Uni\Cyrene-Voice-Model\data\raw"
|
||||
CLEANED_DIR = r"D:\Project\Code\Uni\Cyrene-Voice-Model\data\cleaned"
|
||||
|
||||
# 要转换的子目录(按优先级)
|
||||
TARGETS = [
|
||||
"VoBanks27", "VoBanks28", "VoBanks29", "VoBanks30", "VoBanks31",
|
||||
]
|
||||
|
||||
|
||||
def convert_wem_to_wav(wem_path: str, wav_path: str) -> bool:
|
||||
"""vgmstream WEM→临时WAV, ffmpeg 标准化 → 最终WAV (22050Hz mono s16)"""
|
||||
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
|
||||
|
||||
# 跳过已存在且非空的文件
|
||||
if os.path.exists(wav_path) and os.path.getsize(wav_path) > 100:
|
||||
return True
|
||||
|
||||
tmp_path = wav_path + ".tmp.wav"
|
||||
try:
|
||||
# Step 1: vgmstream → temp WAV
|
||||
result = subprocess.run(
|
||||
[VGMSTREAM, "-o", tmp_path, wem_path],
|
||||
capture_output=True, timeout=30,
|
||||
)
|
||||
if result.returncode != 0 or not os.path.exists(tmp_path):
|
||||
return False
|
||||
|
||||
# Step 2: ffmpeg → 标准化 22050Hz mono s16
|
||||
result = subprocess.run(
|
||||
["ffmpeg", "-y", "-i", tmp_path,
|
||||
"-ar", "22050", "-ac", "1", "-sample_fmt", "s16",
|
||||
wav_path],
|
||||
capture_output=True, timeout=30,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return False
|
||||
|
||||
return os.path.exists(wav_path) and os.path.getsize(wav_path) > 100
|
||||
|
||||
except Exception as e:
|
||||
print(f" FAIL [{os.path.basename(wem_path)}]: {e}")
|
||||
return False
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def main():
|
||||
print("=== 批量 WEM → WAV 转换 (VoBanks) ===\n")
|
||||
|
||||
total = 0
|
||||
ok = 0
|
||||
fail = 0
|
||||
|
||||
for target in TARGETS:
|
||||
src_dir = os.path.join(RAW_DIR, target)
|
||||
dst_dir = os.path.join(CLEANED_DIR, target)
|
||||
|
||||
if not os.path.isdir(src_dir):
|
||||
print(f"SKIP: {target} (not found)")
|
||||
continue
|
||||
|
||||
wem_files = sorted(Path(src_dir).glob("*.wem"))
|
||||
if not wem_files:
|
||||
print(f"SKIP: {target} (empty)")
|
||||
continue
|
||||
|
||||
print(f"[{target}] {len(wem_files)} files...")
|
||||
|
||||
for i, wem in enumerate(wem_files):
|
||||
wav = os.path.join(dst_dir, wem.stem + ".wav")
|
||||
if convert_wem_to_wav(str(wem), wav):
|
||||
ok += 1
|
||||
else:
|
||||
fail += 1
|
||||
total += 1
|
||||
|
||||
if (i + 1) % 50 == 0:
|
||||
print(f" {i+1}/{len(wem_files)} (ok:{ok} fail:{fail})")
|
||||
|
||||
print(f" Done: {len(wem_files)} files\n")
|
||||
|
||||
print(f"=== 转换完成: {ok} ok, {fail} fail, {total} total ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,73 @@
|
||||
#!/bin/bash
|
||||
# 批量 WEM → WAV 转换 (使用 vgmstream-cli)
|
||||
set -e
|
||||
|
||||
RAW_DIR="D:/Project/Code/Uni/Cyrene-Voice-Model/data/raw"
|
||||
CLEANED_DIR="D:/Project/Code/Uni/Cyrene-Voice-Model/data/cleaned"
|
||||
VGMSTREAM="D:/Project/Code/Uni/Cyrene/scripts/voice/tools/vgmstream/vgmstream-cli.exe"
|
||||
|
||||
if [ ! -f "$VGMSTREAM" ]; then
|
||||
echo "错误: 找不到 vgmstream-cli.exe"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "=== 批量 WEM → WAV 转换 ==="
|
||||
echo ""
|
||||
|
||||
TOTAL=0
|
||||
SUCCESS=0
|
||||
FAILED=0
|
||||
|
||||
while IFS= read -r -d '' wem_file; do
|
||||
TOTAL=$((TOTAL + 1))
|
||||
|
||||
# 输出路径: cleaned/ 目录下保持相同子目录结构,改 .wem 为 .wav
|
||||
rel_path="${wem_file#$RAW_DIR/}"
|
||||
wav_file="${CLEANED_DIR}/${rel_path%.wem}.wav"
|
||||
wav_dir="$(dirname "$wav_file")"
|
||||
|
||||
mkdir -p "$wav_dir"
|
||||
|
||||
# 跳过已转换的
|
||||
if [ -f "$wav_file" ] && [ "$(stat -c%s "$wav_file" 2>/dev/null || echo 0)" -gt 100 ]; then
|
||||
SUCCESS=$((SUCCESS + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
# 转换
|
||||
if cmd.exe //c "$VGMSTREAM -o \"$wav_file\" \"$wem_file\"" 2>/dev/null; then
|
||||
SUCCESS=$((SUCCESS + 1))
|
||||
else
|
||||
FAILED=$((FAILED + 1))
|
||||
fi
|
||||
|
||||
# 进度显示
|
||||
if [ $((TOTAL % 100)) -eq 0 ]; then
|
||||
echo " 进度: $TOTAL 文件 (成功: $SUCCESS, 失败: $FAILED)"
|
||||
fi
|
||||
done < <(find "$RAW_DIR" -name "*.wem" -print0)
|
||||
|
||||
echo ""
|
||||
echo "=== 转换完成 ==="
|
||||
echo "总计: $TOTAL | 成功: $SUCCESS | 失败: $FAILED"
|
||||
|
||||
# 统计分类
|
||||
echo ""
|
||||
echo "音频时长分布:"
|
||||
find "$CLEANED_DIR" -name "*.wav" | while read wav; do
|
||||
dur=$(ffprobe -v quiet -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 "$wav" 2>/dev/null || echo "0")
|
||||
echo "$dur"
|
||||
done | awk '
|
||||
{ d = $1 + 0 }
|
||||
d < 1 { lt1++ }
|
||||
d < 3 { lt3++ }
|
||||
d < 10 { lt10++ }
|
||||
d < 30 { lt30++ }
|
||||
d >= 30 { gt30++ }
|
||||
END {
|
||||
printf " < 1s: %d\n", lt1
|
||||
printf " 1-3s: %d\n", lt3
|
||||
printf " 3-10s: %d\n", lt10
|
||||
printf " 10-30s: %d\n", lt30
|
||||
printf " > 30s: %d\n", gt30
|
||||
}'
|
||||
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
通过 Chrome DevTools Protocol 下载文件。
|
||||
需要 Chrome 以 --remote-debugging-port=9222 启动。
|
||||
|
||||
用法:
|
||||
python cdp_download.py <url> <output_path>
|
||||
python cdp_download.py https://github.com/.../vgmstream-win.zip tools/vgmstream.zip
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
from websocket import create_connection
|
||||
|
||||
|
||||
class CDPClient:
|
||||
def __init__(self, ws_url: str):
|
||||
self.ws = create_connection(ws_url, origin="http://127.0.0.1:9222")
|
||||
self._id = 0
|
||||
self._lock = threading.Lock()
|
||||
self._pending = {}
|
||||
self._events = []
|
||||
self._running = True
|
||||
|
||||
# Start background reader
|
||||
self._reader_thread = threading.Thread(target=self._read_loop, daemon=True)
|
||||
self._reader_thread.start()
|
||||
|
||||
def _read_loop(self):
|
||||
while self._running:
|
||||
try:
|
||||
msg = json.loads(self.ws.recv())
|
||||
msg_id = msg.get('id')
|
||||
if msg_id is not None:
|
||||
with self._lock:
|
||||
self._pending[msg_id] = msg
|
||||
else:
|
||||
self._events.append(msg)
|
||||
except Exception:
|
||||
if self._running:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
break
|
||||
|
||||
def send(self, method: str, params: dict = None) -> dict:
|
||||
self._id += 1
|
||||
msg_id = self._id
|
||||
payload = json.dumps({
|
||||
'id': msg_id,
|
||||
'method': method,
|
||||
'params': params or {}
|
||||
})
|
||||
self.ws.send(payload)
|
||||
|
||||
# Wait for response
|
||||
timeout = 60
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
with self._lock:
|
||||
if msg_id in self._pending:
|
||||
result = self._pending.pop(msg_id)
|
||||
if 'error' in result:
|
||||
raise Exception(f"CDP Error: {result['error']}")
|
||||
return result.get('result', {})
|
||||
time.sleep(0.1)
|
||||
raise TimeoutError(f"CDP command {method} timed out")
|
||||
|
||||
def wait_for_event(self, event_type: str, timeout: float = 60) -> dict:
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
for i, evt in enumerate(self._events):
|
||||
if evt.get('method') == event_type:
|
||||
return self._events.pop(i)['params']
|
||||
time.sleep(0.2)
|
||||
raise TimeoutError(f"Event {event_type} not received within {timeout}s")
|
||||
|
||||
def close(self):
|
||||
self._running = False
|
||||
self.ws.close()
|
||||
|
||||
|
||||
def download_via_cdp(url: str, output_path: str, cdp_url: str = "http://localhost:9222"):
|
||||
"""
|
||||
使用 Chrome CDP 下载文件。
|
||||
"""
|
||||
import urllib.request
|
||||
|
||||
# 1. 创建新标签页
|
||||
print(f"[CDP] 创建标签页...")
|
||||
req = urllib.request.Request(f"{cdp_url}/json/new", method='PUT')
|
||||
resp = urllib.request.urlopen(req)
|
||||
tab = json.loads(resp.read())
|
||||
ws_url = tab['webSocketDebuggerUrl']
|
||||
print(f"[CDP] 标签页: {tab['id']}")
|
||||
|
||||
client = CDPClient(ws_url)
|
||||
|
||||
try:
|
||||
# 2. 启用必要的域
|
||||
print(f"[CDP] 启用 Page...")
|
||||
client.send('Page.enable')
|
||||
|
||||
# 3. 设置下载目录
|
||||
download_dir = os.path.abspath(os.path.dirname(output_path))
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
print(f"[CDP] 下载目录: {download_dir}")
|
||||
|
||||
client.send('Browser.setDownloadBehavior', {
|
||||
'behavior': 'allow',
|
||||
'downloadPath': download_dir
|
||||
})
|
||||
|
||||
# 4. 导航到下载 URL
|
||||
print(f"[CDP] 导航到: {url}")
|
||||
client.send('Page.navigate', {'url': url})
|
||||
|
||||
# 5. 等待下载完成
|
||||
print(f"[CDP] 等待下载开始...")
|
||||
will_begin = client.wait_for_event('Browser.downloadWillBegin', timeout=30)
|
||||
guid = will_begin['guid']
|
||||
suggested_name = will_begin.get('suggestedFilename', 'unknown')
|
||||
print(f"[CDP] 下载开始: {suggested_name} (guid={guid})")
|
||||
|
||||
# 等待下载进度完成
|
||||
print(f"[CDP] 等待下载完成...")
|
||||
while True:
|
||||
progress = client.wait_for_event('Browser.downloadProgress', timeout=120)
|
||||
state = progress.get('state', '')
|
||||
if state == 'completed':
|
||||
print(f"[CDP] 下载完成")
|
||||
break
|
||||
elif state == 'canceled':
|
||||
raise Exception("下载被取消")
|
||||
elif state == 'interrupted':
|
||||
print(f"[CDP] 下载中断, 重试...")
|
||||
client.send('Browser.resumeDownload', {'guid': guid})
|
||||
else:
|
||||
received = progress.get('receivedBytes', 0)
|
||||
total = progress.get('totalBytes', 0)
|
||||
if total > 0:
|
||||
print(f"[CDP] 进度: {received}/{total} ({100*received//total}%)")
|
||||
|
||||
# 6. 移动到目标路径
|
||||
downloaded_file = os.path.join(download_dir, suggested_name)
|
||||
if os.path.exists(downloaded_file) and downloaded_file != output_path:
|
||||
if os.path.exists(output_path):
|
||||
os.remove(output_path)
|
||||
os.rename(downloaded_file, output_path)
|
||||
print(f"[CDP] 文件保存到: {output_path}")
|
||||
|
||||
return output_path
|
||||
|
||||
finally:
|
||||
client.close()
|
||||
# 关闭标签页
|
||||
urllib.request.urlopen(urllib.request.Request(
|
||||
f"{cdp_url}/json/close/{tab['id']}", method='PUT'))
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 3:
|
||||
print("用法: python cdp_download.py <url> <output_path>")
|
||||
print("示例: python cdp_download.py https://example.com/file.zip tools/file.zip")
|
||||
sys.exit(1)
|
||||
|
||||
url = sys.argv[1]
|
||||
output = sys.argv[2]
|
||||
download_via_cdp(url, output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
将 .wem (Wwise Encoded Media) 文件批量转换为 .wav 格式。
|
||||
使用 ffmpeg 进行转换(需预先安装 ffmpeg)。
|
||||
|
||||
.wem 文件本质上是 RIFF/WAVE 容器,内部编码可能是:
|
||||
- PCM 16-bit (ffmpeg 直接支持)
|
||||
- Wwise ADPCM (ffmpeg 需要额外解码器)
|
||||
- Vorbis (部分 ffmpeg 版本支持)
|
||||
|
||||
用法:
|
||||
python convert_wem.py <input_dir> <output_dir>
|
||||
python convert_wem.py ./wem_output/ ./wav_output/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def convert_wem_to_wav(wem_path: str, wav_path: str) -> bool:
|
||||
"""使用 ffmpeg 将单个 .wem 文件转为 .wav."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['ffmpeg', '-y', '-i', wem_path,
|
||||
'-ar', '22050', # 22.05kHz (与 persona.yaml 设定的训练格式一致)
|
||||
'-ac', '1', # mono
|
||||
'-sample_fmt', 's16', # 16-bit
|
||||
wav_path],
|
||||
capture_output=True,
|
||||
timeout=30,
|
||||
)
|
||||
if result.returncode == 0 and os.path.getsize(wav_path) > 100:
|
||||
return True
|
||||
else:
|
||||
# 部分 WEM 是 Vorbis 编码,需要用不同方式
|
||||
return _convert_wem_vorbis(wem_path, wav_path)
|
||||
except Exception as e:
|
||||
print(f" ffmpeg 错误 [{os.path.basename(wem_path)}]: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _convert_wem_vorbis(wem_path: str, wav_path: str) -> bool:
|
||||
"""尝试处理 Vorbis 编码的 WEM 文件 (带 Wwise 头的 Ogg Vorbis)."""
|
||||
try:
|
||||
# 方式: 跳过 RIFF 头 + fmt chunk, 直接取 Vorbis 数据
|
||||
# .wem 的 Vorbis 数据从 "vorb" chunk 开始
|
||||
with open(wem_path, 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
# 查找 "vorb" 标识
|
||||
vorb_pos = data.find(b'vorb')
|
||||
if vorb_pos == -1:
|
||||
return False
|
||||
|
||||
# 重新封装为标准 Ogg (在 vorb 数据前加 OggS 头)
|
||||
# 简化方法: 用 ffmpeg 的 libvorbis 解码
|
||||
# 如果上面失败了,尝试用 -f s16le 强制读取
|
||||
result = subprocess.run(
|
||||
['ffmpeg', '-y', '-f', 's16le',
|
||||
'-ar', '48000', '-ac', '1',
|
||||
'-i', wem_path,
|
||||
'-ar', '22050', '-ac', '1',
|
||||
wav_path],
|
||||
capture_output=True,
|
||||
timeout=30,
|
||||
)
|
||||
return result.returncode == 0 and os.path.getsize(wav_path) > 100
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def convert_directory(input_dir: str, output_dir: str) -> tuple[int, int]:
|
||||
"""批量转换目录中所有 .wem 文件."""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
wem_files = sorted(Path(input_dir).glob('*.wem'))
|
||||
|
||||
if not wem_files:
|
||||
print(f"在 {input_dir} 中未找到 .wem 文件")
|
||||
return 0, 0
|
||||
|
||||
print(f"找到 {len(wem_files)} 个 .wem 文件,开始转换...")
|
||||
success = 0
|
||||
failed = 0
|
||||
|
||||
for i, wem_path in enumerate(wem_files):
|
||||
wav_name = wem_path.stem + '.wav'
|
||||
wav_path = os.path.join(output_dir, wav_name)
|
||||
|
||||
# 跳过已存在的
|
||||
if os.path.exists(wav_path) and os.path.getsize(wav_path) > 100:
|
||||
success += 1
|
||||
continue
|
||||
|
||||
if convert_wem_to_wav(str(wem_path), wav_path):
|
||||
success += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
if (i + 1) % 50 == 0:
|
||||
print(f" 进度: {i+1}/{len(wem_files)} (成功: {success}, 失败: {failed})")
|
||||
|
||||
print(f"\n转换完成: {success} 成功, {failed} 失败")
|
||||
return success, failed
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="批量转换 .wem → .wav (需要 ffmpeg)")
|
||||
parser.add_argument('input_dir', help='包含 .wem 文件的输入目录')
|
||||
parser.add_argument('output_dir', help='输出目录')
|
||||
parser.add_argument('--single', nargs=2, metavar=('WEM', 'WAV'),
|
||||
help='转换单个文件')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.single:
|
||||
ok = convert_wem_to_wav(args.single[0], args.single[1])
|
||||
print(f"{'OK' if ok else 'FAILED'}: {args.single[0]} -> {args.single[1]}")
|
||||
else:
|
||||
convert_directory(args.input_dir, args.output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,161 @@
|
||||
#!/bin/bash
|
||||
# ============================================================
|
||||
# 昔涟语音提取管线
|
||||
#
|
||||
# 步骤:
|
||||
# 1. 从 HSR 音频包提取 .wem (本脚本)
|
||||
# 2. 用 ww2ogg/vgmstream 转换 .wem → .wav (需手动安装工具)
|
||||
# 3. 用 ffmpeg 标准化音频格式
|
||||
# ============================================================
|
||||
set -e
|
||||
|
||||
HSR_AUDIO_DIR="D:/MeowG/Honkai:Star_Rail/StarRail_Data/Persistent/Audio/AudioPackage/Windows/Chinese(PRC)"
|
||||
RAW_DIR="D:/Project/Code/Uni/Cyrene-Voice-Model/data/raw"
|
||||
CLEANED_DIR="D:/Project/Code/Uni/Cyrene-Voice-Model/data/cleaned"
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
|
||||
echo "=== 昔涟语音提取管线 ==="
|
||||
echo ""
|
||||
|
||||
# ---- 阶段 1: 从 .pck 提取 .wem ----
|
||||
echo "[阶段 1/4] 提取候选 .pck 文件..."
|
||||
|
||||
# 昔涟是 3.x 角色,语音在以下文件中:
|
||||
# - VoBanks 27-31 (最新角色语音库)
|
||||
# - External_del_4.0_chapter5_* (3.0 主线昔涟出场)
|
||||
# - External_del_4.1_chapter_* (3.1 主线昔涟出场)
|
||||
TARGETS=(
|
||||
"VoBanks27.pck"
|
||||
"VoBanks28.pck"
|
||||
"VoBanks29.pck"
|
||||
"VoBanks30.pck"
|
||||
"VoBanks31.pck"
|
||||
"External_del_4.0_chapter5_0.pck"
|
||||
"External_del_4.0_chapter5_1.pck"
|
||||
"External_del_4.0_chapter5_2.pck"
|
||||
"External_del_4.1_chapter_0.pck"
|
||||
"External_del_4.1_chapter_1.pck"
|
||||
"External_del_4.1_chapter_2.pck"
|
||||
)
|
||||
|
||||
for target in "${TARGETS[@]}"; do
|
||||
pck_path="${HSR_AUDIO_DIR}/${target}"
|
||||
if [ -f "$pck_path" ]; then
|
||||
echo " 提取: $target ($(du -h "$pck_path" | cut -f1))"
|
||||
python3 "${SCRIPT_DIR}/extract_pck.py" "$pck_path" "${RAW_DIR}/${target%.pck}/"
|
||||
else
|
||||
echo " 跳过: $target (文件不存在)"
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "[阶段 1/4] 完成: .wem 文件已提取到 ${RAW_DIR}/"
|
||||
|
||||
# ---- 阶段 2: .wem → .wav 转换 ----
|
||||
echo ""
|
||||
echo "[阶段 2/4] 需要转换 .wem → .wav"
|
||||
echo ""
|
||||
echo " HSR 使用 Wwise 专有编码 (0xFFFF),ffmpeg 无法直接解码。"
|
||||
echo " 请使用以下工具之一进行转换:"
|
||||
echo ""
|
||||
echo " 方案 A — vgmstream CLI (推荐, 最简单):"
|
||||
echo " 下载: https://github.com/vgmstream/vgmstream/releases"
|
||||
echo " 解压后将 vgmstream-cli.exe 放入 scripts/voice/tools/"
|
||||
echo " 然后运行本脚本的 --convert 模式"
|
||||
echo ""
|
||||
echo " 方案 B — AnimeWwise GUI (最强大, 保留原始文件名):"
|
||||
echo " 下载: https://github.com/Escartem/AnimeWwise/releases"
|
||||
echo " 直接打开 GUI, 选择 HSR 目录, 导出昔涟语音"
|
||||
echo ""
|
||||
echo " 方案 C — ww2ogg + revorb (传统方案):"
|
||||
echo " 下载 ww2ogg.exe + revorb.exe + packed_codebooks.bin"
|
||||
echo " 放入 scripts/voice/tools/"
|
||||
echo ""
|
||||
echo " 安装工具后, 运行: $0 --convert"
|
||||
echo ""
|
||||
|
||||
# ---- 阶段 3 (条件): 批量转换 ----
|
||||
if [ "$1" = "--convert" ]; then
|
||||
echo "[阶段 3/4] 转换 .wem → .wav..."
|
||||
|
||||
TOOLS_DIR="${SCRIPT_DIR}/tools"
|
||||
|
||||
# 优先使用 vgmstream
|
||||
if [ -f "${TOOLS_DIR}/vgmstream-cli.exe" ]; then
|
||||
echo " 使用 vgmstream-cli..."
|
||||
find "${RAW_DIR}" -name "*.wem" | while read wem; do
|
||||
wav="${wem%.wem}.wav"
|
||||
if [ ! -f "$wav" ]; then
|
||||
"${TOOLS_DIR}/vgmstream-cli.exe" -o "$wav" "$wem" 2>/dev/null
|
||||
fi
|
||||
done
|
||||
elif [ -f "${TOOLS_DIR}/ww2ogg.exe" ]; then
|
||||
echo " 使用 ww2ogg + ffmpeg..."
|
||||
find "${RAW_DIR}" -name "*.wem" | while read wem; do
|
||||
ogg="${wem%.wem}.ogg"
|
||||
wav="${wem%.wem}.wav"
|
||||
if [ ! -f "$wav" ]; then
|
||||
"${TOOLS_DIR}/ww2ogg.exe" "$wem" -o "$ogg" --pcb "${TOOLS_DIR}/packed_codebooks.bin" 2>/dev/null
|
||||
ffmpeg -y -i "$ogg" -ar 22050 -ac 1 -sample_fmt s16 "$wav" 2>/dev/null
|
||||
rm -f "$ogg"
|
||||
fi
|
||||
done
|
||||
else
|
||||
echo " 错误: 未找到转换工具, 请先安装 vgmstream 或 ww2ogg"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "[阶段 3/4] 完成"
|
||||
|
||||
# ---- 阶段 4: 音频标准化 + 分类 ----
|
||||
echo ""
|
||||
echo "[阶段 4/4] 标准化 + 分类..."
|
||||
|
||||
# 按音频时长初步分类 (语音通常 1-15 秒)
|
||||
mkdir -p "${CLEANED_DIR}/daily" "${CLEANED_DIR}/battle" \
|
||||
"${CLEANED_DIR}/emotional" "${CLEANED_DIR}/story"
|
||||
|
||||
find "${RAW_DIR}" -name "*.wav" | while read wav; do
|
||||
# 获取时长
|
||||
duration=$(ffprobe -v quiet -show_entries format=duration \
|
||||
-of default=noprint_wrappers=1:nokey=1 "$wav" 2>/dev/null || echo "0")
|
||||
dur_float=$(echo "$duration" | awk '{print int($1 * 1000)}')
|
||||
|
||||
basename=$(basename "$wav" .wav)
|
||||
parent=$(basename "$(dirname "$wav")")
|
||||
|
||||
# 分类逻辑
|
||||
if [ "$dur_float" -lt 500 ]; then
|
||||
# < 0.5s: 可能是战斗短语音 / 语气词
|
||||
target_dir="${CLEANED_DIR}/battle"
|
||||
elif [ "$dur_float" -gt 15000 ]; then
|
||||
# > 15s: 可能是剧情长对话
|
||||
target_dir="${CLEANED_DIR}/story"
|
||||
elif echo "$parent" | grep -qi "chapter"; then
|
||||
target_dir="${CLEANED_DIR}/story"
|
||||
elif echo "$parent" | grep -qi "vobanks"; then
|
||||
# VoBanks 包含战斗 + 日常语音, 需要人工筛选
|
||||
target_dir="${CLEANED_DIR}/daily"
|
||||
else
|
||||
target_dir="${CLEANED_DIR}/daily"
|
||||
fi
|
||||
|
||||
# 用 ffmpeg 标准化: 22.05kHz mono 16bit
|
||||
ffmpeg -y -i "$wav" -ar 22050 -ac 1 -sample_fmt s16 \
|
||||
"${target_dir}/${parent}_${basename}.wav" 2>/dev/null
|
||||
|
||||
echo -n "."
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "[阶段 4/4] 完成: 音频已分类到 ${CLEANED_DIR}/"
|
||||
echo ""
|
||||
echo "文件分布:"
|
||||
echo " 日常对话: $(ls "${CLEANED_DIR}/daily/" 2>/dev/null | wc -l) 个"
|
||||
echo " 战斗语音: $(ls "${CLEANED_DIR}/battle/" 2>/dev/null | wc -l) 个"
|
||||
echo " 情感表达: $(ls "${CLEANED_DIR}/emotional/" 2>/dev/null | wc -l) 个"
|
||||
echo " 剧情对话: $(ls "${CLEANED_DIR}/story/" 2>/dev/null | wc -l) 个"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== 管线完成 ==="
|
||||
@@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
从 Honkai: Star Rail 的 .pck (AKPK/Wwise SoundBank) 文件中提取 .wem 音频。
|
||||
|
||||
用法:
|
||||
python extract_pck.py <input.pck> <output_dir>
|
||||
python extract_pck.py VoBanks31.pck ./output/
|
||||
python extract_pck.py --all <pck_dir> <output_dir>
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def find_wem_files(data: bytes) -> list[tuple[int, int, str]]:
|
||||
"""
|
||||
扫描数据中的所有 RIFF/WAVE 块 (.wem 文件).
|
||||
返回 [(offset, size, riff_type), ...] 列表.
|
||||
"""
|
||||
results = []
|
||||
pos = 0
|
||||
data_len = len(data)
|
||||
while pos < data_len - 12:
|
||||
if data[pos:pos + 4] == b'RIFF':
|
||||
chunk_size = struct.unpack_from('<I', data, pos + 4)[0]
|
||||
riff_type = data[pos + 8:pos + 12]
|
||||
# .wem 文件的 riff_type 是 b'WAVE'
|
||||
if riff_type == b'WAVE' and chunk_size > 100:
|
||||
total_size = chunk_size + 8 # RIFF header + data
|
||||
if pos + total_size <= data_len:
|
||||
results.append((pos, total_size, riff_type.decode('ascii', errors='replace')))
|
||||
# 跳过已匹配的块
|
||||
pos += 8 + chunk_size
|
||||
continue
|
||||
pos += 1
|
||||
return results
|
||||
|
||||
|
||||
def extract_pck(pck_path: str, output_dir: str, prefix: str = "") -> list[str]:
|
||||
"""
|
||||
从单个 .pck 文件提取所有 .wem 文件到 output_dir.
|
||||
返回提取的文件路径列表.
|
||||
"""
|
||||
pck_name = Path(pck_path).stem
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
with open(pck_path, 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
print(f"[{pck_name}] 文件大小: {len(data):,} bytes, 扫描 WEM 块...")
|
||||
wem_entries = find_wem_files(data)
|
||||
print(f"[{pck_name}] 找到 {len(wem_entries)} 个音频文件")
|
||||
|
||||
extracted = []
|
||||
for i, (offset, size, riff_type) in enumerate(wem_entries):
|
||||
# 使用 offset 作为唯一 ID (Wwise 文件 ID 就是 offset 的某种映射)
|
||||
wem_data = data[offset:offset + size]
|
||||
|
||||
if prefix:
|
||||
filename = f"{prefix}_{i:04d}_{offset:08x}.wem"
|
||||
else:
|
||||
filename = f"{pck_name}_{i:04d}_{offset:08x}.wem"
|
||||
|
||||
out_path = os.path.join(output_dir, filename)
|
||||
with open(out_path, 'wb') as f:
|
||||
f.write(wem_data)
|
||||
extracted.append(out_path)
|
||||
|
||||
total_mb = sum(wem_entries[i][1] for i in range(len(wem_entries))) / 1024 / 1024
|
||||
print(f"[{pck_name}] 提取完成: {len(extracted)} 文件, {total_mb:.1f} MB")
|
||||
return extracted
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="从 HSR .pck 文件提取 .wem 音频")
|
||||
parser.add_argument('input', help='输入 .pck 文件路径,或 --all 模式下的目录路径')
|
||||
parser.add_argument('output', help='输出目录')
|
||||
parser.add_argument('--all', action='store_true', help='批量模式:提取目录中所有 .pck 文件')
|
||||
parser.add_argument('--prefix', default='', help='输出文件名前缀')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.all:
|
||||
pck_dir = Path(args.input)
|
||||
pck_files = sorted(pck_dir.glob('*.pck'))
|
||||
print(f"批量模式: 找到 {len(pck_files)} 个 .pck 文件")
|
||||
total_extracted = 0
|
||||
for pck_file in pck_files:
|
||||
extracted = extract_pck(str(pck_file), args.output, args.prefix)
|
||||
total_extracted += len(extracted)
|
||||
print(f"\n总计: {total_extracted} 个音频文件")
|
||||
else:
|
||||
extract_pck(args.input, args.output, args.prefix)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user