feat: 第五轮开发 - 14项未来路线图功能完整实现

W1-W14 全部完成:
- W1: 消息搜索 (ILIKE全文检索 + SearchModal)
- W2: 对话导出 (JSON/Markdown/TXT三格式)
- W3: 记忆时间线 DevTools 可视化
- W4: 通知推送系统 (WebSocket + Browser Notification API)
- W5: 定时提醒 (30s轮询 + 重复提醒 + WebSocket推送)
- W6: 每日简报 (08:00自动生成: 天气+新闻+提醒+AI摘要)
- W7: IoT场景自动化 (规则引擎 10s轮询 + 条件评估 + 场景执行)
- W8: 语音输入 (浏览器 Speech Recognition API)
- W9: STT服务 (voice-service + whisper.cpp)
- W10: TTS服务 (浏览器 Speech Synthesis + edge-tts三档回退)
- W11: 文件管理 (上传/下载/缩略图/纯Go bilinear缩放)
- W12: 知识库RAG (PostgreSQL tsvector + 文档分块 + 检索)
- W13: 多模态 (图片上传+分析: Vision API + 本地Go分析回退)
- W14: PWA (Service Worker + 离线页 + install prompt)

总计: 6个Go微服务 + 10+前端组件 + 10+ PostgreSQL表 + 4个后台调度器
This commit is contained in:
2026-05-19 12:01:09 +08:00
parent 78e3f450c2
commit bcf4d4e621
69 changed files with 14599 additions and 150 deletions
+83 -10
View File
@@ -12,6 +12,8 @@ import (
"github.com/gin-gonic/gin"
"github.com/joho/godotenv"
"github.com/yourname/cyrene-ai/gateway/internal/config"
"github.com/yourname/cyrene-ai/gateway/internal/engine"
"github.com/yourname/cyrene-ai/gateway/internal/handler"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/router"
"github.com/yourname/cyrene-ai/gateway/internal/store"
@@ -27,15 +29,78 @@ func main() {
// 加载配置
cfg := config.Load()
// 确保上传目录存在
if err := os.MkdirAll("./uploads", 0755); err != nil {
log.Printf("⚠ 创建上传目录失败: %v", err)
}
// 初始化数据库持久化存储 (降级:连接失败不崩溃)
var sessionStore *store.SessionStore
var reminderStore *store.ReminderStore
var briefingStore *store.BriefingStore
var automationStore *store.AutomationStore
var fileStore *store.FileStore
var knowledgeStore *store.KnowledgeStore
var ruleEngine *engine.RuleEngine
databaseURL := cfg.DatabaseURL()
if s, err := store.NewSessionStore(databaseURL); err != nil {
log.Printf("⚠ 会话持久化存储初始化失败 (数据库不可用): %v", err)
log.Println("⚠ Gateway 将以仅内存模式运行 — 会话数据在重启后丢失")
log.Printf("⚠ 会话持久化存储初始化失败 (数据库不可用): %v", err)
log.Println("⚠ Gateway 将以仅内存模式运行 — 会话数据在重启后丢失")
} else {
sessionStore = s
log.Println("✅ 会话持久化存储已启用 (PostgreSQL)")
sessionStore = s
log.Println("✅ 会话持久化存储已启用 (PostgreSQL)")
// 初始化提醒存储(复用同一数据库连接)
if rs, err := store.NewReminderStore(s.DB()); err != nil {
log.Printf("⚠ 提醒存储初始化失败: %v", err)
} else {
reminderStore = rs
log.Println("✅ 提醒持久化存储已启用 (PostgreSQL)")
}
// 初始化简报存储(复用同一数据库连接)
if bs, err := store.NewBriefingStore(s.DB()); err != nil {
log.Printf("⚠ 简报存储初始化失败: %v", err)
} else {
briefingStore = bs
log.Println("✅ 简报持久化存储已启用 (PostgreSQL)")
}
// 初始化自动化存储(复用同一数据库连接)
if as, err := store.NewAutomationStore(s.DB()); err != nil {
log.Printf("⚠ 自动化存储初始化失败: %v", err)
} else {
automationStore = as
log.Println("✅ 自动化持久化存储已启用 (PostgreSQL)")
}
// 初始化文件存储(复用同一数据库连接)
if fs, err := store.NewFileStore(s.DB()); err != nil {
log.Printf("⚠ 文件存储初始化失败: %v", err)
} else {
fileStore = fs
log.Println("✅ 文件持久化存储已启用 (PostgreSQL)")
}
// 初始化知识库存储(复用同一数据库连接)
if ks, err := store.NewKnowledgeStore(s.DB()); err != nil {
log.Printf("⚠ 知识库存储初始化失败: %v", err)
} else {
knowledgeStore = ks
log.Println("✅ 知识库持久化存储已启用 (PostgreSQL)")
}
}
// 初始化 WebSocket Hub
hub := ws.NewHub()
hub.SetStore(sessionStore)
hub.SetIdleTimeout(cfg.SessionIdleTimeoutMin)
// 初始化规则引擎 (需要 Hub)
if automationStore != nil {
ruleEngine = engine.NewRuleEngine(automationStore, hub)
ruleEngine.Start()
log.Println("✅ 规则引擎已启动")
}
// 初始化Gin
@@ -49,10 +114,7 @@ func main() {
r.Use(middleware.RequestLogging())
r.Use(gin.Recovery())
// 初始化WebSocket Hub
hub := ws.NewHub()
hub.SetStore(sessionStore)
hub.SetIdleTimeout(cfg.SessionIdleTimeoutMin)
// 启动 WebSocket Hub
go hub.Run()
// 启动闲置会话清理 (标记超时会话为 idle,不删除)
@@ -62,8 +124,19 @@ func main() {
hub.StartIoTBroadcast(cfg.IoTDebugServiceURL)
// 注册路由
router.Setup(r, hub, cfg, sessionStore)
router.Setup(r, hub, cfg, sessionStore, reminderStore, briefingStore, automationStore, fileStore, ruleEngine, knowledgeStore, nil)
// 启动提醒调度器
if reminderStore != nil {
handler.StartReminderScheduler(reminderStore, hub)
}
// 启动简报调度器
if briefingStore != nil && reminderStore != nil {
briefingHandler := handler.NewBriefingHandler(cfg, hub, briefingStore, reminderStore)
handler.StartBriefingScheduler(briefingHandler, briefingStore, cfg.BriefingTime)
}
// 启动服务
srv := &http.Server{
Addr: ":" + cfg.Port,
+15 -1
View File
@@ -45,6 +45,9 @@ type Config struct {
// IoT 调试服务
IoTDebugServiceURL string
// Voice 语音识别服务
VoiceServiceURL string
// Tool-Engine 工具引擎服务
ToolEngineURL string
@@ -61,6 +64,12 @@ type Config struct {
// Webhook (第三方平台接入)
WebhookAPIKey string
// Internal Service Token (内部服务间认证)
InternalServiceToken string
// 每日简报时间 (HH:MM 格式)
BriefingTime string
}
// Load 从环境变量加载配置
@@ -95,6 +104,8 @@ func Load() *Config {
IoTDebugServiceURL: getEnv("IOT_DEBUG_SERVICE_URL", "http://localhost:8083"),
VoiceServiceURL: getEnv("VOICE_SERVICE_URL", "http://localhost:8093"),
ToolEngineURL: getEnv("TOOL_ENGINE_URL", "http://localhost:8092"),
LLMAPIURL: getEnv("LLM_API_URL", "https://api.openai.com/v1"),
@@ -104,7 +115,10 @@ func Load() *Config {
WSMaxConnections: getEnvInt("WS_MAX_CONNECTIONS", 1000),
SessionIdleTimeoutMin: getEnvInt("SESSION_IDLE_TIMEOUT_MIN", 30),
WebhookAPIKey: getEnv("WEBHOOK_API_KEY", ""),
WebhookAPIKey: getEnv("WEBHOOK_API_KEY", ""),
InternalServiceToken: getEnv("INTERNAL_SERVICE_TOKEN", "cyrene-internal-token-change-me"),
BriefingTime: getEnv("BRIEFING_TIME", "08:00"),
}
}
@@ -0,0 +1,505 @@
package engine
import (
"bytes"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/yourname/cyrene-ai/gateway/internal/store"
"github.com/yourname/cyrene-ai/gateway/internal/ws"
)
// TriggerConfig 触发器配置
type TriggerConfig struct {
Cron string `json:"cron,omitempty"`
Time string `json:"time,omitempty"`
Days []string `json:"days,omitempty"`
DeviceID string `json:"device_id,omitempty"`
Property string `json:"property,omitempty"`
Operator string `json:"operator,omitempty"`
Value float64 `json:"value,omitempty"`
}
// Condition 条件定义
type Condition struct {
Type string `json:"type"`
Start string `json:"start,omitempty"`
End string `json:"end,omitempty"`
DeviceID string `json:"device_id,omitempty"`
Property string `json:"property,omitempty"`
Operator string `json:"operator,omitempty"`
Value float64 `json:"value,omitempty"`
}
// Action 动作定义
type Action struct {
Type string `json:"type"`
DeviceID string `json:"device_id,omitempty"`
Property string `json:"property,omitempty"`
Value interface{} `json:"value,omitempty"`
Title string `json:"title,omitempty"`
Body string `json:"body,omitempty"`
}
// RuleEngine 规则引擎
type RuleEngine struct {
store *store.AutomationStore
hub *ws.Hub
iotServiceURL string
httpClient *http.Client
lastTriggered map[string]time.Time
mu sync.RWMutex
stopCh chan struct{}
running bool
}
// NewRuleEngine 创建规则引擎
func NewRuleEngine(as *store.AutomationStore, hub *ws.Hub) *RuleEngine {
iotServiceURL := os.Getenv("IOT_SERVICE_URL")
if iotServiceURL == "" {
iotServiceURL = "http://localhost:8083"
}
return &RuleEngine{
store: as,
hub: hub,
iotServiceURL: iotServiceURL,
httpClient: &http.Client{
Timeout: 5 * time.Second,
},
lastTriggered: make(map[string]time.Time),
stopCh: make(chan struct{}),
}
}
// Start 启动后台规则评估 goroutine
func (e *RuleEngine) Start() {
e.mu.Lock()
if e.running {
e.mu.Unlock()
return
}
e.running = true
e.mu.Unlock()
go e.loop()
log.Printf("[RuleEngine] 规则引擎已启动 (IoT服务地址: %s)", e.iotServiceURL)
}
// Stop 停止规则引擎
func (e *RuleEngine) Stop() {
e.mu.Lock()
defer e.mu.Unlock()
if !e.running {
return
}
close(e.stopCh)
e.running = false
log.Println("[RuleEngine] 规则引擎已停止")
}
// loop 规则引擎主循环
func (e *RuleEngine) loop() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
// 首次立即评估
e.evaluateAllRules()
for {
select {
case <-e.stopCh:
return
case <-ticker.C:
e.evaluateAllRules()
}
}
}
// evaluateAllRules 评估所有启用的规则
func (e *RuleEngine) evaluateAllRules() {
rules, err := e.store.GetEnabledRules()
if err != nil {
log.Printf("[RuleEngine] 获取启用的规则失败: %v", err)
return
}
if len(rules) == 0 {
return
}
for _, rule := range rules {
if e.evaluateRule(&rule) {
e.ExecuteRuleActions(&rule)
e.store.MarkRuleTriggered(rule.ID)
e.mu.Lock()
e.lastTriggered[rule.ID] = time.Now()
e.mu.Unlock()
}
}
}
// evaluateRule 评估单条规则是否应触发
func (e *RuleEngine) evaluateRule(rule *store.AutomationRule) bool {
// 防重复触发:同一规则在 1 分钟内不重复触发
e.mu.RLock()
lastTime, exists := e.lastTriggered[rule.ID]
e.mu.RUnlock()
if exists && time.Since(lastTime) < 1*time.Minute {
return false
}
// 解析 trigger_config
var triggerCfg TriggerConfig
if rule.TriggerConfig != nil {
if err := json.Unmarshal(*rule.TriggerConfig, &triggerCfg); err != nil {
log.Printf("[RuleEngine] 解析触发器配置失败: rule=%s err=%v", rule.ID, err)
return false
}
}
// 评估触发器
triggered := false
switch rule.TriggerType {
case "schedule":
triggered = e.evaluateScheduleTrigger(triggerCfg)
case "device_state":
triggered = e.evaluateDeviceStateTrigger(triggerCfg)
case "manual":
// 不自动触发
return false
default:
return false
}
if !triggered {
return false
}
// 评估 conditions
var conditions []Condition
if rule.Conditions != nil {
if err := json.Unmarshal(*rule.Conditions, &conditions); err != nil {
log.Printf("[RuleEngine] 解析条件失败: rule=%s err=%v", rule.ID, err)
return false
}
}
for _, cond := range conditions {
if !e.evaluateCondition(cond) {
return false
}
}
return true
}
// evaluateScheduleTrigger 评估定时触发器
func (e *RuleEngine) evaluateScheduleTrigger(cfg TriggerConfig) bool {
now := time.Now()
// 检查 days (星期)
if len(cfg.Days) > 0 {
weekday := strings.ToLower(now.Weekday().String()[:3])
found := false
for _, d := range cfg.Days {
if strings.ToLower(strings.TrimSpace(d)) == weekday {
found = true
break
}
}
if !found {
return false
}
}
// 检查 time
if cfg.Time != "" {
currentTime := now.Format("15:04")
return currentTime == cfg.Time
}
return false
}
// evaluateDeviceStateTrigger 评估设备状态触发器
func (e *RuleEngine) evaluateDeviceStateTrigger(cfg TriggerConfig) bool {
if cfg.DeviceID == "" || cfg.Property == "" || cfg.Operator == "" {
return false
}
// 从 IoT 服务获取设备状态
devices, err := e.fetchIoTDevices()
if err != nil {
log.Printf("[RuleEngine] 获取设备状态失败: %v", err)
return false
}
// 查找目标设备
for _, d := range devices {
if d.ID != cfg.DeviceID {
continue
}
var actualValue float64
switch cfg.Property {
case "temperature":
actualValue = d.Temperature
case "value":
actualValue = d.Value
case "brightness":
actualValue = float64(d.Brightness)
case "position":
actualValue = float64(d.Position)
case "battery":
actualValue = float64(d.Battery)
default:
// 尝试从 properties 中获取
if props, ok := d.Properties[cfg.Property]; ok {
if v, ok := props.(float64); ok {
actualValue = v
}
} else {
return false
}
}
return compareValues(actualValue, cfg.Operator, cfg.Value)
}
return false
}
// evaluateCondition 评估单个条件
func (e *RuleEngine) evaluateCondition(cond Condition) bool {
switch cond.Type {
case "time_range":
if cond.Start == "" || cond.End == "" {
return true
}
now := time.Now()
currentTime := now.Format("15:04")
return currentTime >= cond.Start && currentTime <= cond.End
case "device_state":
if cond.DeviceID == "" || cond.Property == "" || cond.Operator == "" {
return true
}
devices, err := e.fetchIoTDevices()
if err != nil {
return true // 无法获取设备状态时不阻塞
}
for _, d := range devices {
if d.ID != cond.DeviceID {
continue
}
var actualValue float64
switch cond.Property {
case "temperature":
actualValue = d.Temperature
case "value":
actualValue = d.Value
case "brightness":
actualValue = float64(d.Brightness)
case "position":
actualValue = float64(d.Position)
case "battery":
actualValue = float64(d.Battery)
default:
if props, ok := d.Properties[cond.Property]; ok {
if v, ok := props.(float64); ok {
actualValue = v
}
} else {
return true
}
}
return compareValues(actualValue, cond.Operator, cond.Value)
}
return true
}
return true
}
// ExecuteRuleActions 执行规则的动作
func (e *RuleEngine) ExecuteRuleActions(rule *store.AutomationRule) {
var actions []Action
if rule.Actions != nil {
if err := json.Unmarshal(*rule.Actions, &actions); err != nil {
log.Printf("[RuleEngine] 解析动作失败: rule=%s err=%v", rule.ID, err)
return
}
}
log.Printf("[RuleEngine] 执行规则 %s (%s) 的 %d 个动作", rule.ID, rule.Name, len(actions))
for _, action := range actions {
switch action.Type {
case "set_device":
e.executeSetDevice(action)
case "notify":
e.executeNotify(action, rule.UserID)
default:
log.Printf("[RuleEngine] 未知动作类型: %s", action.Type)
}
}
}
// ExecuteScene 手动触发场景
func (e *RuleEngine) ExecuteScene(sceneID, userID string) error {
rules, err := e.store.GetSceneRules(sceneID)
if err != nil {
return fmt.Errorf("获取场景规则失败: %w", err)
}
log.Printf("[RuleEngine] 执行场景 %s,共 %d 条关联规则", sceneID, len(rules))
for _, rule := range rules {
if rule.Enabled {
e.ExecuteRuleActions(&rule)
e.store.MarkRuleTriggered(rule.ID)
e.mu.Lock()
e.lastTriggered[rule.ID] = time.Now()
e.mu.Unlock()
}
}
return nil
}
// executeSetDevice 执行设备控制动作
func (e *RuleEngine) executeSetDevice(action Action) {
url := fmt.Sprintf("%s/api/v1/devices/%s/set", e.iotServiceURL, action.DeviceID)
body := map[string]interface{}{
"property": action.Property,
"value": action.Value,
}
bodyBytes, _ := json.Marshal(body)
resp, err := e.httpClient.Post(url, "application/json", bytes.NewReader(bodyBytes))
if err != nil {
log.Printf("[RuleEngine] 设备控制请求失败: device=%s property=%s err=%v",
action.DeviceID, action.Property, err)
return
}
defer resp.Body.Close()
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
log.Printf("[RuleEngine] 设备控制成功: device=%s property=%s value=%v",
action.DeviceID, action.Property, action.Value)
} else {
log.Printf("[RuleEngine] 设备控制失败: device=%s property=%s status=%d",
action.DeviceID, action.Property, resp.StatusCode)
}
}
// executeNotify 执行通知动作
func (e *RuleEngine) executeNotify(action Action, userID string) {
notif := ws.NotificationInfo{
ID: "notif_" + randomID(),
Type: "info",
Title: action.Title,
Body: action.Body,
Timestamp: time.Now().UTC().Format(time.RFC3339),
}
msg := ws.ServerMessage{
Type: "notification",
MessageID: "notif_" + randomID(),
Timestamp: time.Now().UnixMilli(),
Notification: &notif,
}
data, err := json.Marshal(msg)
if err != nil {
log.Printf("[RuleEngine] 序列化通知失败: %v", err)
return
}
e.hub.SendToUser(userID, data)
log.Printf("[RuleEngine] 通知已发送: user=%s title=%s", userID, action.Title)
}
// ========== 辅助方法 ==========
// IotDevice 设备信息(从 IoT 服务返回)
type IotDevice struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
Status string `json:"status"`
Brightness int `json:"brightness,omitempty"`
Color string `json:"color,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Mode string `json:"mode,omitempty"`
Position int `json:"position,omitempty"`
Value float64 `json:"value,omitempty"`
Unit string `json:"unit,omitempty"`
Battery int `json:"battery,omitempty"`
Properties map[string]interface{} `json:"properties,omitempty"`
}
// fetchIoTDevices 从 IoT 调试服务获取设备列表
func (e *RuleEngine) fetchIoTDevices() ([]IotDevice, error) {
resp, err := e.httpClient.Get(e.iotServiceURL + "/api/v1/devices")
if err != nil {
return nil, fmt.Errorf("请求IoT服务失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("IoT服务返回状态码 %d", resp.StatusCode)
}
var result struct {
Devices []IotDevice `json:"devices"`
Total int `json:"total"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("解析IoT设备列表失败: %w", err)
}
return result.Devices, nil
}
// compareValues 比较两个值
func compareValues(actual float64, operator string, expected float64) bool {
switch operator {
case "eq":
return actual == expected
case "neq":
return actual != expected
case "gt":
return actual > expected
case "gte":
return actual >= expected
case "lt":
return actual < expected
case "lte":
return actual <= expected
default:
return false
}
}
// randomID 使用 crypto/rand 生成随机 ID
func randomID() string {
b := make([]byte, 8)
rand.Read(b)
return hex.EncodeToString(b)
}
@@ -0,0 +1,483 @@
package handler
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"log"
"net/http"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/engine"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/store"
)
// AutomationHandler 自动化处理器
type AutomationHandler struct {
store *store.AutomationStore
engine *engine.RuleEngine
}
// NewAutomationHandler 创建自动化处理器
func NewAutomationHandler(s *store.AutomationStore, e *engine.RuleEngine) *AutomationHandler {
return &AutomationHandler{store: s, engine: e}
}
// ========== 请求/响应体 ==========
// CreateRuleRequest 创建规则请求
type CreateRuleRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
TriggerType string `json:"trigger_type" binding:"required"`
TriggerConfig json.RawMessage `json:"trigger_config"`
Conditions json.RawMessage `json:"conditions"`
Actions json.RawMessage `json:"actions" binding:"required"`
Enabled *bool `json:"enabled"`
}
// UpdateRuleRequest 更新规则请求
type UpdateRuleRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
TriggerType *string `json:"trigger_type"`
TriggerConfig *json.RawMessage `json:"trigger_config"`
Conditions *json.RawMessage `json:"conditions"`
Actions *json.RawMessage `json:"actions"`
Enabled *bool `json:"enabled"`
}
// CreateSceneRequest 创建场景请求
type CreateSceneRequest struct {
Name string `json:"name" binding:"required"`
Icon string `json:"icon"`
RuleIDs json.RawMessage `json:"rule_ids"`
}
// UpdateSceneRequest 更新场景请求
type UpdateSceneRequest struct {
Name *string `json:"name"`
Icon *string `json:"icon"`
RuleIDs *json.RawMessage `json:"rule_ids"`
}
// ========== Rule Handlers ==========
// ListRules 获取用户的所有规则
// GET /api/v1/automation/rules
func (h *AutomationHandler) ListRules(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
return
}
rules, err := h.store.GetRulesByUser(userID)
if err != nil {
log.Printf("[automation] 获取规则列表失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取规则列表失败"})
return
}
c.JSON(http.StatusOK, gin.H{
"rules": rules,
"count": len(rules),
})
}
// CreateRule 创建规则
// POST /api/v1/automation/rules
func (h *AutomationHandler) CreateRule(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
return
}
var req CreateRuleRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
triggerConfig := ensureRawMessage(req.TriggerConfig)
conditions := ensureRawMessage(req.Conditions)
actions := ensureRawMessage(req.Actions)
enabled := true
if req.Enabled != nil {
enabled = *req.Enabled
}
rule := &store.AutomationRule{
ID: randomHexID(),
UserID: userID,
Name: req.Name,
Description: req.Description,
TriggerType: req.TriggerType,
TriggerConfig: triggerConfig,
Conditions: conditions,
Actions: actions,
Enabled: enabled,
}
if err := h.store.CreateRule(rule); err != nil {
log.Printf("[automation] 创建规则失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建规则失败"})
return
}
c.JSON(http.StatusCreated, gin.H{
"success": true,
"rule": rule,
})
}
// GetRule 获取单条规则
func (h *AutomationHandler) GetRule(c *gin.Context) {
id := c.Param("id")
rule, err := h.store.GetRule(id)
if err != nil {
log.Printf("[automation] 获取规则失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取规则失败"})
return
}
if rule == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "规则不存在"})
return
}
c.JSON(http.StatusOK, gin.H{"rule": rule})
}
// UpdateRule 更新规则
// PUT /api/v1/automation/rules/:id
func (h *AutomationHandler) UpdateRule(c *gin.Context) {
userID := middleware.GetUserID(c)
id := c.Param("id")
// 先获取规则验证所有权
existing, err := h.store.GetRule(id)
if err != nil {
log.Printf("[automation] 获取规则失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取规则失败"})
return
}
if existing == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "规则不存在"})
return
}
if existing.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权修改此规则"})
return
}
var req UpdateRuleRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
// 只更新提供的字段
if req.Name != nil {
existing.Name = *req.Name
}
if req.Description != nil {
existing.Description = *req.Description
}
if req.TriggerType != nil {
existing.TriggerType = *req.TriggerType
}
if req.TriggerConfig != nil {
existing.TriggerConfig = req.TriggerConfig
}
if req.Conditions != nil {
existing.Conditions = req.Conditions
}
if req.Actions != nil {
existing.Actions = req.Actions
}
if req.Enabled != nil {
existing.Enabled = *req.Enabled
}
if err := h.store.UpdateRule(existing); err != nil {
log.Printf("[automation] 更新规则失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新规则失败"})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"rule": existing,
})
}
// DeleteRule 删除规则
// DELETE /api/v1/automation/rules/:id
func (h *AutomationHandler) DeleteRule(c *gin.Context) {
userID := middleware.GetUserID(c)
id := c.Param("id")
existing, err := h.store.GetRule(id)
if err != nil {
log.Printf("[automation] 获取规则失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取规则失败"})
return
}
if existing == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "规则不存在"})
return
}
if existing.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此规则"})
return
}
if err := h.store.DeleteRule(id); err != nil {
log.Printf("[automation] 删除规则失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除规则失败"})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// TriggerRule 手动触发单条规则
// POST /api/v1/automation/rules/:id/trigger
func (h *AutomationHandler) TriggerRule(c *gin.Context) {
userID := middleware.GetUserID(c)
id := c.Param("id")
rule, err := h.store.GetRule(id)
if err != nil {
log.Printf("[automation] 获取规则失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取规则失败"})
return
}
if rule == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "规则不存在"})
return
}
if rule.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权触发此规则"})
return
}
h.engine.ExecuteRuleActions(rule)
h.store.MarkRuleTriggered(id)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "规则已触发",
})
}
// ========== Scene Handlers ==========
// ListScenes 获取所有场景
// GET /api/v1/automation/scenes
func (h *AutomationHandler) ListScenes(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
return
}
scenes, err := h.store.GetScenesByUser(userID)
if err != nil {
log.Printf("[automation] 获取场景列表失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取场景列表失败"})
return
}
c.JSON(http.StatusOK, gin.H{
"scenes": scenes,
"count": len(scenes),
})
}
// CreateScene 创建场景
// POST /api/v1/automation/scenes
func (h *AutomationHandler) CreateScene(c *gin.Context) {
userID := middleware.GetUserID(c)
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
return
}
var req CreateSceneRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
ruleIDs := ensureRawMessage(req.RuleIDs)
scene := &store.AutomationScene{
ID: randomHexID(),
UserID: userID,
Name: req.Name,
Icon: req.Icon,
RuleIDs: ruleIDs,
}
if err := h.store.CreateScene(scene); err != nil {
log.Printf("[automation] 创建场景失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建场景失败"})
return
}
c.JSON(http.StatusCreated, gin.H{
"success": true,
"scene": scene,
})
}
// GetScene 获取单个场景
func (h *AutomationHandler) GetScene(c *gin.Context) {
id := c.Param("id")
scene, err := h.store.GetScene(id)
if err != nil {
log.Printf("[automation] 获取场景失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取场景失败"})
return
}
if scene == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "场景不存在"})
return
}
c.JSON(http.StatusOK, gin.H{"scene": scene})
}
// UpdateScene 更新场景
// PUT /api/v1/automation/scenes/:id
func (h *AutomationHandler) UpdateScene(c *gin.Context) {
userID := middleware.GetUserID(c)
id := c.Param("id")
existing, err := h.store.GetScene(id)
if err != nil {
log.Printf("[automation] 获取场景失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取场景失败"})
return
}
if existing == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "场景不存在"})
return
}
if existing.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权修改此场景"})
return
}
var req UpdateSceneRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
if req.Name != nil {
existing.Name = *req.Name
}
if req.Icon != nil {
existing.Icon = *req.Icon
}
if req.RuleIDs != nil {
existing.RuleIDs = req.RuleIDs
}
if err := h.store.UpdateScene(existing); err != nil {
log.Printf("[automation] 更新场景失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新场景失败"})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"scene": existing,
})
}
// DeleteScene 删除场景
// DELETE /api/v1/automation/scenes/:id
func (h *AutomationHandler) DeleteScene(c *gin.Context) {
userID := middleware.GetUserID(c)
id := c.Param("id")
existing, err := h.store.GetScene(id)
if err != nil {
log.Printf("[automation] 获取场景失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取场景失败"})
return
}
if existing == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "场景不存在"})
return
}
if existing.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此场景"})
return
}
if err := h.store.DeleteScene(id); err != nil {
log.Printf("[automation] 删除场景失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除场景失败"})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// ExecuteScene 手动执行场景
// POST /api/v1/automation/scenes/:id/execute
func (h *AutomationHandler) ExecuteScene(c *gin.Context) {
id := c.Param("id")
// 验证场景存在
scene, err := h.store.GetScene(id)
if err != nil {
log.Printf("[automation] 获取场景失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取场景失败"})
return
}
if scene == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "场景不存在"})
return
}
userID := middleware.GetUserID(c)
if err := h.engine.ExecuteScene(id, userID); err != nil {
log.Printf("[automation] 执行场景失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "执行场景失败"})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "场景已执行",
})
}
// ========== 辅助函数 ==========
// randomHexID 使用 crypto/rand 生成 16 字节 hex ID
func randomHexID() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}
// ensureRawMessage 确保 json.RawMessage 非空
func ensureRawMessage(raw json.RawMessage) *json.RawMessage {
if len(raw) == 0 {
return nil
}
result := make(json.RawMessage, len(raw))
copy(result, raw)
return &result
}
@@ -0,0 +1,709 @@
package handler
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/config"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/store"
"github.com/yourname/cyrene-ai/gateway/internal/ws"
)
// BriefingHandler 每日简报处理器
type BriefingHandler struct {
cfg *config.Config
hub *ws.Hub
briefingStore *store.BriefingStore
reminderStore *store.ReminderStore
httpClient *http.Client
}
// NewBriefingHandler 创建简报处理器
func NewBriefingHandler(cfg *config.Config, hub *ws.Hub, bs *store.BriefingStore, rs *store.ReminderStore) *BriefingHandler {
return &BriefingHandler{
cfg: cfg,
hub: hub,
briefingStore: bs,
reminderStore: rs,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// GenerateBriefingRequest 手动生成简报请求体
type GenerateBriefingRequest struct {
UserID string `json:"user_id" binding:"required"`
}
// GetBriefing 获取指定日期简报
// GET /api/v1/briefings?user_id=xxx&date=2024-01-01
func (h *BriefingHandler) GetBriefing(c *gin.Context) {
authUserID := middleware.GetUserID(c)
userID := c.Query("user_id")
date := c.Query("date")
if !strings.HasPrefix(authUserID, "admin_") || userID == "" {
userID = authUserID
}
if userID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少user_id参数"})
return
}
if date == "" {
date = time.Now().Format("2006-01-02")
}
briefing, err := h.briefingStore.GetBriefingByDate(userID, date)
if err != nil {
log.Printf("[briefing] 查询简报失败: user=%s date=%s err=%v", userID, date, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询简报失败: " + err.Error()})
return
}
if briefing == nil {
c.JSON(http.StatusOK, gin.H{
"briefing": nil,
"message": "当日简报尚未生成",
})
return
}
c.JSON(http.StatusOK, gin.H{"briefing": briefing})
}
// GetLatestBriefings 获取最近简报列表
// GET /api/v1/briefings/latest?user_id=xxx&limit=7
func (h *BriefingHandler) GetLatestBriefings(c *gin.Context) {
authUserID := middleware.GetUserID(c)
userID := c.Query("user_id")
if !strings.HasPrefix(authUserID, "admin_") || userID == "" {
userID = authUserID
}
if userID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少user_id参数"})
return
}
limit := 7
if l := c.Query("limit"); l != "" {
if parsed, err := parseInt(l); err == nil && parsed > 0 && parsed <= 30 {
limit = parsed
}
}
briefings, err := h.briefingStore.GetLatestBriefings(userID, limit)
if err != nil {
log.Printf("[briefing] 查询简报列表失败: user=%s err=%v", userID, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询简报列表失败: " + err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"briefings": briefings,
"total": len(briefings),
})
}
// Generate 手动触发生成今日简报
// POST /api/v1/briefings/generate
func (h *BriefingHandler) Generate(c *gin.Context) {
authUserID := middleware.GetUserID(c)
var req GenerateBriefingRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
// 非管理员只能为自己生成
if !strings.HasPrefix(authUserID, "admin_") {
req.UserID = authUserID
}
if req.UserID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少user_id"})
return
}
result, err := h.GenerateDailyBriefing(req.UserID)
if err != nil {
log.Printf("[briefing] 生成简报失败: user=%s err=%v", req.UserID, err)
c.JSON(http.StatusInternalServerError, gin.H{
"error": "生成简报失败: " + err.Error(),
"success": false,
})
return
}
// 生成后推送通知
h.pushBriefingNotification(req.UserID, result)
c.JSON(http.StatusOK, gin.H{
"success": true,
"briefing": result,
"message": "简报已生成并推送",
})
}
// GenerateDailyBriefing 生成每日简报(核心逻辑)
func (h *BriefingHandler) GenerateDailyBriefing(userID string) (*store.Briefing, error) {
today := time.Now().Format("2006-01-02")
briefing := &store.Briefing{
ID: "brief_" + generateID(),
UserID: userID,
Date: today,
Status: "pending",
Weather: &store.WeatherData{},
News: []store.NewsItem{},
Reminders: []store.BriefReminder{},
}
// 1. 获取天气数据
log.Printf("[briefing] 获取天气数据...")
weather, err := h.fetchWeather("Shanghai")
if err != nil {
log.Printf("[briefing] 天气获取失败 (降级): %v", err)
weather = &store.WeatherData{
Location: "未知",
Temp: 0,
Condition: "获取天气失败",
Icon: "❓",
}
}
briefing.Weather = weather
log.Printf("[briefing] 天气: %s %.1f°C %s", weather.Location, weather.Temp, weather.Condition)
// 2. 获取今日待办提醒
log.Printf("[briefing] 获取待办提醒...")
reminders, err := h.reminderStore.GetRemindersByUser(userID, "pending", 10, 0)
if err != nil {
log.Printf("[briefing] 获取提醒失败: %v", err)
} else {
now := time.Now()
endOfDay := time.Date(now.Year(), now.Month(), now.Day(), 23, 59, 59, 0, now.Location())
for _, r := range reminders {
if r.RemindAt.Before(endOfDay) || r.RemindAt.Equal(endOfDay) {
briefing.Reminders = append(briefing.Reminders, store.BriefReminder{
ID: r.ID,
Title: r.Title,
RemindAt: r.RemindAt.Format(time.RFC3339),
})
}
}
}
log.Printf("[briefing] 今日待办: %d 项", len(briefing.Reminders))
// 3. 获取新闻摘要(通过 tool-engine web_search
log.Printf("[briefing] 获取新闻摘要...")
news, err := h.fetchNews()
if err != nil {
log.Printf("[briefing] 新闻获取失败 (降级): %v", err)
}
briefing.News = news
log.Printf("[briefing] 新闻: %d 条", len(news))
// 4. 生成 AI 摘要
log.Printf("[briefing] 生成 AI 摘要...")
summary, err := h.generateAISummary(briefing)
if err != nil {
log.Printf("[briefing] AI 摘要生成失败 (降级): %v", err)
summary = h.buildFallbackSummary(briefing)
}
briefing.Summary = summary
// 5. 标记为已生成
now := time.Now()
briefing.Status = "generated"
briefing.GeneratedAt = &now
// 6. 持久化
if err := h.briefingStore.CreateOrUpdateBriefing(briefing); err != nil {
return nil, fmt.Errorf("保存简报失败: %w", err)
}
log.Printf("[briefing] 简报已生成: user=%s date=%s", userID, today)
return briefing, nil
}
// fetchWeather 通过 wttr.in API 获取天气数据
func (h *BriefingHandler) fetchWeather(location string) (*store.WeatherData, error) {
url := fmt.Sprintf("https://wttr.in/%s?format=j1", location)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, fmt.Errorf("构建天气请求失败: %w", err)
}
req.Header.Set("User-Agent", "Cyrene-AI/1.0")
resp, err := h.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("请求天气API失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("天气API返回状态码 %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取天气响应失败: %w", err)
}
// 解析 wttr.in JSON 响应
var wttrResp struct {
CurrentCondition []struct {
TempC string `json:"temp_C"`
WeatherDesc []struct {
Value string `json:"value"`
} `json:"weatherDesc"`
WeatherIconURL []struct {
Value string `json:"value"`
} `json:"weatherIconUrl"`
} `json:"current_condition"`
NearestArea []struct {
AreaName []struct {
Value string `json:"value"`
} `json:"areaName"`
} `json:"nearest_area"`
}
if err := json.Unmarshal(body, &wttrResp); err != nil {
return nil, fmt.Errorf("解析天气数据失败: %w", err)
}
wd := &store.WeatherData{
Location: location,
}
if len(wttrResp.NearestArea) > 0 && len(wttrResp.NearestArea[0].AreaName) > 0 {
wd.Location = wttrResp.NearestArea[0].AreaName[0].Value
}
if len(wttrResp.CurrentCondition) > 0 {
cc := wttrResp.CurrentCondition[0]
wd.Temp = parseFloat(cc.TempC)
if len(cc.WeatherDesc) > 0 {
wd.Condition = cc.WeatherDesc[0].Value
}
// 根据天气描述转 emoji
wd.Icon = weatherEmoji(wd.Condition)
}
return wd, nil
}
// fetchNews 通过 tool-engine web_search 搜索今日新闻
func (h *BriefingHandler) fetchNews() ([]store.NewsItem, error) {
if h.cfg.ToolEngineURL == "" {
return nil, fmt.Errorf("ToolEngine URL 未配置")
}
today := time.Now().Format("2006年01月02日")
query := fmt.Sprintf("%s 今日要闻 热点新闻", today)
reqBody, _ := json.Marshal(map[string]interface{}{
"arguments": map[string]interface{}{
"query": query,
"limit": 5,
},
})
url := fmt.Sprintf("%s/api/v1/tools/web_search/execute", h.cfg.ToolEngineURL)
req, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("构建新闻搜索请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := h.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("请求新闻搜索失败: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取新闻搜索结果失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("新闻搜索返回状态码 %d: %s", resp.StatusCode, string(body))
}
// 解析 tool-engine 单个工具执行响应: {id, output, error?}
var result struct {
ID string `json:"id"`
Output string `json:"output"`
Error string `json:"error,omitempty"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("解析新闻搜索结果失败: %w", err)
}
if result.Error != "" {
log.Printf("[briefing] 新闻搜索失败: %s", result.Error)
// 返回降级新闻
return []store.NewsItem{
{
Title: "未能获取今日新闻",
URL: "",
Source: "系统",
Summary: "新闻搜索服务暂时不可用,请稍后再试。",
},
}, nil
}
var news []store.NewsItem
// 尝试解析搜索结果为结构化数据
var searchResults []struct {
Title string `json:"title"`
URL string `json:"url"`
Snippet string `json:"snippet"`
Source string `json:"source"`
}
if err := json.Unmarshal([]byte(result.Output), &searchResults); err != nil {
// 如果不是 JSON 数组,当做纯文本处理
news = append(news, store.NewsItem{
Title: "今日新闻",
URL: "",
Source: "搜索引擎",
Summary: truncateStr(result.Output, 200),
})
} else {
for _, sr := range searchResults {
news = append(news, store.NewsItem{
Title: sr.Title,
URL: sr.URL,
Source: sr.Source,
Summary: sr.Snippet,
})
}
}
if len(news) == 0 {
news = []store.NewsItem{
{
Title: "未能获取今日新闻",
URL: "",
Source: "系统",
Summary: "新闻搜索服务暂时不可用,请稍后再试。",
},
}
}
// 限制最多 5 条
if len(news) > 5 {
news = news[:5]
}
return news, nil
}
// generateAISummary 通过 AI-Core 生成人性化摘要
func (h *BriefingHandler) generateAISummary(b *store.Briefing) (string, error) {
if h.cfg.AICoreURL == "" {
return "", fmt.Errorf("AI-Core URL 未配置")
}
// 构建提示词
prompt := h.buildSummaryPrompt(b)
reqBody, _ := json.Marshal(map[string]interface{}{
"messages": []map[string]interface{}{
{
"role": "system",
"content": "你是昔涟,一个温柔贴心的AI助手。请用温暖、亲切的语气回复,像朋友一样关心用户。回复使用中文。",
},
{
"role": "user",
"content": prompt,
},
},
"max_tokens": 500,
"temperature": 0.7,
})
url := fmt.Sprintf("%s/api/v1/chat/completions", h.cfg.AICoreURL)
req, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
if err != nil {
return "", fmt.Errorf("构建 AI 请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := h.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("请求 AI-Core 失败: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("读取 AI 响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("AI-Core 返回状态码 %d: %s", resp.StatusCode, string(body))
}
// 解析 OpenAI 兼容响应
var aiResp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &aiResp); err != nil {
return "", fmt.Errorf("解析 AI 响应失败: %w", err)
}
if len(aiResp.Choices) == 0 {
return "", fmt.Errorf("AI 返回空响应")
}
return aiResp.Choices[0].Message.Content, nil
}
// buildSummaryPrompt 构建 AI 摘要提示词
func (h *BriefingHandler) buildSummaryPrompt(b *store.Briefing) string {
var sb strings.Builder
today := time.Now().Format("2006年01月02日")
sb.WriteString(fmt.Sprintf("今天是%s。请根据以下信息,用昔涟温柔的语气为用户生成一份简短的每日简报(控制在200字以内):\n\n", today))
// 天气
if b.Weather != nil && b.Weather.Condition != "" {
sb.WriteString(fmt.Sprintf("☁️ 天气:%s%.0f°C%s\n", b.Weather.Location, b.Weather.Temp, b.Weather.Condition))
}
// 待办
if len(b.Reminders) > 0 {
sb.WriteString("📋 今日待办:\n")
for _, r := range b.Reminders {
sb.WriteString(fmt.Sprintf(" - %s\n", r.Title))
}
}
// 新闻
if len(b.News) > 0 {
sb.WriteString("📰 今日新闻:\n")
for _, n := range b.News {
if n.Summary != "" {
sb.WriteString(fmt.Sprintf(" - %s: %s\n", n.Title, n.Summary))
} else {
sb.WriteString(fmt.Sprintf(" - %s\n", n.Title))
}
}
}
sb.WriteString("\n请用昔涟的语气回复,包含:1) 温馨问候 2) 天气提醒 3) 待办提醒 4) 新闻简要 5) 结语祝福。简洁自然即可。")
return sb.String()
}
// buildFallbackSummary 降级摘要(不依赖 AI
func (h *BriefingHandler) buildFallbackSummary(b *store.Briefing) string {
today := time.Now().Format("2006年01月02日")
var sb strings.Builder
sb.WriteString(fmt.Sprintf("早上好!今天是%s ☀️\n\n", today))
if b.Weather != nil && b.Weather.Condition != "" {
sb.WriteString(fmt.Sprintf("今日%s天气:%s%.0f°C。", b.Weather.Location, b.Weather.Condition, b.Weather.Temp))
if b.Weather.Temp < 10 {
sb.WriteString("天气有点凉,记得多穿件衣服哦~")
} else if b.Weather.Temp > 30 {
sb.WriteString("天气比较热,注意防暑降温哦~")
} else {
sb.WriteString("天气不错,适合出门走走呢~")
}
sb.WriteString("\n\n")
}
if len(b.Reminders) > 0 {
sb.WriteString(fmt.Sprintf("你今天有 %d 项待办事项,记得按时完成哦!\n", len(b.Reminders)))
} else {
sb.WriteString("今天没有待办事项,可以轻松一下~\n")
}
if len(b.News) > 0 && b.News[0].Title != "未能获取今日新闻" {
sb.WriteString(fmt.Sprintf("\n今日热点:%s。", b.News[0].Title))
}
sb.WriteString("\n\n祝你度过美好的一天!🌸")
return sb.String()
}
// pushBriefingNotification 推送简报通知到用户
func (h *BriefingHandler) pushBriefingNotification(userID string, b *store.Briefing) {
bodyPreview := truncateStr(b.Summary, 100)
notif := &ws.NotificationInfo{
ID: "briefing_" + b.ID,
Type: "info",
Title: fmt.Sprintf("📋 今日简报 (%s)", b.Date),
Body: bodyPreview,
Timestamp: time.Now().UTC().Format(time.RFC3339),
Data: map[string]interface{}{
"briefing_id": b.ID,
"date": b.Date,
"type": "daily_briefing",
},
}
msg := ws.ServerMessage{
Type: "notification",
MessageID: "briefing_" + b.ID,
Timestamp: time.Now().UnixMilli(),
Notification: notif,
}
data, err := json.Marshal(msg)
if err != nil {
log.Printf("[briefing] 序列化简报通知失败: %v", err)
return
}
h.hub.SendToUser(userID, data)
// 更新简报状态为已送达
now := time.Now()
b.Status = "delivered"
b.DeliveredAt = &now
if err := h.briefingStore.CreateOrUpdateBriefing(b); err != nil {
log.Printf("[briefing] 更新简报送达状态失败: %v", err)
}
log.Printf("[briefing] 简报通知已推送: user=%s date=%s", userID, b.Date)
}
// StartBriefingScheduler 启动简报调度器
// briefingTime 格式: "HH:MM",默认 "08:00"
func StartBriefingScheduler(handler *BriefingHandler, briefingStore *store.BriefingStore, briefingTime string) {
if briefingTime == "" {
briefingTime = "08:00"
}
go func() {
// 每 30 秒检查一次是否到达简报时间
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
log.Printf("[BriefingScheduler] 简报调度器已启动 (简报时间: %s)", briefingTime)
// 记录今天是否已触发
lastTriggeredDate := ""
for range ticker.C {
now := time.Now()
currentTime := now.Format("15:04")
currentDate := now.Format("2006-01-02")
// 检查是否到达简报时间且今天尚未触发
if currentTime == briefingTime && currentDate != lastTriggeredDate {
log.Printf("[BriefingScheduler] 触发每日简报生成: %s", currentDate)
lastTriggeredDate = currentDate
// 获取所有用户
users, err := briefingStore.GetAllUsers()
if err != nil {
log.Printf("[BriefingScheduler] 获取用户列表失败: %v", err)
continue
}
// 如果没有从 reminders 表获取到用户,也尝试从 briefings 表获取
if len(users) == 0 {
users, _ = briefingStore.GetUsersWithBriefings()
}
if len(users) == 0 {
log.Println("[BriefingScheduler] 没有找到用户,跳过简报生成")
continue
}
for _, userID := range users {
log.Printf("[BriefingScheduler] 为用户 %s 生成简报...", userID)
result, err := handler.GenerateDailyBriefing(userID)
if err != nil {
log.Printf("[BriefingScheduler] 生成简报失败: user=%s err=%v", userID, err)
continue
}
handler.pushBriefingNotification(userID, result)
}
log.Printf("[BriefingScheduler] 每日简报已生成完毕,共 %d 个用户", len(users))
}
}
}()
}
// ========== 辅助函数 ==========
// weatherEmoji 根据天气描述返回对应 emoji
func weatherEmoji(condition string) string {
c := strings.ToLower(condition)
switch {
case strings.Contains(c, "sunny") || strings.Contains(c, "clear") || strings.Contains(c, "晴"):
return "☀️"
case strings.Contains(c, "partly cloudy") || strings.Contains(c, "多云"):
return "⛅"
case strings.Contains(c, "cloudy") || strings.Contains(c, "阴"):
return "☁️"
case strings.Contains(c, "rain") || strings.Contains(c, "drizzle") || strings.Contains(c, "雨"):
return "🌧️"
case strings.Contains(c, "thunder") || strings.Contains(c, "雷"):
return "⛈️"
case strings.Contains(c, "snow") || strings.Contains(c, "雪"):
return "❄️"
case strings.Contains(c, "fog") || strings.Contains(c, "mist") || strings.Contains(c, "雾"):
return "🌫️"
case strings.Contains(c, "wind") || strings.Contains(c, "风"):
return "💨"
default:
return "🌤️"
}
}
// parseFloat 安全解析浮点数
func parseFloat(s string) float64 {
var f float64
fmt.Sscanf(s, "%f", &f)
return f
}
// parseInt 安全解析整数
func parseInt(s string) (int, error) {
var n int
_, err := fmt.Sscanf(s, "%d", &n)
return n, err
}
// truncateStr 截断字符串
func truncateStr(s string, maxLen int) string {
runes := []rune(s)
if len(runes) <= maxLen {
return s
}
return string(runes[:maxLen]) + "..."
}
@@ -128,12 +128,15 @@ func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage)
h.hub.UpdateSessionState(client.SessionID, "thinking")
// 构建 AI-Core 请求
aiReq := map[string]string{
aiReq := map[string]interface{}{
"user_id": client.UserID,
"session_id": client.SessionID,
"message": msg.Content,
"mode": mode,
}
if len(msg.Attachments) > 0 {
aiReq["attachments"] = msg.Attachments
}
reqBody, err := json.Marshal(aiReq)
if err != nil {
log.Printf("[chat] 序列化请求失败: %v", err)
@@ -148,12 +151,16 @@ func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage)
}
// 缓存用户消息(在 goroutine 前完成,避免竞态)
h.hub.CacheMessage(client.UserID, client.SessionID, ws.Message{
userMsg := ws.Message{
ID: "msg_" + generateID(),
Role: "user",
Content: msg.Content,
Timestamp: time.Now().UnixMilli(),
})
}
if len(msg.Attachments) > 0 {
userMsg.Attachments = msg.Attachments
}
h.hub.CacheMessage(client.UserID, client.SessionID, userMsg)
// 在 goroutine 中进行 AI-Core 调用和流式发送,避免阻塞 ReadPump
go h.streamResponse(client, mode, reqBody, msg.Content)
@@ -0,0 +1,706 @@
package handler
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"image"
"image/color"
"image/jpeg"
"image/png"
"io"
"log"
"net/http"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/store"
)
// FileHandler 文件管理处理器
type FileHandler struct {
store *store.FileStore
uploadDir string
}
// NewFileHandler 创建文件处理器
func NewFileHandler(s *store.FileStore) *FileHandler {
return &FileHandler{
store: s,
uploadDir: "./uploads",
}
}
// ========== 允许的文件类型 ==========
var allowedMimeTypes = map[string]bool{
// 图片
"image/jpeg": true,
"image/png": true,
"image/gif": true,
"image/webp": true,
"image/svg+xml": true,
// 文档
"application/pdf": true,
"text/plain": true,
"text/markdown": true,
"application/msword": true,
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": true,
// 音频
"audio/mpeg": true,
"audio/wav": true,
"audio/ogg": true,
"audio/webm": true,
// 视频
"video/mp4": true,
"video/webm": true,
}
var allowedExtensions = map[string]string{
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".webp": "image/webp",
".svg": "image/svg+xml",
".pdf": "application/pdf",
".txt": "text/plain",
".md": "text/markdown",
".doc": "application/msword",
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".mp3": "audio/mpeg",
".wav": "audio/wav",
".ogg": "audio/ogg",
".mp4": "video/mp4",
}
const maxFileSize = 20 * 1024 * 1024 // 20MB
// ========== POST /api/v1/files/upload ==========
// Upload 处理文件上传
func (h *FileHandler) Upload(c *gin.Context) {
userID := middleware.GetUserID(c)
// Nil store guard — 数据库不可用时拒绝上传
if h.store == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用,数据库未连接", "errorType": "service_unavailable"})
return
}
// 获取上传的文件
file, header, err := c.Request.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到上传文件", "errorType": "missing_file"})
return
}
defer file.Close()
// 检查文件大小
if header.Size > maxFileSize {
c.JSON(http.StatusBadRequest, gin.H{
"error": "文件大小超过限制 (最大 20MB)",
"errorType": "file_too_large",
})
return
}
// 验证文件类型
mimeType := header.Header.Get("Content-Type")
ext := strings.ToLower(filepath.Ext(header.Filename))
if mimeType == "" || mimeType == "application/octet-stream" {
// 尝试从扩展名推断
if inferred, ok := allowedExtensions[ext]; ok {
mimeType = inferred
}
}
if !allowedMimeTypes[mimeType] {
// 再尝试通过扩展名检查
if _, ok := allowedExtensions[ext]; !ok {
c.JSON(http.StatusBadRequest, gin.H{
"error": "不支持的文件类型: " + mimeType,
"errorType": "unsupported_type",
})
return
}
mimeType = allowedExtensions[ext]
}
// 安全化文件名:移除路径分隔符和特殊字符
safeFilename := sanitizeFilename(header.Filename)
// 生成文件ID (crypto/rand UUID v4)
fileID := generateUUID()
// 创建按日期组织的目录
dateDir := time.Now().Format("2006-01-02")
storedDir := filepath.Join(h.uploadDir, dateDir)
if err := os.MkdirAll(storedDir, 0755); err != nil {
log.Printf("[FileHandler] 创建上传目录失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建上传目录失败", "errorType": "server_error"})
return
}
// 存储路径:以UUID+原始扩展名保存
storedFilename := fileID + ext
storedPath := filepath.Join(storedDir, storedFilename)
// 计算 SHA256 hash
hasher := sha256.New()
teeReader := io.TeeReader(file, hasher)
// 保存到磁盘
dst, err := os.Create(storedPath)
if err != nil {
log.Printf("[FileHandler] 创建文件失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存文件失败", "errorType": "server_error"})
return
}
defer dst.Close()
written, err := io.Copy(dst, teeReader)
if err != nil {
os.Remove(storedPath)
log.Printf("[FileHandler] 写入文件失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "写入文件失败", "errorType": "server_error"})
return
}
hash := hex.EncodeToString(hasher.Sum(nil))
// 去重检查:如果存在相同hash的文件,复用已有记录
if existing, err := h.store.GetFileByHash(hash); err == nil && existing != nil {
// 删除刚保存的重复文件
os.Remove(storedPath)
log.Printf("[FileHandler] 文件去重: 复用已有文件 %s (hash=%s)", existing.ID, hash[:16])
c.JSON(http.StatusOK, gin.H{
"id": existing.ID,
"filename": existing.Filename,
"mime_type": existing.MimeType,
"size": existing.Size,
"url": fmt.Sprintf("/api/v1/files/%s/download", existing.ID),
"dedup": true,
})
return
}
// 创建数据库记录
fileRecord := &store.File{
ID: fileID,
UserID: userID,
Filename: safeFilename,
StoredPath: storedPath,
MimeType: mimeType,
Size: written,
Hash: hash,
IsPublic: false,
CreatedAt: time.Now(),
}
if err := h.store.CreateFile(fileRecord); err != nil {
os.Remove(storedPath)
log.Printf("[FileHandler] 创建文件记录失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建文件记录失败", "errorType": "db_error"})
return
}
log.Printf("[FileHandler] 文件上传成功: %s (%s, %d bytes, hash=%s)", fileID, safeFilename, written, hash[:16])
c.JSON(http.StatusCreated, gin.H{
"id": fileID,
"filename": safeFilename,
"mime_type": mimeType,
"size": written,
"url": fmt.Sprintf("/api/v1/files/%s/download", fileID),
})
}
// ========== GET /api/v1/files ==========
// List 列出用户的所有文件 (支持分页)
func (h *FileHandler) List(c *gin.Context) {
userID := middleware.GetUserID(c)
if h.store == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
files, total, err := h.store.GetUserFiles(userID, page, limit)
if err != nil {
log.Printf("[FileHandler] 查询文件列表失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文件列表失败", "errorType": "db_error"})
return
}
// 转换为响应格式
result := make([]gin.H, 0, len(files))
for _, f := range files {
item := gin.H{
"id": f.ID,
"user_id": f.UserID,
"filename": f.Filename,
"mime_type": f.MimeType,
"size": f.Size,
"hash": f.Hash,
"is_public": f.IsPublic,
"created_at": f.CreatedAt.UnixMilli(),
"url": fmt.Sprintf("/api/v1/files/%s/download", f.ID),
}
// 图片类型添加缩略图URL
if isImageType(f.MimeType) {
item["thumbnail_url"] = fmt.Sprintf("/api/v1/files/%s/thumbnail", f.ID)
}
result = append(result, item)
}
c.JSON(http.StatusOK, gin.H{
"files": result,
"total": total,
"page": page,
"limit": limit,
})
}
// ========== GET /api/v1/files/:id ==========
// Get 获取文件元数据
func (h *FileHandler) Get(c *gin.Context) {
userID := middleware.GetUserID(c)
fileID := c.Param("id")
if h.store == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
return
}
f, err := h.store.GetFile(fileID)
if err != nil {
log.Printf("[FileHandler] 查询文件失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文件失败", "errorType": "db_error"})
return
}
if f == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "文件不存在",
"errorType": "file_not_found",
})
return
}
// 检查权限
if f.UserID != userID && !f.IsPublic {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文件", "errorType": "access_denied"})
return
}
item := gin.H{
"id": f.ID,
"user_id": f.UserID,
"filename": f.Filename,
"mime_type": f.MimeType,
"size": f.Size,
"hash": f.Hash,
"is_public": f.IsPublic,
"created_at": f.CreatedAt.UnixMilli(),
"url": fmt.Sprintf("/api/v1/files/%s/download", f.ID),
}
if isImageType(f.MimeType) {
item["thumbnail_url"] = fmt.Sprintf("/api/v1/files/%s/thumbnail", f.ID)
}
c.JSON(http.StatusOK, item)
}
// ========== GET /api/v1/files/:id/download ==========
// Download 下载文件
func (h *FileHandler) Download(c *gin.Context) {
userID := middleware.GetUserID(c)
fileID := c.Param("id")
if h.store == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
return
}
f, err := h.store.GetFile(fileID)
if err != nil {
log.Printf("[FileHandler] 查询文件失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文件失败", "errorType": "db_error"})
return
}
if f == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "文件不存在",
"errorType": "file_not_found",
})
return
}
// 检查权限
if f.UserID != userID && !f.IsPublic {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文件", "errorType": "access_denied"})
return
}
// 检查文件在磁盘上是否存在
if _, err := os.Stat(f.StoredPath); os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{
"error": "文件实体不存在(可能已被清理)",
"errorType": "file_missing",
})
return
}
// 设置 Content-Disposition
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, f.Filename))
c.Header("Content-Type", f.MimeType)
c.File(f.StoredPath)
}
// ========== DELETE /api/v1/files/:id ==========
// Delete 删除文件
func (h *FileHandler) Delete(c *gin.Context) {
userID := middleware.GetUserID(c)
fileID := c.Param("id")
if h.store == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
return
}
f, err := h.store.GetFile(fileID)
if err != nil {
log.Printf("[FileHandler] 查询文件失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文件失败", "errorType": "db_error"})
return
}
if f == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "文件不存在",
"errorType": "file_not_found",
})
return
}
// 检查权限
if f.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此文件", "errorType": "access_denied"})
return
}
// 删除磁盘上的文件(忽略错误,可能已被删除)
if err := os.Remove(f.StoredPath); err != nil && !os.IsNotExist(err) {
log.Printf("[FileHandler] 删除磁盘文件失败 (stored_path=%s): %v", f.StoredPath, err)
}
// 删除数据库记录
if err := h.store.DeleteFile(fileID); err != nil {
log.Printf("[FileHandler] 删除文件记录失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除文件记录失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
// ========== GET /api/v1/files/:id/thumbnail ==========
// Thumbnail 返回文件缩略图
func (h *FileHandler) Thumbnail(c *gin.Context) {
userID := middleware.GetUserID(c)
fileID := c.Param("id")
if h.store == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
return
}
f, err := h.store.GetFile(fileID)
if err != nil {
log.Printf("[FileHandler] 查询文件失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文件失败", "errorType": "db_error"})
return
}
if f == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "文件不存在",
"errorType": "file_not_found",
})
return
}
if f.UserID != userID && !f.IsPublic {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文件", "errorType": "access_denied"})
return
}
// 如果是图片,生成缩略图
if isImageType(f.MimeType) && f.MimeType != "image/svg+xml" {
if thumbData, contentType, err := generateThumbnail(f.StoredPath, f.MimeType); err == nil {
c.Header("Content-Type", contentType)
c.Header("Cache-Control", "public, max-age=86400")
c.Data(http.StatusOK, contentType, thumbData)
return
} else {
log.Printf("[FileHandler] 生成缩略图失败: %v", err)
}
}
// 非图片文件或缩略图生成失败,返回占位图标 SVG
placeholder := generatePlaceholderSVG(f.MimeType)
c.Header("Content-Type", "image/svg+xml")
c.Header("Cache-Control", "public, max-age=86400")
c.Data(http.StatusOK, "image/svg+xml", []byte(placeholder))
}
// ========== 辅助函数 ==========
// isImageType 判断是否图片类型
func isImageType(mimeType string) bool {
return strings.HasPrefix(mimeType, "image/")
}
// sanitizeFilename 安全化文件名:移除路径分隔符、特殊字符
var unsafeChars = regexp.MustCompile(`[\\/:*?"<>|]`)
func sanitizeFilename(name string) string {
// 移除路径分隔符和Windows特殊字符
name = unsafeChars.ReplaceAllString(name, "_")
// 限制长度
if len(name) > 255 {
ext := filepath.Ext(name)
base := name[:255-len(ext)]
name = base + ext
}
// 为空时给默认名
if name == "" || name == "." || name == ".." {
name = "unnamed_file"
}
return name
}
// generateThumbnail 使用 Go 标准库生成缩略图 (最大 300x300)
func generateThumbnail(filePath, mimeType string) ([]byte, string, error) {
// 打开源文件
srcFile, err := os.Open(filePath)
if err != nil {
return nil, "", fmt.Errorf("打开文件失败: %w", err)
}
defer srcFile.Close()
// 解码图像
var srcImg image.Image
switch mimeType {
case "image/jpeg":
srcImg, err = jpeg.Decode(srcFile)
case "image/png":
srcImg, err = png.Decode(srcFile)
default:
// 尝试通用解码
srcImg, _, err = image.Decode(srcFile)
}
if err != nil {
return nil, "", fmt.Errorf("解码图像失败: %w", err)
}
// 计算缩略图尺寸
bounds := srcImg.Bounds()
srcW := bounds.Dx()
srcH := bounds.Dy()
maxDim := 300
newW, newH := srcW, srcH
if srcW > maxDim || srcH > maxDim {
if srcW > srcH {
newW = maxDim
newH = srcH * maxDim / srcW
} else {
newH = maxDim
newW = srcW * maxDim / srcH
}
}
if newW < 1 {
newW = 1
}
if newH < 1 {
newH = 1
}
// 使用标准库双线性缩放
thumbImg := image.NewRGBA(image.Rect(0, 0, newW, newH))
scaleBilinear(thumbImg, srcImg)
// 编码为 JPEG 输出
var buf strings.Builder
errWriter := &stringWriter{&buf}
if err := jpeg.Encode(errWriter, thumbImg, &jpeg.Options{Quality: 80}); err != nil {
return nil, "", fmt.Errorf("编码缩略图失败: %w", err)
}
return []byte(buf.String()), "image/jpeg", nil
}
// stringWriter 实现 io.Writer 到 strings.Builder
type stringWriter struct {
b *strings.Builder
}
func (w *stringWriter) Write(p []byte) (int, error) {
return w.b.Write(p)
}
// generatePlaceholderSVG 为非图片文件生成占位图标 SVG
func generatePlaceholderSVG(mimeType string) string {
var icon, color string
switch {
case strings.HasPrefix(mimeType, "audio/"):
icon = "🎵"
color = "#8B5CF6" // purple
case strings.HasPrefix(mimeType, "video/"):
icon = "🎬"
color = "#EF4444" // red
case strings.HasPrefix(mimeType, "application/pdf"):
icon = "📄"
color = "#F59E0B" // amber
case strings.HasPrefix(mimeType, "text/"):
icon = "📝"
color = "#3B82F6" // blue
default:
icon = "📎"
color = "#6B7280" // gray
}
return fmt.Sprintf(`<svg xmlns="http://www.w3.org/2000/svg" width="300" height="300" viewBox="0 0 300 300">
<rect width="300" height="300" fill="%s" opacity="0.1"/>
<text x="150" y="160" text-anchor="middle" font-size="64" fill="%s">%s</text>
<text x="150" y="220" text-anchor="middle" font-size="16" fill="%s" opacity="0.7">%s</text>
</svg>`, color, color, icon, color, getMimeTypeShort(mimeType))
}
// getMimeTypeShort 获取MIME类型简称
func getMimeTypeShort(mimeType string) string {
parts := strings.SplitN(mimeType, "/", 2)
if len(parts) < 2 {
return mimeType
}
return strings.ToUpper(parts[1])
}
// generateUUID 使用 crypto/rand 生成 UUID v4 格式的字符串
func generateUUID() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
// 降级方案:基于时间戳 + 随机数的唯一标识
b = make([]byte, 16)
ts := time.Now().UnixNano()
for i := 0; i < 8; i++ {
b[i] = byte(ts >> (i * 8))
}
// 用简单 PRNG 填充剩余字节
for i := 8; i < 16; i++ {
b[i] = byte((ts * int64(i+1)) % 256)
}
}
// 设置 UUID v4 版本位 (version = 4)
b[6] = (b[6] & 0x0f) | 0x40
// 设置 UUID variant 位 (variant = 10xx)
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
}
// scaleBilinear 使用双线性插值将 src 图像缩放到 dst 的尺寸 (纯标准库实现)
func scaleBilinear(dst *image.RGBA, src image.Image) {
dstBounds := dst.Bounds()
srcBounds := src.Bounds()
dstW := dstBounds.Dx()
dstH := dstBounds.Dy()
srcW := srcBounds.Dx()
srcH := srcBounds.Dy()
// 计算缩放比例
scaleX := float64(srcW) / float64(dstW)
scaleY := float64(srcH) / float64(dstH)
for dy := 0; dy < dstH; dy++ {
for dx := 0; dx < dstW; dx++ {
// 源图像中的浮点坐标
sx := float64(dx)*scaleX + float64(srcBounds.Min.X)
sy := float64(dy)*scaleY + float64(srcBounds.Min.Y)
// 四个邻近像素的整数坐标
x0 := int(sx)
y0 := int(sy)
x1 := x0 + 1
y1 := y0 + 1
// 边界限制
if x0 < srcBounds.Min.X {
x0 = srcBounds.Min.X
}
if x1 >= srcBounds.Max.X {
x1 = srcBounds.Max.X - 1
}
if x0 >= srcBounds.Max.X {
x0 = srcBounds.Max.X - 1
}
if x1 < srcBounds.Min.X {
x1 = srcBounds.Min.X
}
if y0 < srcBounds.Min.Y {
y0 = srcBounds.Min.Y
}
if y1 >= srcBounds.Max.Y {
y1 = srcBounds.Max.Y - 1
}
if y0 >= srcBounds.Max.Y {
y0 = srcBounds.Max.Y - 1
}
if y1 < srcBounds.Min.Y {
y1 = srcBounds.Min.Y
}
// 插值权重
fracX := sx - float64(x0)
fracY := sy - float64(y0)
// 四个角的 RGBA 值
r00, g00, b00, a00 := src.At(x0, y0).RGBA()
r10, g10, b10, a10 := src.At(x1, y0).RGBA()
r01, g01, b01, a01 := src.At(x0, y1).RGBA()
r11, g11, b11, a11 := src.At(x1, y1).RGBA()
// 双线性插值 (在 0-65535 范围内进行)
interp := func(c00, c10, c01, c11 uint32) uint8 {
top := float64(c00)*(1-fracX) + float64(c10)*fracX
bot := float64(c01)*(1-fracX) + float64(c11)*fracX
val := top*(1-fracY) + bot*fracY
return uint8(val / 256)
}
r := interp(r00, r10, r01, r11)
g := interp(g00, g10, g01, g11)
b := interp(b00, b10, b01, b11)
a := interp(a00, a10, a01, a11)
dst.SetRGBA(dx+int(dstBounds.Min.X), dy+int(dstBounds.Min.Y), color.RGBA{r, g, b, a})
}
}
}
@@ -0,0 +1,718 @@
package handler
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"image"
"image/color"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"log"
"net/http"
"os"
"sort"
"strings"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/config"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/store"
)
// ImageHandler 图片分析处理器
type ImageHandler struct {
cfg *config.Config
fileStore *store.FileStore
}
// NewImageHandler 创建图片分析处理器
func NewImageHandler(cfg *config.Config, fileStore *store.FileStore) *ImageHandler {
return &ImageHandler{
cfg: cfg,
fileStore: fileStore,
}
}
// ImageAnalysis 图片分析结果
type ImageAnalysis struct {
Format string `json:"format"`
Width int `json:"width"`
Height int `json:"height"`
FileSize int64 `json:"file_size"`
Description string `json:"description"`
TopColors []ColorInfo `json:"top_colors,omitempty"`
EXIF map[string]string `json:"exif,omitempty"`
AnalyzedBy string `json:"analyzed_by"` // "openai_vision" | "local"
}
// ColorInfo 颜色信息
type ColorInfo struct {
Hex string `json:"hex"`
Percent float64 `json:"percent"`
}
// AnalyzeRequestBody 分析请求体
type AnalyzeRequestBody struct {
FileID string `json:"file_id"`
}
// ========== POST /api/v1/images/analyze ==========
// Analyze 分析上传的图片 (multipart/form-data 或 JSON)
func (h *ImageHandler) Analyze(c *gin.Context) {
userID := middleware.GetUserID(c)
// 尝试 JSON body: {"file_id": "xxx"}
contentType := c.GetHeader("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
var body AnalyzeRequestBody
if err := c.ShouldBindJSON(&body); err != nil || body.FileID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 file_id 字段", "errorType": "invalid_request"})
return
}
h.analyzeByFileID(c, userID, body.FileID)
return
}
// 尝试 multipart/form-data: 直接上传图片分析
file, header, err := c.Request.FormFile("file")
if err != nil {
// 也尝试 "image" 字段名
file, header, err = c.Request.FormFile("image")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到图片文件 (使用 file 或 image 字段)", "errorType": "missing_file"})
return
}
}
defer file.Close()
h.analyzeUploadedFile(c, userID, file, header.Filename, header.Size)
}
// ========== GET /api/v1/images/analyze/:file_id ==========
// AnalyzeByID 对已上传的文件进行分析
func (h *ImageHandler) AnalyzeByID(c *gin.Context) {
userID := middleware.GetUserID(c)
fileID := c.Param("file_id")
if fileID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 file_id", "errorType": "invalid_request"})
return
}
h.analyzeByFileID(c, userID, fileID)
}
// analyzeByFileID 根据文件ID分析已存储的图片
func (h *ImageHandler) analyzeByFileID(c *gin.Context, userID, fileID string) {
if h.fileStore == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
return
}
f, err := h.fileStore.GetFile(fileID)
if err != nil {
log.Printf("[ImageHandler] 查询文件失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文件失败", "errorType": "db_error"})
return
}
if f == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在", "errorType": "file_not_found"})
return
}
if f.UserID != userID && !f.IsPublic {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文件", "errorType": "access_denied"})
return
}
if !isImageType(f.MimeType) {
c.JSON(http.StatusBadRequest, gin.H{"error": "文件不是图片类型: " + f.MimeType, "errorType": "unsupported_type"})
return
}
result, err := h.analyzeImage(f.StoredPath, f.MimeType, f.Size)
if err != nil {
log.Printf("[ImageHandler] 图片分析失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "图片分析失败: " + err.Error(), "errorType": "analysis_error"})
return
}
c.JSON(http.StatusOK, result)
}
// analyzeUploadedFile 分析直接上传的图片文件
func (h *ImageHandler) analyzeUploadedFile(c *gin.Context, userID string, file io.Reader, filename string, fileSize int64) {
// 检查文件大小 (10MB 限制)
const maxImageSize = 10 * 1024 * 1024
if fileSize > maxImageSize {
c.JSON(http.StatusBadRequest, gin.H{"error": "图片大小超过限制 (最大 10MB)", "errorType": "file_too_large"})
return
}
// 读取文件到内存
data, err := io.ReadAll(file)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "读取图片失败", "errorType": "read_error"})
return
}
// 检测格式
_, format, err := image.DecodeConfig(bytes.NewReader(data))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无法解码图片: " + err.Error(), "errorType": "decode_error"})
return
}
mimeType := "image/" + format
supportedFormats := map[string]bool{
"image/jpeg": true,
"image/png": true,
"image/gif": true,
}
if !supportedFormats[mimeType] {
// 允许所有 image/* 格式,但只对常见格式做深入分析
}
// 写入临时文件进行分析
tmpFile, err := os.CreateTemp("", "cyrene-image-*."+format)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建临时文件失败", "errorType": "server_error"})
return
}
defer os.Remove(tmpFile.Name())
defer tmpFile.Close()
if _, err := tmpFile.Write(data); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "写入临时文件失败", "errorType": "server_error"})
return
}
result, err := h.analyzeImage(tmpFile.Name(), mimeType, int64(len(data)))
if err != nil {
log.Printf("[ImageHandler] 图片分析失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "图片分析失败: " + err.Error(), "errorType": "analysis_error"})
return
}
c.JSON(http.StatusOK, result)
}
// analyzeImage 核心分析逻辑:先尝试 OpenAI Vision,失败则降级到本地分析
func (h *ImageHandler) analyzeImage(filePath, mimeType string, fileSize int64) (*ImageAnalysis, error) {
// 如果配置了 OpenAI API Key,尝试使用 Vision API
apiKey := h.cfg.LLMAPIKey
if apiKey != "" {
result, err := h.analyzeWithOpenAIVision(filePath, mimeType)
if err == nil {
return result, nil
}
log.Printf("[ImageHandler] OpenAI Vision 分析失败,降级到本地分析: %v", err)
}
// 降级到本地分析
return analyzeImageLocally(filePath, mimeType, fileSize)
}
// analyzeWithOpenAIVision 使用 OpenAI Vision API 分析图片
func (h *ImageHandler) analyzeWithOpenAIVision(filePath, mimeType string) (*ImageAnalysis, error) {
// 读取图片并编码为 base64
data, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("读取图片文件失败: %w", err)
}
base64Data := base64.StdEncoding.EncodeToString(data)
dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
// 获取本地基本信息
localInfo, err := analyzeImageLocally(filePath, mimeType, int64(len(data)))
if err != nil {
localInfo = &ImageAnalysis{}
}
// 构建 OpenAI Vision API 请求
reqBody := map[string]interface{}{
"model": h.cfg.LLMModel,
"messages": []map[string]interface{}{
{
"role": "user",
"content": []map[string]interface{}{
{
"type": "text",
"text": "请详细描述这张图片的内容。用中文回答。请描述:1) 图片中的主要物体/人物 2) 场景/环境 3) 颜色和色调 4) 文字内容(如果有)5) 整体氛围和风格。请尽可能详细。",
},
{
"type": "image_url",
"image_url": map[string]string{
"url": dataURL,
},
},
},
},
},
"max_tokens": 500,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
apiURL := strings.TrimRight(h.cfg.LLMAPIURL, "/") + "/chat/completions"
httpReq, err := http.NewRequest("POST", apiURL, bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+h.cfg.LLMAPIKey)
httpClient := &http.Client{}
resp, err := httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("API 请求失败: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API 返回错误 (%d): %s", resp.StatusCode, string(body))
}
var result struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
var description string
if len(result.Choices) > 0 {
description = result.Choices[0].Message.Content
}
return &ImageAnalysis{
Format: localInfo.Format,
Width: localInfo.Width,
Height: localInfo.Height,
FileSize: localInfo.FileSize,
Description: description,
TopColors: localInfo.TopColors,
EXIF: localInfo.EXIF,
AnalyzedBy: "openai_vision",
}, nil
}
// analyzeImageLocally 使用 Go 标准库进行本地图片分析
func analyzeImageLocally(filePath, mimeType string, fileSize int64) (*ImageAnalysis, error) {
// 1. 读取文件
data, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("读取文件失败: %w", err)
}
// 2. 解码图片
img, format, err := image.Decode(bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("解码图片失败: %w", err)
}
// 3. 获取尺寸
bounds := img.Bounds()
width := bounds.Dx()
height := bounds.Dy()
// 4. 计算颜色直方图 (采样像素)
topColors := computeColorHistogram(img, 5)
// 5. 读取 EXIF (简单实现: 仅 JPEG)
exif := extractEXIF(data, format)
// 6. 生成描述文本
description := generateLocalDescription(format, width, height, fileSize, topColors)
return &ImageAnalysis{
Format: format,
Width: width,
Height: height,
FileSize: fileSize,
Description: description,
TopColors: topColors,
EXIF: exif,
AnalyzedBy: "local",
}, nil
}
// computeColorHistogram 计算颜色直方图,返回 top N 颜色
func computeColorHistogram(img image.Image, topN int) []ColorInfo {
bounds := img.Bounds()
width := bounds.Dx()
height := bounds.Dy()
// 采样间隔:每 step 个像素采样一个
step := 1
totalPixels := width * height
if totalPixels > 10000 {
step = (width * height) / 10000
if step < 1 {
step = 1
}
}
colorCount := make(map[string]int)
sampledCount := 0
for y := bounds.Min.Y; y < bounds.Max.Y; y += step {
for x := bounds.Min.X; x < bounds.Max.X; x += step {
r, g, b, _ := img.At(x, y).RGBA()
// 量化到 8-bit 并聚类(每 32 级一分组,减少颜色种类)
qr := int(r>>8) / 32
qg := int(g>>8) / 32
qb := int(b>>8) / 32
key := fmt.Sprintf("%02d_%02d_%02d", qr, qg, qb)
colorCount[key]++
sampledCount++
}
}
if sampledCount == 0 {
return nil
}
// 排序取 topN
type kv struct {
key string
count int
}
var sorted []kv
for k, v := range colorCount {
sorted = append(sorted, kv{k, v})
}
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].count > sorted[j].count
})
result := make([]ColorInfo, 0, topN)
for i := 0; i < topN && i < len(sorted); i++ {
var qr, qg, qb int
fmt.Sscanf(sorted[i].key, "%d_%d_%d", &qr, &qg, &qb)
// 量化组的中间值
r := qr*32 + 16
g := qg*32 + 16
b := qb*32 + 16
hex := fmt.Sprintf("#%02X%02X%02X", r, g, b)
pct := float64(sorted[i].count) / float64(sampledCount) * 100
result = append(result, ColorInfo{
Hex: hex,
Percent: pct,
})
}
return result
}
// extractEXIF 简单提取 JPEG EXIF 信息
func extractEXIF(data []byte, format string) map[string]string {
if format != "jpeg" {
return nil
}
exif := make(map[string]string)
// 查找 EXIF 标记 (0xFFE1)
for i := 0; i < len(data)-4; i++ {
if data[i] == 0xFF && data[i+1] == 0xE1 {
if i+10 >= len(data) {
break
}
// 验证 EXIF 标识 "Exif\0\0"
if string(data[i+4:i+10]) != "Exif\x00\x00" {
continue
}
exifStart := i + 10
if exifStart+8 >= len(data) {
break
}
// 判断字节序
var bigEndian bool
if data[exifStart] == 'M' && data[exifStart+1] == 'M' {
bigEndian = true
} else if data[exifStart] == 'I' && data[exifStart+1] == 'I' {
bigEndian = false
} else {
break
}
// 读取 IFD0
tiffStart := exifStart
readUint16 := func(offset int) uint16 {
if offset+2 > len(data) {
return 0
}
if bigEndian {
return uint16(data[offset])<<8 | uint16(data[offset+1])
}
return uint16(data[offset+1])<<8 | uint16(data[offset])
}
ifd0Offset := int(readUint16(tiffStart + 4))
if ifd0Offset < 8 {
break
}
ifd0Addr := tiffStart + ifd0Offset
if ifd0Addr+2 >= len(data) {
break
}
numEntries := int(readUint16(ifd0Addr))
entryAddr := ifd0Addr + 2
// 常见 EXIF 标签
tagNames := map[uint16]string{
0x010F: "Make",
0x0110: "Model",
0x0112: "Orientation",
0x0132: "DateTime",
0x829A: "ExposureTime",
0x829D: "FNumber",
0x8827: "ISO",
0x9003: "DateTimeOriginal",
0x920A: "FocalLength",
}
for j := 0; j < numEntries && entryAddr+12 <= len(data); j++ {
tag := readUint16(entryAddr)
dataType := readUint16(entryAddr + 2)
dataCount := int(readUint16(entryAddr + 4))
entryAddr += 12
if name, ok := tagNames[tag]; ok {
valueLen := dataCount
switch dataType {
case 2: // ASCII
valueLen = dataCount
case 3, 4: // SHORT, LONG
valueLen = dataCount * 2
case 5: // RATIONAL
valueLen = dataCount * 8
}
if valueLen <= 4 {
// 值在 tag 自身中
valData := data[entryAddr-4 : entryAddr]
valStr := extractASCIIValue(valData, dataType, dataCount, bigEndian)
if valStr != "" {
exif[name] = valStr
}
}
}
}
break // 只处理第一个 EXIF 块
}
}
if len(exif) == 0 {
return nil
}
return exif
}
// extractASCIIValue 从 EXIF 数据中提取 ASCII 值
func extractASCIIValue(data []byte, dataType uint16, count int, bigEndian bool) string {
switch dataType {
case 2: // ASCII string
s := string(data)
if idx := strings.IndexByte(s, 0); idx >= 0 {
s = s[:idx]
}
return s
case 3: // SHORT
if len(data) >= 2 {
var val uint16
if bigEndian {
val = uint16(data[0])<<8 | uint16(data[1])
} else {
val = uint16(data[1])<<8 | uint16(data[0])
}
return fmt.Sprintf("%d", val)
}
case 5: // RATIONAL
// 简化处理:返回原始字节
return ""
}
return ""
}
// generateLocalDescription 生成本地图片描述文本
func generateLocalDescription(format string, width, height int, fileSize int64, topColors []ColorInfo) string {
var sb strings.Builder
formatNames := map[string]string{
"jpeg": "JPEG",
"jpg": "JPEG",
"png": "PNG",
"gif": "GIF",
"webp": "WebP",
"bmp": "BMP",
}
formatName := strings.ToUpper(format)
if name, ok := formatNames[strings.ToLower(format)]; ok {
formatName = name
}
sb.WriteString(fmt.Sprintf("这是一张 %s 格式的图片,", formatName))
sb.WriteString(fmt.Sprintf("分辨率为 %d×%d 像素,", width, height))
sb.WriteString(fmt.Sprintf("文件大小为 %s。", formatFileSize(fileSize)))
// 判断大致比例
ratio := float64(width) / float64(height)
if ratio > 1.8 {
sb.WriteString("图片呈宽幅横幅比例。")
} else if ratio < 0.6 {
sb.WriteString("图片呈竖幅比例。")
} else if ratio > 1.2 {
sb.WriteString("图片接近横向画幅。")
} else if ratio < 0.8 {
sb.WriteString("图片接近纵向画幅。")
} else {
sb.WriteString("图片接近正方形比例。")
}
// 描述主要颜色
if len(topColors) > 0 {
sb.WriteString(" 主要色调为")
for i, c := range topColors {
if i > 0 {
if i == len(topColors)-1 {
sb.WriteString(" 和 ")
} else {
sb.WriteString("、")
}
}
colorName := getColorName(c.Hex)
sb.WriteString(fmt.Sprintf("%s(%s, %.0f%%)", colorName, c.Hex, c.Percent))
}
sb.WriteString("。")
}
return sb.String()
}
// formatFileSize 格式化文件大小
func formatFileSize(size int64) string {
if size < 1024 {
return fmt.Sprintf("%d B", size)
}
if size < 1024*1024 {
return fmt.Sprintf("%.1f KB", float64(size)/1024)
}
return fmt.Sprintf("%.1f MB", float64(size)/(1024*1024))
}
// getColorName 根据 hex 颜色获取中文颜色名
func getColorName(hex string) string {
if len(hex) < 7 {
return hex
}
var r, g, b uint8
fmt.Sscanf(hex, "#%02X%02X%02X", &r, &g, &b)
// 灰度判断
if absDiff(r, g) < 20 && absDiff(g, b) < 20 && absDiff(r, b) < 20 {
if r < 40 {
return "黑色"
}
if r < 100 {
return "深灰色"
}
if r < 180 {
return "灰色"
}
if r < 230 {
return "浅灰色"
}
return "白色"
}
// HSL 近似判断色调
maxC := max(r, max(g, b))
minC := min(r, min(g, b))
delta := maxC - minC
if delta < 30 {
if maxC < 60 {
return "暗色"
}
if maxC > 200 {
return "浅色"
}
return "中性色"
}
var hue string
switch {
case r == maxC:
if g >= b {
hue = "红色"
} else {
hue = "品红色"
}
case g == maxC:
if b >= r {
hue = "绿色"
} else {
hue = "黄绿色"
}
default:
if r >= g {
hue = "紫红色"
} else {
hue = "蓝色"
}
}
// 亮度修饰
if maxC < 80 {
hue = "深" + hue
} else if minC > 200 {
hue = "浅" + hue
}
return hue
}
func absDiff(a, b uint8) int {
if a > b {
return int(a - b)
}
return int(b - a)
}
func max(a, b uint8) uint8 {
if a > b {
return a
}
return b
}
func min(a, b uint8) uint8 {
if a < b {
return a
}
return b
}
// ========== color.RGBA → string 辅助 ==========
var _ = color.RGBA{} // 确保 color 包被使用
@@ -0,0 +1,589 @@
package handler
import (
"log"
"net/http"
"os"
"strings"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/store"
)
// KnowledgeHandler 知识库处理器
type KnowledgeHandler struct {
store *store.KnowledgeStore
fileStore *store.FileStore
}
// NewKnowledgeHandler 创建知识库处理器
func NewKnowledgeHandler(s *store.KnowledgeStore, fs *store.FileStore) *KnowledgeHandler {
return &KnowledgeHandler{store: s, fileStore: fs}
}
// checkStore 检查知识库存储是否可用,不可用时返回 true(调用方应 return
func (h *KnowledgeHandler) checkStore(c *gin.Context) bool {
if h.store == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": "知识库服务不可用(数据库未连接)",
"errorType": "service_unavailable",
})
return true
}
return false
}
// ========== 请求/响应类型 ==========
type createKBRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
}
type updateKBRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
}
type addDocRequest struct {
Title string `json:"title" binding:"required"`
Content string `json:"content"`
SourceType string `json:"source_type"`
FileID string `json:"file_id"`
}
type searchRequest struct {
Query string `json:"query" binding:"required"`
KBIDs []string `json:"kb_ids"`
Limit int `json:"limit"`
}
// ========== POST /api/v1/knowledge/bases — 创建知识库 ==========
func (h *KnowledgeHandler) CreateKB(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
var req createKBRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供知识库名称", "errorType": "invalid_request"})
return
}
kb := &store.KnowledgeBase{
ID: store.GenerateUUID(),
UserID: userID,
Name: req.Name,
Description: req.Description,
}
if err := h.store.CreateKB(kb); err != nil {
log.Printf("[KnowledgeHandler] 创建知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建知识库失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusCreated, kb)
}
// ========== GET /api/v1/knowledge/bases — 列出用户的知识库 ==========
func (h *KnowledgeHandler) ListKBs(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbs, err := h.store.GetKBsByUser(userID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库列表失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库列表失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"knowledge_bases": kbs, "total": len(kbs)})
}
// ========== GET /api/v1/knowledge/bases/:id — 获取知识库详情 ==========
func (h *KnowledgeHandler) GetKB(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此知识库", "errorType": "access_denied"})
return
}
// 获取文档列表
docs, err := h.store.GetDocumentsByKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询文档列表失败: %v", err)
docs = []store.KnowledgeDocument{}
}
c.JSON(http.StatusOK, gin.H{
"knowledge_base": kb,
"documents": docs,
})
}
// ========== PUT /api/v1/knowledge/bases/:id — 更新知识库 ==========
func (h *KnowledgeHandler) UpdateKB(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
var req updateKBRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供知识库名称", "errorType": "invalid_request"})
return
}
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权修改此知识库", "errorType": "access_denied"})
return
}
if err := h.store.UpdateKB(kbID, req.Name, req.Description); err != nil {
log.Printf("[KnowledgeHandler] 更新知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新知识库失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"status": "updated"})
}
// ========== DELETE /api/v1/knowledge/bases/:id — 删除知识库 ==========
func (h *KnowledgeHandler) DeleteKB(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此知识库", "errorType": "access_denied"})
return
}
if err := h.store.DeleteKB(kbID); err != nil {
log.Printf("[KnowledgeHandler] 删除知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除知识库失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
// ========== POST /api/v1/knowledge/bases/:id/documents — 添加文档 ==========
func (h *KnowledgeHandler) AddDocument(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
// 检查知识库是否存在且属于当前用户
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权操作此知识库", "errorType": "access_denied"})
return
}
var req addDocRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文档标题", "errorType": "invalid_request"})
return
}
if req.SourceType == "" {
req.SourceType = "text"
}
var content string
var sourceRef string
var contentType string
switch req.SourceType {
case "text":
content = req.Content
contentType = "text/plain"
if content == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文档内容", "errorType": "invalid_request"})
return
}
case "file":
if req.FileID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文件ID", "errorType": "invalid_request"})
return
}
sourceRef = req.FileID
// 从 FileStore 读取文件内容
if h.fileStore != nil {
f, err := h.fileStore.GetFile(req.FileID)
if err != nil || f == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在", "errorType": "not_found"})
return
}
if f.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文件", "errorType": "access_denied"})
return
}
// 注意:这里只支持文本类型文件的读取
// 对于二进制文件,需要更复杂的解析逻辑
sourceRef = f.Filename
// 从磁盘读取文件内容
content, contentType, _ = readFileContent(f.StoredPath, f.MimeType)
if content == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "无法读取文件内容,仅支持文本文件", "errorType": "unsupported_file"})
return
}
} else {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
return
}
case "url":
sourceRef = req.FileID
content = req.Content
contentType = "text/html"
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的来源类型: " + req.SourceType, "errorType": "invalid_request"})
return
}
if content == "" {
content = req.Content
}
if contentType == "" {
contentType = "text/plain"
}
doc := &store.KnowledgeDocument{
ID: store.GenerateUUID(),
KBID: kbID,
UserID: userID,
Title: req.Title,
SourceType: req.SourceType,
SourceRef: sourceRef,
ContentType: contentType,
RawContent: content,
}
if err := h.store.AddDocument(doc); err != nil {
log.Printf("[KnowledgeHandler] 添加文档失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加文档失败", "errorType": "db_error"})
return
}
// 自动分块
chunkCount, err := h.store.ChunkDocument(doc.ID)
if err != nil {
log.Printf("[KnowledgeHandler] 文档分块失败: %v", err)
// 分块失败不影响文档创建
}
doc.ChunkCount = chunkCount
c.JSON(http.StatusCreated, doc)
}
// ========== GET /api/v1/knowledge/bases/:id/documents — 列出知识库中的文档 ==========
func (h *KnowledgeHandler) ListDocuments(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
// 检查权限
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此知识库", "errorType": "access_denied"})
return
}
docs, err := h.store.GetDocumentsByKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询文档列表失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档列表失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"documents": docs, "total": len(docs)})
}
// ========== GET /api/v1/knowledge/documents/:id — 获取文档详情 ==========
func (h *KnowledgeHandler) GetDocument(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
docID := c.Param("id")
doc, err := h.store.GetDocument(docID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询文档失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档失败", "errorType": "db_error"})
return
}
if doc == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "文档不存在", "errorType": "not_found"})
return
}
if doc.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文档", "errorType": "access_denied"})
return
}
// 获取分块
chunks, err := h.store.GetChunksByDocID(docID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询分块失败: %v", err)
chunks = []store.KnowledgeChunk{}
}
c.JSON(http.StatusOK, gin.H{
"document": doc,
"chunks": chunks,
})
}
// ========== DELETE /api/v1/knowledge/documents/:id — 删除文档 ==========
func (h *KnowledgeHandler) DeleteDocument(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
docID := c.Param("id")
doc, err := h.store.GetDocument(docID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询文档失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档失败", "errorType": "db_error"})
return
}
if doc == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "文档不存在", "errorType": "not_found"})
return
}
if doc.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此文档", "errorType": "access_denied"})
return
}
if err := h.store.DeleteDocument(docID); err != nil {
log.Printf("[KnowledgeHandler] 删除文档失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除文档失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
// ========== POST /api/v1/knowledge/search — 搜索知识库 ==========
func (h *KnowledgeHandler) Search(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
var req searchRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供搜索关键词", "errorType": "invalid_request"})
return
}
if req.Limit <= 0 {
req.Limit = 5
}
if req.Limit > 50 {
req.Limit = 50
}
var results []store.SearchChunkResult
var err error
if len(req.KBIDs) == 0 {
// 搜索所有知识库
results, err = h.store.SearchAllKBs(userID, req.Query, req.Limit)
} else {
// 搜索指定知识库,需要验证权限
results = []store.SearchChunkResult{}
for _, kbID := range req.KBIDs {
kb, checkErr := h.store.GetKB(kbID)
if checkErr != nil || kb == nil || kb.UserID != userID {
continue // 跳过无权限或不存在的知识库
}
kbResults, searchErr := h.store.SearchChunks(kbID, req.Query, req.Limit)
if searchErr != nil {
log.Printf("[KnowledgeHandler] 搜索知识库 %s 失败: %v", kbID, searchErr)
continue
}
results = append(results, kbResults...)
}
// 限制总结果数
if len(results) > req.Limit {
results = results[:req.Limit]
}
}
if err != nil {
log.Printf("[KnowledgeHandler] 搜索失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "搜索失败", "errorType": "db_error"})
return
}
// 生成高亮片段
for i := range results {
results[i].Headline = generateHeadline(results[i].Content, req.Query, 200)
}
c.JSON(http.StatusOK, gin.H{
"chunks": results,
"total": len(results),
"query": req.Query,
})
}
// ========== 辅助函数 ==========
// generateHeadline 生成高亮片段,提取查询关键词周围的文本
func generateHeadline(content, query string, maxLen int) string {
if maxLen <= 0 {
maxLen = 200
}
runes := []rune(content)
if len(runes) <= maxLen {
return content
}
// 查找查询关键词位置
queryRunes := []rune(query)
pos := -1
for i := 0; i <= len(runes)-len(queryRunes); i++ {
match := true
for j := 0; j < len(queryRunes); j++ {
if runes[i+j] != queryRunes[j] {
match = false
break
}
}
if match {
pos = i
break
}
}
if pos < 0 {
// 没有找到精确匹配,返回前 maxLen 个字符
return string(runes[:maxLen]) + "..."
}
// 以匹配位置为中心,截取上下文
half := maxLen / 2
start := pos - half
if start < 0 {
start = 0
}
end := start + maxLen
if end > len(runes) {
end = len(runes)
start = end - maxLen
if start < 0 {
start = 0
}
}
result := string(runes[start:end])
if start > 0 {
result = "..." + result
}
if end < len(runes) {
result = result + "..."
}
return result
}
// readFileContent 从磁盘读取文件内容 (仅支持文本类型)
func readFileContent(path, mimeType string) (content string, contentType string, err error) {
// 只支持文本类型
if !strings.HasPrefix(mimeType, "text/") && mimeType != "application/json" {
return "", "", nil
}
data, err := os.ReadFile(path)
if err != nil {
return "", "", err
}
return string(data), mimeType, nil
}
@@ -0,0 +1,162 @@
package handler
import (
"encoding/json"
"log"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/config"
"github.com/yourname/cyrene-ai/gateway/internal/ws"
)
// NotificationHandler 通知推送处理器
type NotificationHandler struct {
cfg *config.Config
hub *ws.Hub
}
// NewNotificationHandler 创建通知处理器
func NewNotificationHandler(cfg *config.Config, hub *ws.Hub) *NotificationHandler {
return &NotificationHandler{cfg: cfg, hub: hub}
}
// PushNotificationRequest 推送通知请求体
type PushNotificationRequest struct {
UserID string `json:"user_id" binding:"required"`
Type string `json:"type" binding:"required,oneof=info warning success thinking reminder"`
Title string `json:"title" binding:"required"`
Body string `json:"body" binding:"required"`
Data map[string]interface{} `json:"data,omitempty"`
}
// Push 推送通知到指定用户 (需要 JWT 认证)
// POST /api/v1/notifications/push
func (h *NotificationHandler) Push(c *gin.Context) {
var req PushNotificationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
// 生成通知
notif := h.buildNotification(req)
// 序列化 WS 消息
msg := ws.ServerMessage{
Type: "notification",
MessageID: "notif_" + generateID(),
Timestamp: time.Now().UnixMilli(),
Notification: notif,
}
data, err := json.Marshal(msg)
if err != nil {
log.Printf("[notification] 序列化通知失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "内部错误"})
return
}
// 通过 Hub 推送给指定用户
h.hub.SendToUser(req.UserID, data)
log.Printf("[notification] 通知已推送: user=%s type=%s title=%s", req.UserID, req.Type, req.Title)
c.JSON(http.StatusOK, gin.H{
"success": true,
"notification": gin.H{
"id": notif.ID,
"type": notif.Type,
"title": notif.Title,
"user_id": req.UserID,
"timestamp": notif.Timestamp,
"delivered": h.hub.UserClientCount(req.UserID) > 0,
},
})
}
// InternalNotify 内部服务推送通知 (使用内部 service token)
// POST /api/v1/internal/notify
func (h *NotificationHandler) InternalNotify(c *gin.Context) {
var req PushNotificationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
// 生成通知
notif := h.buildNotification(req)
// 序列化 WS 消息
msg := ws.ServerMessage{
Type: "notification",
MessageID: "notif_" + generateID(),
Timestamp: time.Now().UnixMilli(),
Notification: notif,
}
data, err := json.Marshal(msg)
if err != nil {
log.Printf("[notification] 序列化通知失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "内部错误"})
return
}
// 通过 Hub 推送给指定用户
h.hub.SendToUser(req.UserID, data)
log.Printf("[notification] 内部通知已推送: user=%s type=%s title=%s", req.UserID, req.Type, req.Title)
c.JSON(http.StatusOK, gin.H{
"success": true,
"notification": gin.H{
"id": notif.ID,
"type": notif.Type,
"title": notif.Title,
"user_id": req.UserID,
"timestamp": notif.Timestamp,
"delivered": h.hub.UserClientCount(req.UserID) > 0,
},
})
}
// InternalNotifyAuth 内部服务认证中间件
func (h *NotificationHandler) InternalNotifyAuth() gin.HandlerFunc {
return func(c *gin.Context) {
token := c.GetHeader("X-Internal-Token")
if token == "" {
token = c.GetHeader("Authorization")
if len(token) > 7 && token[:7] == "Bearer " {
token = token[7:]
}
}
if token != h.cfg.InternalServiceToken {
c.JSON(http.StatusUnauthorized, gin.H{"error": "内部认证失败"})
c.Abort()
return
}
c.Next()
}
}
// buildNotification 构建 NotificationInfo
func (h *NotificationHandler) buildNotification(req PushNotificationRequest) *ws.NotificationInfo {
notifID := "notif_" + generateID()
now := time.Now().UTC().Format(time.RFC3339)
if req.Data == nil {
req.Data = make(map[string]interface{})
}
return &ws.NotificationInfo{
ID: notifID,
Type: req.Type,
Title: req.Title,
Body: req.Body,
Timestamp: now,
Data: req.Data,
}
}
@@ -0,0 +1,340 @@
package handler
import (
"encoding/json"
"log"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/store"
"github.com/yourname/cyrene-ai/gateway/internal/ws"
)
// ReminderHandler 提醒处理器
type ReminderHandler struct {
store *store.ReminderStore
hub *ws.Hub
}
// NewReminderHandler 创建提醒处理器
func NewReminderHandler(s *store.ReminderStore, hub *ws.Hub) *ReminderHandler {
return &ReminderHandler{store: s, hub: hub}
}
// CreateReminderRequest 创建提醒请求体
type CreateReminderRequest struct {
Title string `json:"title" binding:"required"`
Description string `json:"description"`
RemindAt string `json:"remind_at" binding:"required"` // ISO 8601 格式
RepeatType string `json:"repeat_type"` // none, daily, weekly, monthly
SessionID string `json:"session_id"`
}
// UpdateReminderRequest 更新提醒请求体
type UpdateReminderRequest struct {
Title string `json:"title"`
Description string `json:"description"`
RemindAt string `json:"remind_at"`
Status string `json:"status"` // pending, completed, cancelled
RepeatType string `json:"repeat_type"` // none, daily, weekly, monthly
SessionID string `json:"session_id"`
}
// List 获取提醒列表
// GET /api/v1/reminders?user_id=xxx&status=pending&limit=50&offset=0
func (h *ReminderHandler) List(c *gin.Context) {
userID := c.Query("user_id")
if userID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 user_id 参数"})
return
}
status := c.Query("status")
limit := 50
offset := 0
if l, ok := c.GetQuery("limit"); ok {
if v, err := strconv.Atoi(l); err == nil && v > 0 {
limit = v
}
}
if o, ok := c.GetQuery("offset"); ok {
if v, err := strconv.Atoi(o); err == nil && v >= 0 {
offset = v
}
}
reminders, err := h.store.GetRemindersByUser(userID, status, limit, offset)
if err != nil {
log.Printf("[reminder] 获取提醒列表失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取提醒列表失败"})
return
}
c.JSON(http.StatusOK, gin.H{
"reminders": reminders,
"count": len(reminders),
})
}
// Create 创建新提醒
// POST /api/v1/reminders
func (h *ReminderHandler) Create(c *gin.Context) {
var req CreateReminderRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
// 从 JWT 获取 userID
userID := middleware.GetUserID(c)
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
return
}
// 解析时间
remindAt, err := time.Parse(time.RFC3339, req.RemindAt)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "时间格式无效,请使用 ISO 8601 格式 (例如 2024-01-01T15:00:00Z)"})
return
}
// 默认值
repeatType := req.RepeatType
if repeatType == "" {
repeatType = "none"
}
reminder := &store.Reminder{
ID: generateID(),
UserID: userID,
Title: req.Title,
Description: req.Description,
RemindAt: remindAt,
Status: "pending",
RepeatType: repeatType,
SessionID: req.SessionID,
Notified: false,
}
if err := h.store.CreateReminder(reminder); err != nil {
log.Printf("[reminder] 创建提醒失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建提醒失败"})
return
}
log.Printf("[reminder] 提醒已创建: id=%s user=%s title=%s remind_at=%s repeat=%s",
reminder.ID, userID, reminder.Title, remindAt.Format(time.RFC3339), repeatType)
c.JSON(http.StatusCreated, gin.H{
"success": true,
"reminder": reminder,
})
}
// Update 更新提醒
// PUT /api/v1/reminders/:id
func (h *ReminderHandler) Update(c *gin.Context) {
id := c.Param("id")
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少提醒 ID"})
return
}
userID := middleware.GetUserID(c)
var req UpdateReminderRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
return
}
// 先获取已有提醒
reminders, err := h.store.GetRemindersByUser(userID, "", 100, 0)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取提醒失败"})
return
}
var existing *store.Reminder
for i := range reminders {
if reminders[i].ID == id {
existing = &reminders[i]
break
}
}
if existing == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "提醒不存在"})
return
}
// 更新字段
if req.Title != "" {
existing.Title = req.Title
}
if req.Description != "" {
existing.Description = req.Description
}
if req.RemindAt != "" {
remindAt, err := time.Parse(time.RFC3339, req.RemindAt)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "时间格式无效"})
return
}
existing.RemindAt = remindAt
}
if req.Status != "" {
existing.Status = req.Status
if req.Status == "completed" || req.Status == "cancelled" {
now := time.Now()
existing.CompletedAt = &now
}
}
if req.RepeatType != "" {
existing.RepeatType = req.RepeatType
}
if req.SessionID != "" {
existing.SessionID = req.SessionID
}
if err := h.store.UpdateReminder(id, existing); err != nil {
log.Printf("[reminder] 更新提醒失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新提醒失败"})
return
}
log.Printf("[reminder] 提醒已更新: id=%s", id)
c.JSON(http.StatusOK, gin.H{
"success": true,
"reminder": existing,
})
}
// Delete 删除提醒
// DELETE /api/v1/reminders/:id
func (h *ReminderHandler) Delete(c *gin.Context) {
id := c.Param("id")
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少提醒 ID"})
return
}
if err := h.store.DeleteReminder(id); err != nil {
log.Printf("[reminder] 删除提醒失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除提醒失败"})
return
}
log.Printf("[reminder] 提醒已删除: id=%s", id)
c.JSON(http.StatusOK, gin.H{
"success": true,
})
}
// ========== 提醒调度器 ==========
// StartReminderScheduler 启动提醒调度器,每 30 秒检查一次到期提醒
func StartReminderScheduler(s *store.ReminderStore, hub *ws.Hub) {
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
log.Println("[ReminderScheduler] 提醒调度器已启动 (检查间隔: 30秒)")
for range ticker.C {
checkAndNotify(s, hub)
}
}()
}
// checkAndNotify 检查到期提醒并推送通知
func checkAndNotify(s *store.ReminderStore, hub *ws.Hub) {
reminders, err := s.GetDueReminders()
if err != nil {
log.Printf("[ReminderScheduler] 获取到期提醒失败: %v", err)
return
}
if len(reminders) == 0 {
return
}
for _, r := range reminders {
// 1. 构建 WebSocket 通知消息
notif := &ws.NotificationInfo{
ID: "reminder_" + r.ID,
Type: "reminder",
Title: r.Title,
Body: r.Description,
Timestamp: time.Now().UTC().Format(time.RFC3339),
Data: map[string]interface{}{
"reminder_id": r.ID,
"session_id": r.SessionID,
},
}
msg := ws.ServerMessage{
Type: "notification",
MessageID: "reminder_" + r.ID,
Timestamp: time.Now().UnixMilli(),
Notification: notif,
}
data, err := json.Marshal(msg)
if err != nil {
log.Printf("[ReminderScheduler] 序列化通知失败: %v", err)
continue
}
// 2. 通过 Hub 向用户推送
hub.SendToUser(r.UserID, data)
// 3. 标记为已通知
if err := s.MarkNotified(r.ID); err != nil {
log.Printf("[ReminderScheduler] 标记已通知失败: id=%s err=%v", r.ID, err)
}
// 4. 处理重复提醒
if r.RepeatType != "" && r.RepeatType != "none" {
nextTime := calculateNextRemindAt(r.RemindAt, r.RepeatType)
r.RemindAt = nextTime
r.Notified = false
if err := s.UpdateReminder(r.ID, &r); err != nil {
log.Printf("[ReminderScheduler] 更新重复提醒失败: id=%s err=%v", r.ID, err)
} else {
log.Printf("[ReminderScheduler] 重复提醒已更新: id=%s next=%s", r.ID, nextTime.Format(time.RFC3339))
}
} else {
// 非重复提醒:标记为已完成
now := time.Now()
r.Status = "completed"
r.CompletedAt = &now
if err := s.UpdateReminder(r.ID, &r); err != nil {
log.Printf("[ReminderScheduler] 标记提醒完成失败: id=%s err=%v", r.ID, err)
}
}
log.Printf("[ReminderScheduler] 提醒已推送: user=%s title=%s id=%s", r.UserID, r.Title, r.ID)
}
}
// calculateNextRemindAt 计算下一次提醒时间
func calculateNextRemindAt(current time.Time, repeatType string) time.Time {
switch repeatType {
case "daily":
return current.Add(24 * time.Hour)
case "weekly":
return current.Add(7 * 24 * time.Hour)
case "monthly":
return current.AddDate(0, 1, 0)
default:
return current
}
}
@@ -1,8 +1,11 @@
package handler
import (
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
@@ -305,6 +308,270 @@ func (h *SessionHandler) GetSession(c *gin.Context) {
c.JSON(http.StatusOK, session)
}
// ========== GET /api/v1/messages/search?q=xxx&user_id=xxx&limit=50&offset=0 — 全文搜索消息 ==========
// SearchMessages 全文搜索消息
func (h *SessionHandler) SearchMessages(c *gin.Context) {
query := c.Query("q")
if query == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少搜索关键词参数 q", "errorType": "missing_query"})
return
}
userID := c.Query("user_id")
if userID == "" {
userID = middleware.GetUserID(c)
}
limit := 50
offset := 0
if l := c.Query("limit"); l != "" {
parsed := 0
for _, ch := range l {
if ch < '0' || ch > '9' {
break
}
parsed = parsed*10 + int(ch-'0')
}
if parsed > 0 && parsed <= 200 {
limit = parsed
}
}
if o := c.Query("offset"); o != "" {
parsed := 0
for _, ch := range o {
if ch < '0' || ch > '9' {
break
}
parsed = parsed*10 + int(ch-'0')
}
if parsed >= 0 {
offset = parsed
}
}
if !h.useDB {
c.JSON(http.StatusOK, gin.H{
"results": []gin.H{},
"total": 0,
"query": query,
})
return
}
results, total, err := h.store.SearchMessages(userID, query, limit, offset)
if err != nil {
log.Printf("[SessionHandler] 搜索消息失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "搜索失败", "errorType": "db_error"})
return
}
// 转换为统一格式
items := make([]gin.H, 0, len(results))
for _, r := range results {
items = append(items, gin.H{
"message_id": r.MessageID,
"session_id": r.SessionID,
"session_title": r.SessionTitle,
"role": r.Role,
"content": r.Content,
"created_at": r.CreatedAt.UnixMilli(),
})
}
c.JSON(http.StatusOK, gin.H{
"results": items,
"total": total,
"query": query,
"limit": limit,
"offset": offset,
})
}
// ========== GET /api/v1/sessions/:id/export?format=json|markdown|txt — 导出会话 ==========
// ExportSession 导出会话为指定格式
func (h *SessionHandler) ExportSession(c *gin.Context) {
sessionID := c.Param("id")
format := c.Query("format")
if format == "" {
format = "json"
}
// 验证格式
switch format {
case "json", "markdown", "txt":
// valid
default:
c.JSON(http.StatusBadRequest, gin.H{
"error": "不支持的导出格式",
"errorType": "invalid_format",
"hint": "支持的格式: json, markdown, txt",
})
return
}
if !h.useDB {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": "会话存储不可用",
"errorType": "store_unavailable",
"hint": "数据库连接未建立,无法导出会话",
})
return
}
// 获取会话信息
session, err := h.store.GetSession(sessionID)
if err != nil {
log.Printf("[SessionHandler] 查询会话失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询会话失败", "errorType": "db_error"})
return
}
if session == nil {
c.JSON(http.StatusNotFound, gin.H{
"error": "会话不存在",
"errorType": "session_not_found",
})
return
}
// 获取所有消息 (不限制数量,导出全部)
messages, err := h.store.GetMessages(sessionID, 0)
if err != nil {
log.Printf("[SessionHandler] 查询消息失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询消息失败", "errorType": "db_error"})
return
}
if messages == nil {
messages = []store.Message{}
}
now := time.Now()
switch format {
case "json":
h.exportJSON(c, session, messages, now)
case "markdown":
h.exportMarkdown(c, session, messages, now)
case "txt":
h.exportTXT(c, session, messages, now)
}
}
// exportJSON 导出 JSON 格式
func (h *SessionHandler) exportJSON(c *gin.Context, session *store.Session, messages []store.Message, now time.Time) {
type msgOut struct {
Role string `json:"role"`
Content string `json:"content"`
CreatedAt int64 `json:"created_at"`
}
type sessionOut struct {
ID string `json:"id"`
Title string `json:"title"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
}
type export struct {
Session sessionOut `json:"session"`
Messages []msgOut `json:"messages"`
}
msgs := make([]msgOut, 0, len(messages))
for _, m := range messages {
msgs = append(msgs, msgOut{
Role: m.Role,
Content: m.Content,
CreatedAt: m.CreatedAt.UnixMilli(),
})
}
data := export{
Session: sessionOut{
ID: session.ID,
Title: session.Title,
CreatedAt: session.CreatedAt.UnixMilli(),
UpdatedAt: session.UpdatedAt.UnixMilli(),
},
Messages: msgs,
}
jsonBytes, err := json.MarshalIndent(data, "", " ")
if err != nil {
log.Printf("[SessionHandler] JSON序列化失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "导出失败", "errorType": "serialization_error"})
return
}
c.Header("Content-Type", "application/json; charset=utf-8")
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="session_%s.json"`, session.ID))
c.Data(http.StatusOK, "application/json; charset=utf-8", jsonBytes)
}
// exportMarkdown 导出 Markdown 格式
func (h *SessionHandler) exportMarkdown(c *gin.Context, session *store.Session, messages []store.Message, now time.Time) {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("# 对话导出: %s\n", session.Title))
sb.WriteString(fmt.Sprintf("**会话 ID**: %s\n", session.ID))
sb.WriteString(fmt.Sprintf("**导出时间**: %s\n", now.Format("2006-01-02 15:04:05")))
sb.WriteString(fmt.Sprintf("**消息数量**: %d\n", len(messages)))
sb.WriteString("\n---\n\n")
for _, m := range messages {
timeStr := m.CreatedAt.Format("2006-01-02 15:04:05")
switch m.Role {
case "user":
sb.WriteString(fmt.Sprintf("### 👤 用户 (%s)\n\n", timeStr))
case "assistant":
sb.WriteString(fmt.Sprintf("### 🤖 昔涟 (%s)\n\n", timeStr))
case "system":
sb.WriteString(fmt.Sprintf("### ⚙️ 系统 (%s)\n\n", timeStr))
default:
sb.WriteString(fmt.Sprintf("### %s (%s)\n\n", m.Role, timeStr))
}
sb.WriteString(m.Content)
sb.WriteString("\n\n---\n\n")
}
content := sb.String()
c.Header("Content-Type", "text/markdown; charset=utf-8")
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="session_%s.md"`, session.ID))
c.Data(http.StatusOK, "text/markdown; charset=utf-8", []byte(content))
}
// exportTXT 导出纯文本格式
func (h *SessionHandler) exportTXT(c *gin.Context, session *store.Session, messages []store.Message, now time.Time) {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("对话导出: %s\n", session.Title))
sb.WriteString(fmt.Sprintf("会话 ID: %s\n", session.ID))
sb.WriteString(fmt.Sprintf("导出时间: %s\n", now.Format("2006-01-02 15:04:05")))
sb.WriteString(fmt.Sprintf("消息数量: %d\n", len(messages)))
sb.WriteString(strings.Repeat("=", 50))
sb.WriteString("\n\n")
for _, m := range messages {
timeStr := m.CreatedAt.Format("2006-01-02 15:04:05")
roleLabel := m.Role
switch m.Role {
case "user":
roleLabel = "用户"
case "assistant":
roleLabel = "昔涟"
case "system":
roleLabel = "系统"
}
sb.WriteString(fmt.Sprintf("[%s] %s:\n%s\n\n", timeStr, roleLabel, m.Content))
}
content := sb.String()
c.Header("Content-Type", "text/plain; charset=utf-8")
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="session_%s.txt"`, session.ID))
c.Data(http.StatusOK, "text/plain; charset=utf-8", []byte(content))
}
// 简单的工具函数
func randomID(n int) string {
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
@@ -0,0 +1,179 @@
package handler
import (
"bytes"
"encoding/json"
"io"
"log"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
// VoiceHandler 语音处理器 — 代理到 voice-service
type VoiceHandler struct {
voiceServiceURL string
client *http.Client
}
// NewVoiceHandler 创建语音处理器
func NewVoiceHandler(voiceServiceURL string) *VoiceHandler {
return &VoiceHandler{
voiceServiceURL: voiceServiceURL,
client: &http.Client{},
}
}
// Transcribe POST /api/v1/voice/transcribe
// 代理 multipart/form-data 请求到 voice-service
func (h *VoiceHandler) Transcribe(c *gin.Context) {
// 限制上传大小 (10MB)
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, 10<<20)
// 读取原始请求体
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败或文件过大,最大支持 10MB"})
return
}
// 构建代理请求
url := strings.TrimRight(h.voiceServiceURL, "/") + "/api/v1/transcribe"
proxyReq, err := http.NewRequest("POST", url, bytes.NewReader(bodyBytes))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "构建代理请求失败"})
return
}
proxyReq.Header.Set("Content-Type", c.GetHeader("Content-Type"))
resp, err := h.client.Do(proxyReq)
if err != nil {
log.Printf("[voice] Voice-Service 不可达 (Transcribe): %v", err)
c.JSON(http.StatusBadGateway, gin.H{
"error": "Voice-Service 不可达: " + err.Error(),
"errorType": "voice_service_unreachable",
"hint": "Voice-Service 服务未启动或不可达,请先在「服务管理」面板中启动 Voice-Service",
})
return
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
// 透传状态码和响应
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
}
// TTSSynthesize POST /api/v1/voice/tts
// 代理 JSON 请求到 voice-service TTS 合成
func (h *VoiceHandler) TTSSynthesize(c *gin.Context) {
// 读取 JSON 请求体
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
return
}
// 构建代理请求
url := strings.TrimRight(h.voiceServiceURL, "/") + "/api/v1/tts/synthesize"
proxyReq, err := http.NewRequest("POST", url, bytes.NewReader(bodyBytes))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "构建代理请求失败"})
return
}
proxyReq.Header.Set("Content-Type", "application/json")
resp, err := h.client.Do(proxyReq)
if err != nil {
log.Printf("[voice] Voice-Service 不可达 (TTS): %v", err)
c.JSON(http.StatusBadGateway, gin.H{
"error": "Voice-Service 不可达: " + err.Error(),
"errorType": "voice_service_unreachable",
"hint": "Voice-Service 服务未启动或不可达",
})
return
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "读取 Voice-Service 响应失败"})
return
}
// 透传状态码、Content-Type 和响应体
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
}
// TTSVoices GET /api/v1/voice/tts/voices
// 代理请求到 voice-service 获取可用语音列表
func (h *VoiceHandler) TTSVoices(c *gin.Context) {
url := strings.TrimRight(h.voiceServiceURL, "/") + "/api/v1/tts/voices"
resp, err := h.client.Get(url)
if err != nil {
log.Printf("[voice] Voice-Service 不可达 (Voices): %v", err)
c.JSON(http.StatusBadGateway, gin.H{
"error": "Voice-Service 不可达: " + err.Error(),
"errorType": "voice_service_unreachable",
})
return
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
// 解析并透传
var data interface{}
json.Unmarshal(respBody, &data)
c.JSON(resp.StatusCode, data)
}
// TTSStatus GET /api/v1/voice/tts/status
// 代理请求到 voice-service 获取 TTS 状态
func (h *VoiceHandler) TTSStatus(c *gin.Context) {
url := strings.TrimRight(h.voiceServiceURL, "/") + "/api/v1/tts/status"
resp, err := h.client.Get(url)
if err != nil {
log.Printf("[voice] Voice-Service 不可达 (TTS Status): %v", err)
c.JSON(http.StatusBadGateway, gin.H{
"error": "Voice-Service 不可达: " + err.Error(),
"errorType": "voice_service_unreachable",
})
return
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
var data interface{}
json.Unmarshal(respBody, &data)
c.JSON(resp.StatusCode, data)
}
// VoiceStatus GET /api/v1/voice/status
// 代理请求到 voice-service 获取完整状态(STT + TTS
func (h *VoiceHandler) VoiceStatus(c *gin.Context) {
url := strings.TrimRight(h.voiceServiceURL, "/") + "/api/v1/status"
resp, err := h.client.Get(url)
if err != nil {
log.Printf("[voice] Voice-Service 不可达 (Status): %v", err)
c.JSON(http.StatusBadGateway, gin.H{
"error": "Voice-Service 不可达: " + err.Error(),
"errorType": "voice_service_unreachable",
})
return
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
var data interface{}
json.Unmarshal(respBody, &data)
c.JSON(resp.StatusCode, data)
}
+120 -1
View File
@@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/config"
"github.com/yourname/cyrene-ai/gateway/internal/engine"
"github.com/yourname/cyrene-ai/gateway/internal/handler"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/store"
@@ -14,7 +15,7 @@ import (
)
// Setup 注册所有路由
func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.SessionStore) {
func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.SessionStore, reminderStore *store.ReminderStore, briefingStore *store.BriefingStore, automationStore *store.AutomationStore, fileStore *store.FileStore, ruleEngine *engine.RuleEngine, knowledgeStore *store.KnowledgeStore, imageHandler *handler.ImageHandler) {
// 限流器
rateLimiter := middleware.NewRateLimiter(10, 20) // 每秒10个请求,突发20
@@ -24,6 +25,16 @@ func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.S
memoryHandler := handler.NewMemoryHandler(cfg.MemoryServiceURL)
chatHandler := handler.NewChatHandler(cfg, hub)
webhookHandler := handler.NewWebhookHandler(cfg, hub)
notificationHandler := handler.NewNotificationHandler(cfg, hub)
reminderHandler := handler.NewReminderHandler(reminderStore, hub)
briefingHandler := handler.NewBriefingHandler(cfg, hub, briefingStore, reminderStore)
voiceHandler := handler.NewVoiceHandler(cfg.VoiceServiceURL)
fileHandler := handler.NewFileHandler(fileStore)
automationHandler := handler.NewAutomationHandler(automationStore, ruleEngine)
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeStore, fileStore)
if imageHandler == nil {
imageHandler = handler.NewImageHandler(cfg, fileStore)
}
// ========== 公开路由 ==========
api := r.Group("/api/v1")
@@ -62,8 +73,12 @@ func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.S
sessions.DELETE("/:id", sessionHandler.Delete) // DELETE /api/v1/sessions/:id
sessions.GET("/:id/messages", sessionHandler.GetMessages) // GET /api/v1/sessions/:id/messages?limit=50
sessions.DELETE("/:id/messages", sessionHandler.ClearMessages) // DELETE /api/v1/sessions/:id/messages
sessions.GET("/:id/export", sessionHandler.ExportSession) // GET /api/v1/sessions/:id/export?format=json|markdown|txt
}
// 消息搜索
protected.GET("/messages/search", sessionHandler.SearchMessages) // GET /api/v1/messages/search?q=xxx&user_id=xxx&limit=50&offset=0
// 记忆管理
memory := protected.Group("/memory")
{
@@ -73,6 +88,103 @@ func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.S
memory.DELETE("", memoryHandler.Delete)
}
// 通知推送 (需要认证)
notifications := protected.Group("/notifications")
{
notifications.POST("/push", notificationHandler.Push)
}
// 提醒管理 (需要认证)
reminders := protected.Group("/reminders")
{
reminders.GET("", reminderHandler.List) // GET /api/v1/reminders?user_id=xxx&status=pending&limit=50
reminders.POST("", reminderHandler.Create) // POST /api/v1/reminders
reminders.PUT("/:id", reminderHandler.Update) // PUT /api/v1/reminders/:id
reminders.DELETE("/:id", reminderHandler.Delete) // DELETE /api/v1/reminders/:id
}
// 每日简报 (需要认证)
briefings := protected.Group("/briefings")
{
briefings.GET("", briefingHandler.GetBriefing) // GET /api/v1/briefings?user_id=xxx&date=2024-01-01
briefings.GET("/latest", briefingHandler.GetLatestBriefings) // GET /api/v1/briefings/latest?user_id=xxx&limit=7
briefings.POST("/generate", briefingHandler.Generate) // POST /api/v1/briefings/generate
}
// 语音识别 + TTS (需要认证)
voice := protected.Group("/voice")
{
voice.POST("/transcribe", voiceHandler.Transcribe)
voice.POST("/tts", voiceHandler.TTSSynthesize)
voice.GET("/tts/voices", voiceHandler.TTSVoices)
voice.GET("/tts/status", voiceHandler.TTSStatus)
voice.GET("/status", voiceHandler.VoiceStatus)
}
// 文件管理 (需要认证)
files := protected.Group("/files")
{
files.POST("/upload", fileHandler.Upload)
files.GET("", fileHandler.List)
files.GET("/:id", fileHandler.Get)
files.GET("/:id/download", fileHandler.Download)
files.GET("/:id/thumbnail", fileHandler.Thumbnail)
files.DELETE("/:id", fileHandler.Delete)
}
// 自动化 (需要认证)
automation := protected.Group("/automation")
{
// 规则
rules := automation.Group("/rules")
{
rules.GET("", automationHandler.ListRules) // GET /api/v1/automation/rules
rules.POST("", automationHandler.CreateRule) // POST /api/v1/automation/rules
rules.GET("/:id", automationHandler.GetRule) // GET /api/v1/automation/rules/:id
rules.PUT("/:id", automationHandler.UpdateRule) // PUT /api/v1/automation/rules/:id
rules.DELETE("/:id", automationHandler.DeleteRule) // DELETE /api/v1/automation/rules/:id
rules.POST("/:id/trigger", automationHandler.TriggerRule) // POST /api/v1/automation/rules/:id/trigger
}
// 场景
scenes := automation.Group("/scenes")
{
scenes.GET("", automationHandler.ListScenes) // GET /api/v1/automation/scenes
scenes.POST("", automationHandler.CreateScene) // POST /api/v1/automation/scenes
scenes.GET("/:id", automationHandler.GetScene) // GET /api/v1/automation/scenes/:id
scenes.PUT("/:id", automationHandler.UpdateScene) // PUT /api/v1/automation/scenes/:id
scenes.DELETE("/:id", automationHandler.DeleteScene) // DELETE /api/v1/automation/scenes/:id
scenes.POST("/:id/execute", automationHandler.ExecuteScene) // POST /api/v1/automation/scenes/:id/execute
}
}
// 知识库管理 (需要认证)
knowledge := protected.Group("/knowledge")
{
// 知识库 CRUD
knowledge.POST("/bases", knowledgeHandler.CreateKB) // POST /api/v1/knowledge/bases
knowledge.GET("/bases", knowledgeHandler.ListKBs) // GET /api/v1/knowledge/bases
knowledge.GET("/bases/:id", knowledgeHandler.GetKB) // GET /api/v1/knowledge/bases/:id
knowledge.PUT("/bases/:id", knowledgeHandler.UpdateKB) // PUT /api/v1/knowledge/bases/:id
knowledge.DELETE("/bases/:id", knowledgeHandler.DeleteKB) // DELETE /api/v1/knowledge/bases/:id
// 文档管理
knowledge.POST("/bases/:id/documents", knowledgeHandler.AddDocument) // POST /api/v1/knowledge/bases/:id/documents
knowledge.GET("/bases/:id/documents", knowledgeHandler.ListDocuments) // GET /api/v1/knowledge/bases/:id/documents
knowledge.GET("/documents/:id", knowledgeHandler.GetDocument) // GET /api/v1/knowledge/documents/:id
knowledge.DELETE("/documents/:id", knowledgeHandler.DeleteDocument) // DELETE /api/v1/knowledge/documents/:id
// 搜索
knowledge.POST("/search", knowledgeHandler.Search) // POST /api/v1/knowledge/search
}
// 图片分析 (需要认证)
images := protected.Group("/images")
{
images.POST("/analyze", imageHandler.Analyze) // POST /api/v1/images/analyze
images.GET("/analyze/:file_id", imageHandler.AnalyzeByID) // GET /api/v1/images/analyze/:file_id
}
// Admin 路由 (需要管理员权限)
admin := protected.Group("/admin")
admin.Use(adminAuth())
@@ -83,6 +195,13 @@ func Setup(r *gin.Engine, hub *ws.Hub, cfg *config.Config, sessionStore *store.S
}
}
// ========== 内部服务路由 (使用 Internal Service Token 认证) ==========
internal := r.Group("/api/v1/internal")
internal.Use(notificationHandler.InternalNotifyAuth())
{
internal.POST("/notify", notificationHandler.InternalNotify)
}
// ========== WebSocket路由 ==========
// WebSocket升级在HTTP层,token通过query参数或Header传递
wsGroup := r.Group("/ws")
@@ -0,0 +1,365 @@
package store
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"time"
)
// AutomationRule 自动化规则模型
type AutomationRule struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Name string `json:"name"`
Description string `json:"description"`
TriggerType string `json:"trigger_type"`
TriggerConfig *json.RawMessage `json:"trigger_config"`
Conditions *json.RawMessage `json:"conditions"`
Actions *json.RawMessage `json:"actions"`
Enabled bool `json:"enabled"`
LastTriggeredAt *time.Time `json:"last_triggered_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// AutomationScene 自动化场景模型
type AutomationScene struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Name string `json:"name"`
Icon string `json:"icon"`
RuleIDs *json.RawMessage `json:"rule_ids"`
CreatedAt time.Time `json:"created_at"`
}
// AutomationStore 自动化持久化存储
type AutomationStore struct {
db *sql.DB
}
// NewAutomationStore 使用已有数据库连接初始化自动化存储并自动建表
func NewAutomationStore(db *sql.DB) (*AutomationStore, error) {
store := &AutomationStore{db: db}
if err := store.migrate(); err != nil {
return nil, fmt.Errorf("自动化表迁移失败: %w", err)
}
log.Println("[AutomationStore] 自动化持久化存储已初始化")
return store, nil
}
// migrate 自动创建表结构
func (s *AutomationStore) migrate() error {
queries := []string{
`CREATE TABLE IF NOT EXISTS automation_rules (
id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT DEFAULT '',
trigger_type VARCHAR(32) NOT NULL,
trigger_config JSONB DEFAULT '{}',
conditions JSONB DEFAULT '[]',
actions JSONB NOT NULL DEFAULT '[]',
enabled BOOLEAN DEFAULT TRUE,
last_triggered_at TIMESTAMP,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_automation_rules_user_id ON automation_rules(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_automation_rules_enabled ON automation_rules(enabled)`,
`CREATE TABLE IF NOT EXISTS automation_scenes (
id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
name VARCHAR(255) NOT NULL,
icon VARCHAR(64) DEFAULT '',
rule_ids JSONB DEFAULT '[]',
created_at TIMESTAMP DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_automation_scenes_user_id ON automation_scenes(user_id)`,
}
for _, q := range queries {
if _, err := s.db.Exec(q); err != nil {
return fmt.Errorf("迁移SQL执行失败: %w\nSQL: %s", err, q)
}
}
return nil
}
// ========== Rule CRUD ==========
// CreateRule 创建新规则
func (s *AutomationStore) CreateRule(rule *AutomationRule) error {
now := time.Now()
rule.CreatedAt = now
rule.UpdatedAt = now
triggerConfig := jsonNull(rule.TriggerConfig)
conditions := jsonNull(rule.Conditions)
actions := jsonNull(rule.Actions)
_, err := s.db.Exec(
`INSERT INTO automation_rules (id, user_id, name, description, trigger_type, trigger_config, conditions, actions, enabled, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)`,
rule.ID, rule.UserID, rule.Name, rule.Description, rule.TriggerType,
triggerConfig, conditions, actions, rule.Enabled, rule.CreatedAt, rule.UpdatedAt,
)
if err != nil {
return fmt.Errorf("创建规则失败: %w", err)
}
return nil
}
// GetRulesByUser 获取用户的所有规则
func (s *AutomationStore) GetRulesByUser(userID string) ([]AutomationRule, error) {
rows, err := s.db.Query(
`SELECT id, user_id, name, description, trigger_type, trigger_config, conditions, actions, enabled, last_triggered_at, created_at, updated_at
FROM automation_rules WHERE user_id = $1
ORDER BY created_at DESC`,
userID,
)
if err != nil {
return nil, fmt.Errorf("查询用户规则失败: %w", err)
}
defer rows.Close()
var rules []AutomationRule
for rows.Next() {
var r AutomationRule
if err := rows.Scan(&r.ID, &r.UserID, &r.Name, &r.Description, &r.TriggerType,
&r.TriggerConfig, &r.Conditions, &r.Actions, &r.Enabled, &r.LastTriggeredAt,
&r.CreatedAt, &r.UpdatedAt); err != nil {
return nil, fmt.Errorf("扫描规则行失败: %w", err)
}
rules = append(rules, r)
}
if rules == nil {
rules = []AutomationRule{}
}
return rules, rows.Err()
}
// GetRule 获取单个规则
func (s *AutomationStore) GetRule(id string) (*AutomationRule, error) {
var r AutomationRule
err := s.db.QueryRow(
`SELECT id, user_id, name, description, trigger_type, trigger_config, conditions, actions, enabled, last_triggered_at, created_at, updated_at
FROM automation_rules WHERE id = $1`,
id,
).Scan(&r.ID, &r.UserID, &r.Name, &r.Description, &r.TriggerType,
&r.TriggerConfig, &r.Conditions, &r.Actions, &r.Enabled, &r.LastTriggeredAt,
&r.CreatedAt, &r.UpdatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("查询规则失败: %w", err)
}
return &r, nil
}
// UpdateRule 更新规则
func (s *AutomationStore) UpdateRule(rule *AutomationRule) error {
triggerConfig := jsonNull(rule.TriggerConfig)
conditions := jsonNull(rule.Conditions)
actions := jsonNull(rule.Actions)
_, err := s.db.Exec(
`UPDATE automation_rules SET name = $1, description = $2, trigger_type = $3,
trigger_config = $4, conditions = $5, actions = $6, enabled = $7, updated_at = NOW()
WHERE id = $8`,
rule.Name, rule.Description, rule.TriggerType,
triggerConfig, conditions, actions, rule.Enabled, rule.ID,
)
if err != nil {
return fmt.Errorf("更新规则失败: %w", err)
}
return nil
}
// DeleteRule 删除规则
func (s *AutomationStore) DeleteRule(id string) error {
_, err := s.db.Exec(`DELETE FROM automation_rules WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("删除规则失败: %w", err)
}
return nil
}
// GetEnabledRules 获取所有启用的规则(供引擎使用)
func (s *AutomationStore) GetEnabledRules() ([]AutomationRule, error) {
rows, err := s.db.Query(
`SELECT id, user_id, name, description, trigger_type, trigger_config, conditions, actions, enabled, last_triggered_at, created_at, updated_at
FROM automation_rules WHERE enabled = TRUE
ORDER BY created_at ASC`,
)
if err != nil {
return nil, fmt.Errorf("查询启用的规则失败: %w", err)
}
defer rows.Close()
var rules []AutomationRule
for rows.Next() {
var r AutomationRule
if err := rows.Scan(&r.ID, &r.UserID, &r.Name, &r.Description, &r.TriggerType,
&r.TriggerConfig, &r.Conditions, &r.Actions, &r.Enabled, &r.LastTriggeredAt,
&r.CreatedAt, &r.UpdatedAt); err != nil {
return nil, fmt.Errorf("扫描规则行失败: %w", err)
}
rules = append(rules, r)
}
if rules == nil {
rules = []AutomationRule{}
}
return rules, rows.Err()
}
// MarkRuleTriggered 更新 last_triggered_at
func (s *AutomationStore) MarkRuleTriggered(id string) error {
_, err := s.db.Exec(
`UPDATE automation_rules SET last_triggered_at = NOW(), updated_at = NOW() WHERE id = $1`,
id,
)
if err != nil {
return fmt.Errorf("标记规则触发失败: %w", err)
}
return nil
}
// ========== Scene CRUD ==========
// CreateScene 创建新场景
func (s *AutomationStore) CreateScene(scene *AutomationScene) error {
ruleIDs := jsonNull(scene.RuleIDs)
_, err := s.db.Exec(
`INSERT INTO automation_scenes (id, user_id, name, icon, rule_ids, created_at)
VALUES ($1, $2, $3, $4, $5, NOW())`,
scene.ID, scene.UserID, scene.Name, scene.Icon, ruleIDs,
)
if err != nil {
return fmt.Errorf("创建场景失败: %w", err)
}
return nil
}
// GetScenesByUser 获取用户的所有场景
func (s *AutomationStore) GetScenesByUser(userID string) ([]AutomationScene, error) {
rows, err := s.db.Query(
`SELECT id, user_id, name, icon, rule_ids, created_at
FROM automation_scenes WHERE user_id = $1
ORDER BY created_at DESC`,
userID,
)
if err != nil {
return nil, fmt.Errorf("查询用户场景失败: %w", err)
}
defer rows.Close()
var scenes []AutomationScene
for rows.Next() {
var sc AutomationScene
if err := rows.Scan(&sc.ID, &sc.UserID, &sc.Name, &sc.Icon, &sc.RuleIDs, &sc.CreatedAt); err != nil {
return nil, fmt.Errorf("扫描场景行失败: %w", err)
}
scenes = append(scenes, sc)
}
if scenes == nil {
scenes = []AutomationScene{}
}
return scenes, rows.Err()
}
// GetScene 获取单个场景
func (s *AutomationStore) GetScene(id string) (*AutomationScene, error) {
var sc AutomationScene
err := s.db.QueryRow(
`SELECT id, user_id, name, icon, rule_ids, created_at
FROM automation_scenes WHERE id = $1`,
id,
).Scan(&sc.ID, &sc.UserID, &sc.Name, &sc.Icon, &sc.RuleIDs, &sc.CreatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("查询场景失败: %w", err)
}
return &sc, nil
}
// UpdateScene 更新场景
func (s *AutomationStore) UpdateScene(scene *AutomationScene) error {
ruleIDs := jsonNull(scene.RuleIDs)
_, err := s.db.Exec(
`UPDATE automation_scenes SET name = $1, icon = $2, rule_ids = $3 WHERE id = $4`,
scene.Name, scene.Icon, ruleIDs, scene.ID,
)
if err != nil {
return fmt.Errorf("更新场景失败: %w", err)
}
return nil
}
// DeleteScene 删除场景
func (s *AutomationStore) DeleteScene(id string) error {
_, err := s.db.Exec(`DELETE FROM automation_scenes WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("删除场景失败: %w", err)
}
return nil
}
// GetSceneRules 根据 scene 的 rule_ids 取出所有关联的 rules
func (s *AutomationStore) GetSceneRules(sceneID string) ([]AutomationRule, error) {
sc, err := s.GetScene(sceneID)
if err != nil {
return nil, err
}
if sc == nil {
return []AutomationRule{}, nil
}
var ruleIDs []string
if sc.RuleIDs != nil {
if err := json.Unmarshal(*sc.RuleIDs, &ruleIDs); err != nil {
return nil, fmt.Errorf("解析场景规则ID失败: %w", err)
}
}
if len(ruleIDs) == 0 {
return []AutomationRule{}, nil
}
// 构建 IN 查询
var rules []AutomationRule
for _, rid := range ruleIDs {
r, err := s.GetRule(rid)
if err != nil {
return nil, fmt.Errorf("查询场景关联规则失败: %w", err)
}
if r != nil {
rules = append(rules, *r)
}
}
if rules == nil {
rules = []AutomationRule{}
}
return rules, nil
}
// jsonNull 将 *json.RawMessage 转为可写入数据库的 JSON 或 null
func jsonNull(raw *json.RawMessage) interface{} {
if raw == nil {
return nil
}
return []byte(*raw)
}
@@ -0,0 +1,321 @@
package store
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"time"
)
// Briefing 每日简报模型
type Briefing struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Date string `json:"date"` // YYYY-MM-DD
Weather *WeatherData `json:"weather"`
News []NewsItem `json:"news"`
Reminders []BriefReminder `json:"reminders"`
Summary string `json:"summary"`
Status string `json:"status"` // pending, generated, delivered
GeneratedAt *time.Time `json:"generated_at,omitempty"`
DeliveredAt *time.Time `json:"delivered_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// WeatherData 天气数据
type WeatherData struct {
Location string `json:"location"`
Temp float64 `json:"temp"`
Condition string `json:"condition"`
Icon string `json:"icon"`
}
// NewsItem 新闻条目
type NewsItem struct {
Title string `json:"title"`
URL string `json:"url"`
Source string `json:"source"`
Summary string `json:"summary"`
}
// BriefReminder 简报中的提醒摘要
type BriefReminder struct {
ID string `json:"id"`
Title string `json:"title"`
RemindAt string `json:"remind_at"`
}
// BriefingStore 每日简报持久化存储
type BriefingStore struct {
db *sql.DB
}
// NewBriefingStore 使用已有数据库连接初始化简报存储并自动建表
func NewBriefingStore(db *sql.DB) (*BriefingStore, error) {
store := &BriefingStore{db: db}
if err := store.migrate(); err != nil {
return nil, fmt.Errorf("简报表迁移失败: %w", err)
}
log.Println("[BriefingStore] 简报持久化存储已初始化")
return store, nil
}
// migrate 自动创建简报表结构
func (s *BriefingStore) migrate() error {
queries := []string{
`CREATE TABLE IF NOT EXISTS daily_briefings (
id VARCHAR(36) PRIMARY KEY,
user_id VARCHAR(255) NOT NULL,
date DATE NOT NULL,
weather JSONB DEFAULT '{}',
news JSONB DEFAULT '[]',
reminders JSONB DEFAULT '[]',
summary TEXT DEFAULT '',
status VARCHAR(20) DEFAULT 'pending',
generated_at TIMESTAMPTZ,
delivered_at TIMESTAMPTZ,
created_at TIMESTAMPTZ DEFAULT NOW(),
UNIQUE(user_id, date)
)`,
`CREATE INDEX IF NOT EXISTS idx_briefings_user_id ON daily_briefings(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_briefings_date ON daily_briefings(date)`,
`CREATE INDEX IF NOT EXISTS idx_briefings_user_date ON daily_briefings(user_id, date)`,
}
for _, q := range queries {
if _, err := s.db.Exec(q); err != nil {
return fmt.Errorf("迁移SQL执行失败: %w\nSQL: %s", err, q)
}
}
return nil
}
// CreateOrUpdateBriefing upsert 简报
func (s *BriefingStore) CreateOrUpdateBriefing(b *Briefing) error {
weatherJSON, err := json.Marshal(b.Weather)
if err != nil {
return fmt.Errorf("序列化天气数据失败: %w", err)
}
newsJSON, err := json.Marshal(b.News)
if err != nil {
return fmt.Errorf("序列化新闻数据失败: %w", err)
}
remindersJSON, err := json.Marshal(b.Reminders)
if err != nil {
return fmt.Errorf("序列化提醒数据失败: %w", err)
}
_, err = s.db.Exec(
`INSERT INTO daily_briefings (id, user_id, date, weather, news, reminders, summary, status, generated_at, delivered_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
ON CONFLICT (user_id, date) DO UPDATE SET
weather = EXCLUDED.weather,
news = EXCLUDED.news,
reminders = EXCLUDED.reminders,
summary = EXCLUDED.summary,
status = EXCLUDED.status,
generated_at = EXCLUDED.generated_at,
delivered_at = EXCLUDED.delivered_at`,
b.ID, b.UserID, b.Date, string(weatherJSON), string(newsJSON), string(remindersJSON),
b.Summary, b.Status, b.GeneratedAt, b.DeliveredAt,
)
if err != nil {
return fmt.Errorf("upsert 简报失败: %w", err)
}
return nil
}
// GetBriefingByDate 获取指定日期简报
func (s *BriefingStore) GetBriefingByDate(userID, date string) (*Briefing, error) {
row := s.db.QueryRow(
`SELECT id, user_id, date::TEXT, weather, news, reminders, summary, status, generated_at, delivered_at, created_at
FROM daily_briefings WHERE user_id = $1 AND date = $2::DATE`,
userID, date,
)
b, err := s.scanBriefing(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("查询简报失败: %w", err)
}
return b, nil
}
// GetLatestBriefings 获取最近简报列表
func (s *BriefingStore) GetLatestBriefings(userID string, limit int) ([]Briefing, error) {
if limit <= 0 {
limit = 7
}
rows, err := s.db.Query(
`SELECT id, user_id, date::TEXT, weather, news, reminders, summary, status, generated_at, delivered_at, created_at
FROM daily_briefings WHERE user_id = $1
ORDER BY date DESC LIMIT $2`,
userID, limit,
)
if err != nil {
return nil, fmt.Errorf("查询简报列表失败: %w", err)
}
defer rows.Close()
var briefings []Briefing
for rows.Next() {
var (
id, uid, date, summary, status string
weatherRaw, newsRaw, remindersRaw []byte
generatedAt, deliveredAt, createdAt sql.NullTime
)
if err := rows.Scan(&id, &uid, &date, &weatherRaw, &newsRaw, &remindersRaw,
&summary, &status, &generatedAt, &deliveredAt, &createdAt); err != nil {
return nil, fmt.Errorf("扫描简报行失败: %w", err)
}
b := Briefing{
ID: id,
UserID: uid,
Date: date,
Summary: summary,
Status: status,
}
if weatherRaw != nil {
var w WeatherData
if err := json.Unmarshal(weatherRaw, &w); err == nil {
b.Weather = &w
}
}
if newsRaw != nil {
json.Unmarshal(newsRaw, &b.News)
}
if remindersRaw != nil {
json.Unmarshal(remindersRaw, &b.Reminders)
}
if generatedAt.Valid {
b.GeneratedAt = &generatedAt.Time
}
if deliveredAt.Valid {
b.DeliveredAt = &deliveredAt.Time
}
b.CreatedAt = createdAt.Time
// 确保切片不为 nil
if b.News == nil {
b.News = []NewsItem{}
}
if b.Reminders == nil {
b.Reminders = []BriefReminder{}
}
if b.Weather == nil {
b.Weather = &WeatherData{}
}
briefings = append(briefings, b)
}
if briefings == nil {
briefings = []Briefing{}
}
return briefings, rows.Err()
}
// GetUsersWithBriefings 获取拥有简报的所有用户 ID 列表(用于调度器)
func (s *BriefingStore) GetUsersWithBriefings() ([]string, error) {
rows, err := s.db.Query(`SELECT DISTINCT user_id FROM daily_briefings`)
if err != nil {
return nil, fmt.Errorf("查询简报用户列表失败: %w", err)
}
defer rows.Close()
var userIDs []string
for rows.Next() {
var uid string
if err := rows.Scan(&uid); err != nil {
return nil, fmt.Errorf("扫描用户ID失败: %w", err)
}
userIDs = append(userIDs, uid)
}
if userIDs == nil {
userIDs = []string{}
}
return userIDs, rows.Err()
}
// GetAllUsers 获取所有用户 ID(从 reminders 表获取,作为降级方案)
func (s *BriefingStore) GetAllUsers() ([]string, error) {
rows, err := s.db.Query(`SELECT DISTINCT user_id FROM reminders`)
if err != nil {
return nil, fmt.Errorf("查询用户列表失败: %w", err)
}
defer rows.Close()
var userIDs []string
for rows.Next() {
var uid string
if err := rows.Scan(&uid); err != nil {
return nil, fmt.Errorf("扫描用户ID失败: %w", err)
}
userIDs = append(userIDs, uid)
}
if userIDs == nil {
userIDs = []string{}
}
return userIDs, rows.Err()
}
// scanBriefing 扫描单行简报
func (s *BriefingStore) scanBriefing(row *sql.Row) (*Briefing, error) {
var (
id, uid, date, summary, status string
weatherRaw, newsRaw, remindersRaw []byte
generatedAt, deliveredAt, createdAt sql.NullTime
)
if err := row.Scan(&id, &uid, &date, &weatherRaw, &newsRaw, &remindersRaw,
&summary, &status, &generatedAt, &deliveredAt, &createdAt); err != nil {
return nil, err
}
b := &Briefing{
ID: id,
UserID: uid,
Date: date,
Summary: summary,
Status: status,
}
if weatherRaw != nil {
var w WeatherData
if err := json.Unmarshal(weatherRaw, &w); err == nil {
b.Weather = &w
}
}
if b.Weather == nil {
b.Weather = &WeatherData{}
}
if newsRaw != nil {
json.Unmarshal(newsRaw, &b.News)
}
if b.News == nil {
b.News = []NewsItem{}
}
if remindersRaw != nil {
json.Unmarshal(remindersRaw, &b.Reminders)
}
if b.Reminders == nil {
b.Reminders = []BriefReminder{}
}
if generatedAt.Valid {
b.GeneratedAt = &generatedAt.Time
}
if deliveredAt.Valid {
b.DeliveredAt = &deliveredAt.Time
}
b.CreatedAt = createdAt.Time
return b, nil
}
@@ -0,0 +1,172 @@
package store
import (
"database/sql"
"fmt"
"log"
"time"
)
// File 文件元数据模型
type File struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Filename string `json:"filename"`
StoredPath string `json:"stored_path"`
MimeType string `json:"mime_type"`
Size int64 `json:"size"`
Hash string `json:"hash"`
IsPublic bool `json:"is_public"`
CreatedAt time.Time `json:"created_at"`
}
// FileStore 文件元数据持久化存储
type FileStore struct {
db *sql.DB
}
// NewFileStore 使用已有数据库连接初始化文件存储并自动建表
func NewFileStore(db *sql.DB) (*FileStore, error) {
store := &FileStore{db: db}
if err := store.migrate(); err != nil {
return nil, fmt.Errorf("文件表迁移失败: %w", err)
}
log.Println("[FileStore] 文件持久化存储已初始化")
return store, nil
}
// migrate 自动创建文件表结构
func (s *FileStore) migrate() error {
queries := []string{
`CREATE TABLE IF NOT EXISTS files (
id VARCHAR(36) PRIMARY KEY,
user_id VARCHAR(255) NOT NULL,
filename VARCHAR(500) NOT NULL,
stored_path VARCHAR(1000) NOT NULL,
mime_type VARCHAR(255) NOT NULL DEFAULT 'application/octet-stream',
size BIGINT NOT NULL DEFAULT 0,
hash VARCHAR(64) NOT NULL DEFAULT '',
is_public BOOLEAN DEFAULT FALSE,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_files_user_id ON files(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_files_hash ON files(hash)`,
`CREATE INDEX IF NOT EXISTS idx_files_created_at ON files(user_id, created_at DESC)`,
}
for _, q := range queries {
if _, err := s.db.Exec(q); err != nil {
return fmt.Errorf("迁移SQL执行失败: %w\nSQL: %s", err, q)
}
}
return nil
}
// CreateFile 创建文件元数据记录
func (s *FileStore) CreateFile(f *File) error {
_, err := s.db.Exec(
`INSERT INTO files (id, user_id, filename, stored_path, mime_type, size, hash, is_public, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
f.ID, f.UserID, f.Filename, f.StoredPath, f.MimeType, f.Size, f.Hash, f.IsPublic, f.CreatedAt,
)
if err != nil {
return fmt.Errorf("创建文件记录失败: %w", err)
}
return nil
}
// GetFile 根据ID获取文件元数据
func (s *FileStore) GetFile(id string) (*File, error) {
var f File
err := s.db.QueryRow(
`SELECT id, user_id, filename, stored_path, mime_type, size, hash, is_public, created_at
FROM files WHERE id = $1`,
id,
).Scan(&f.ID, &f.UserID, &f.Filename, &f.StoredPath, &f.MimeType, &f.Size, &f.Hash, &f.IsPublic, &f.CreatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("查询文件失败: %w", err)
}
return &f, nil
}
// GetUserFiles 获取用户的所有文件,支持分页
func (s *FileStore) GetUserFiles(userID string, page, limit int) ([]File, int, error) {
if page <= 0 {
page = 1
}
if limit <= 0 || limit > 100 {
limit = 20
}
offset := (page - 1) * limit
// 获取总数
var total int
if err := s.db.QueryRow(
`SELECT COUNT(*) FROM files WHERE user_id = $1`,
userID,
).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("查询文件总数失败: %w", err)
}
// 分页查询
rows, err := s.db.Query(
`SELECT id, user_id, filename, stored_path, mime_type, size, hash, is_public, created_at
FROM files WHERE user_id = $1
ORDER BY created_at DESC
LIMIT $2 OFFSET $3`,
userID, limit, offset,
)
if err != nil {
return nil, 0, fmt.Errorf("查询用户文件失败: %w", err)
}
defer rows.Close()
var files []File
for rows.Next() {
var f File
if err := rows.Scan(&f.ID, &f.UserID, &f.Filename, &f.StoredPath, &f.MimeType, &f.Size, &f.Hash, &f.IsPublic, &f.CreatedAt); err != nil {
return nil, 0, fmt.Errorf("扫描文件行失败: %w", err)
}
files = append(files, f)
}
if files == nil {
files = []File{}
}
return files, total, rows.Err()
}
// GetFileByHash 根据SHA256哈希查找文件(用于去重)
func (s *FileStore) GetFileByHash(hash string) (*File, error) {
if hash == "" {
return nil, nil
}
var f File
err := s.db.QueryRow(
`SELECT id, user_id, filename, stored_path, mime_type, size, hash, is_public, created_at
FROM files WHERE hash = $1
ORDER BY created_at ASC LIMIT 1`,
hash,
).Scan(&f.ID, &f.UserID, &f.Filename, &f.StoredPath, &f.MimeType, &f.Size, &f.Hash, &f.IsPublic, &f.CreatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("按哈希查询文件失败: %w", err)
}
return &f, nil
}
// DeleteFile 删除文件元数据记录
func (s *FileStore) DeleteFile(id string) error {
_, err := s.db.Exec(`DELETE FROM files WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("删除文件记录失败: %w", err)
}
return nil
}
@@ -0,0 +1,822 @@
package store
import (
"crypto/rand"
"database/sql"
"fmt"
"log"
"strings"
"time"
"unicode/utf8"
)
// ========== 模型定义 ==========
// KnowledgeBase 知识库
type KnowledgeBase struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Name string `json:"name"`
Description string `json:"description"`
DocumentCount int `json:"document_count"`
ChunkCount int `json:"chunk_count"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// KnowledgeDocument 知识库文档
type KnowledgeDocument struct {
ID string `json:"id"`
KBID string `json:"kb_id"`
UserID string `json:"user_id"`
Title string `json:"title"`
SourceType string `json:"source_type"` // "file", "text", "url"
SourceRef string `json:"source_ref"` // 文件 ID 或 URL
ContentType string `json:"content_type"` // "text/plain", "text/markdown", "text/html"
RawContent string `json:"raw_content"`
ChunkCount int `json:"chunk_count"`
CreatedAt time.Time `json:"created_at"`
}
// KnowledgeChunk 文档分块
type KnowledgeChunk struct {
ID string `json:"id"`
DocID string `json:"doc_id"`
KBID string `json:"kb_id"`
ChunkIndex int `json:"chunk_index"`
Content string `json:"content"`
TokenCount int `json:"token_count"`
CreatedAt time.Time `json:"created_at"`
}
// SearchChunkResult 搜索结果的块,包含额外上下文信息
type SearchChunkResult struct {
KnowledgeChunk
Relevance float64 `json:"relevance"`
DocumentTitle string `json:"document_title"`
KBName string `json:"kb_name"`
Headline string `json:"headline"`
}
// ========== KnowledgeStore ==========
// KnowledgeStore 知识库持久化存储
type KnowledgeStore struct {
db *sql.DB
}
// NewKnowledgeStore 使用已有数据库连接初始化知识库存储并自动建表
func NewKnowledgeStore(db *sql.DB) (*KnowledgeStore, error) {
store := &KnowledgeStore{db: db}
if err := store.migrate(); err != nil {
return nil, fmt.Errorf("知识库表迁移失败: %w", err)
}
log.Println("[KnowledgeStore] 知识库持久化存储已初始化")
return store, nil
}
// migrate 自动创建知识库相关表结构
func (s *KnowledgeStore) migrate() error {
queries := []string{
// 知识库表
`CREATE TABLE IF NOT EXISTS knowledge_bases (
id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(64) NOT NULL,
name VARCHAR(255) NOT NULL,
description TEXT DEFAULT '',
document_count INT DEFAULT 0,
chunk_count INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_kb_user_id ON knowledge_bases(user_id)`,
// 文档表
`CREATE TABLE IF NOT EXISTS knowledge_documents (
id VARCHAR(64) PRIMARY KEY,
kb_id VARCHAR(64) NOT NULL REFERENCES knowledge_bases(id) ON DELETE CASCADE,
user_id VARCHAR(64) NOT NULL,
title VARCHAR(512) NOT NULL,
source_type VARCHAR(32) DEFAULT 'text',
source_ref VARCHAR(1024) DEFAULT '',
content_type VARCHAR(64) DEFAULT 'text/plain',
raw_content TEXT DEFAULT '',
chunk_count INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_kd_kb_id ON knowledge_documents(kb_id)`,
`CREATE INDEX IF NOT EXISTS idx_kd_user_id ON knowledge_documents(user_id)`,
// 分块表
`CREATE TABLE IF NOT EXISTS knowledge_chunks (
id VARCHAR(64) PRIMARY KEY,
doc_id VARCHAR(64) NOT NULL REFERENCES knowledge_documents(id) ON DELETE CASCADE,
kb_id VARCHAR(64) NOT NULL,
chunk_index INT NOT NULL,
content TEXT NOT NULL,
token_count INT DEFAULT 0,
tsv TSVECTOR,
created_at TIMESTAMPTZ DEFAULT NOW()
)`,
`CREATE INDEX IF NOT EXISTS idx_kc_doc_id ON knowledge_chunks(doc_id)`,
`CREATE INDEX IF NOT EXISTS idx_kc_kb_id ON knowledge_chunks(kb_id)`,
}
for _, q := range queries {
if _, err := s.db.Exec(q); err != nil {
return fmt.Errorf("迁移SQL执行失败: %w\nSQL: %s", err, q)
}
}
// 尝试创建 GIN 索引(可能因权限或扩展问题失败,但不影响功能)
_, err := s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_kc_tsv_gin ON knowledge_chunks USING GIN(tsv)`)
if err != nil {
log.Printf("[KnowledgeStore] ⚠ GIN索引创建失败(将使用ILIKE降级搜索): %v", err)
}
return nil
}
// ========== 知识库 CRUD ==========
// CreateKB 创建知识库
func (s *KnowledgeStore) CreateKB(kb *KnowledgeBase) error {
now := time.Now()
if kb.CreatedAt.IsZero() {
kb.CreatedAt = now
}
if kb.UpdatedAt.IsZero() {
kb.UpdatedAt = now
}
_, err := s.db.Exec(
`INSERT INTO knowledge_bases (id, user_id, name, description, document_count, chunk_count, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
kb.ID, kb.UserID, kb.Name, kb.Description, kb.DocumentCount, kb.ChunkCount, kb.CreatedAt, kb.UpdatedAt,
)
if err != nil {
return fmt.Errorf("创建知识库失败: %w", err)
}
return nil
}
// GetKBsByUser 获取用户的所有知识库
func (s *KnowledgeStore) GetKBsByUser(userID string) ([]KnowledgeBase, error) {
rows, err := s.db.Query(
`SELECT id, user_id, name, description, document_count, chunk_count, created_at, updated_at
FROM knowledge_bases WHERE user_id = $1
ORDER BY updated_at DESC`,
userID,
)
if err != nil {
return nil, fmt.Errorf("查询知识库列表失败: %w", err)
}
defer rows.Close()
var kbs []KnowledgeBase
for rows.Next() {
var kb KnowledgeBase
if err := rows.Scan(&kb.ID, &kb.UserID, &kb.Name, &kb.Description,
&kb.DocumentCount, &kb.ChunkCount, &kb.CreatedAt, &kb.UpdatedAt); err != nil {
return nil, fmt.Errorf("扫描知识库行失败: %w", err)
}
kbs = append(kbs, kb)
}
if kbs == nil {
kbs = []KnowledgeBase{}
}
return kbs, rows.Err()
}
// GetKB 获取单个知识库
func (s *KnowledgeStore) GetKB(id string) (*KnowledgeBase, error) {
var kb KnowledgeBase
err := s.db.QueryRow(
`SELECT id, user_id, name, description, document_count, chunk_count, created_at, updated_at
FROM knowledge_bases WHERE id = $1`,
id,
).Scan(&kb.ID, &kb.UserID, &kb.Name, &kb.Description,
&kb.DocumentCount, &kb.ChunkCount, &kb.CreatedAt, &kb.UpdatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("查询知识库失败: %w", err)
}
return &kb, nil
}
// UpdateKB 更新知识库名称和描述
func (s *KnowledgeStore) UpdateKB(id string, name, description string) error {
_, err := s.db.Exec(
`UPDATE knowledge_bases SET name = $1, description = $2, updated_at = NOW() WHERE id = $3`,
name, description, id,
)
if err != nil {
return fmt.Errorf("更新知识库失败: %w", err)
}
return nil
}
// DeleteKB 删除知识库(级联删除文档和块)
func (s *KnowledgeStore) DeleteKB(id string) error {
_, err := s.db.Exec(`DELETE FROM knowledge_bases WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("删除知识库失败: %w", err)
}
return nil
}
// updateKBStats 更新知识库的统计计数
func (s *KnowledgeStore) updateKBStats(kbID string) error {
_, err := s.db.Exec(
`UPDATE knowledge_bases SET
document_count = (SELECT COUNT(*) FROM knowledge_documents WHERE kb_id = $1),
chunk_count = (SELECT COUNT(*) FROM knowledge_chunks WHERE kb_id = $1),
updated_at = NOW()
WHERE id = $1`,
kbID,
)
return err
}
// ========== 文档 CRUD ==========
// AddDocument 添加文档,返回创建的文档
func (s *KnowledgeStore) AddDocument(doc *KnowledgeDocument) error {
if doc.CreatedAt.IsZero() {
doc.CreatedAt = time.Now()
}
_, err := s.db.Exec(
`INSERT INTO knowledge_documents (id, kb_id, user_id, title, source_type, source_ref, content_type, raw_content, chunk_count, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`,
doc.ID, doc.KBID, doc.UserID, doc.Title, doc.SourceType, doc.SourceRef,
doc.ContentType, doc.RawContent, doc.ChunkCount, doc.CreatedAt,
)
if err != nil {
return fmt.Errorf("添加文档失败: %w", err)
}
// 更新知识库统计
if err := s.updateKBStats(doc.KBID); err != nil {
log.Printf("[KnowledgeStore] 更新知识库统计失败: %v", err)
}
return nil
}
// GetDocument 获取单个文档
func (s *KnowledgeStore) GetDocument(id string) (*KnowledgeDocument, error) {
var doc KnowledgeDocument
err := s.db.QueryRow(
`SELECT id, kb_id, user_id, title, source_type, source_ref, content_type, raw_content, chunk_count, created_at
FROM knowledge_documents WHERE id = $1`,
id,
).Scan(&doc.ID, &doc.KBID, &doc.UserID, &doc.Title, &doc.SourceType, &doc.SourceRef,
&doc.ContentType, &doc.RawContent, &doc.ChunkCount, &doc.CreatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("查询文档失败: %w", err)
}
return &doc, nil
}
// GetDocumentsByKB 获取知识库中的所有文档
func (s *KnowledgeStore) GetDocumentsByKB(kbID string) ([]KnowledgeDocument, error) {
rows, err := s.db.Query(
`SELECT id, kb_id, user_id, title, source_type, source_ref, content_type, raw_content, chunk_count, created_at
FROM knowledge_documents WHERE kb_id = $1
ORDER BY created_at DESC`,
kbID,
)
if err != nil {
return nil, fmt.Errorf("查询文档列表失败: %w", err)
}
defer rows.Close()
var docs []KnowledgeDocument
for rows.Next() {
var doc KnowledgeDocument
if err := rows.Scan(&doc.ID, &doc.KBID, &doc.UserID, &doc.Title, &doc.SourceType, &doc.SourceRef,
&doc.ContentType, &doc.RawContent, &doc.ChunkCount, &doc.CreatedAt); err != nil {
return nil, fmt.Errorf("扫描文档行失败: %w", err)
}
docs = append(docs, doc)
}
if docs == nil {
docs = []KnowledgeDocument{}
}
return docs, rows.Err()
}
// UpdateDocumentChunkCount 更新文档的分块计数
func (s *KnowledgeStore) UpdateDocumentChunkCount(docID string, count int) error {
_, err := s.db.Exec(
`UPDATE knowledge_documents SET chunk_count = $1 WHERE id = $2`,
count, docID,
)
return err
}
// DeleteDocument 删除文档(级联删除块)
func (s *KnowledgeStore) DeleteDocument(id string) error {
// 先获取 kb_id 以便后续更新统计
var kbID string
err := s.db.QueryRow(`SELECT kb_id FROM knowledge_documents WHERE id = $1`, id).Scan(&kbID)
if err != nil {
if err == sql.ErrNoRows {
return nil
}
return fmt.Errorf("查询文档失败: %w", err)
}
_, err = s.db.Exec(`DELETE FROM knowledge_documents WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("删除文档失败: %w", err)
}
// 更新知识库统计
if err := s.updateKBStats(kbID); err != nil {
log.Printf("[KnowledgeStore] 更新知识库统计失败: %v", err)
}
return nil
}
// ========== 分块操作 ==========
// AddChunk 添加单个分块
func (s *KnowledgeStore) AddChunk(chunk *KnowledgeChunk) error {
if chunk.CreatedAt.IsZero() {
chunk.CreatedAt = time.Now()
}
// 尝试使用 to_tsvector('chinese', content) 设置 tsv
// 如果中文分词不可用,使用 simple 配置
_, err := s.db.Exec(
`INSERT INTO knowledge_chunks (id, doc_id, kb_id, chunk_index, content, token_count, tsv, created_at)
VALUES ($1, $2, $3, $4, $5, $6,
CASE WHEN (SELECT count(*) FROM pg_ts_config WHERE cfgname = 'chinese') > 0
THEN to_tsvector('chinese', $5)
ELSE to_tsvector('simple', $5)
END,
$7)`,
chunk.ID, chunk.DocID, chunk.KBID, chunk.ChunkIndex, chunk.Content, chunk.TokenCount, chunk.CreatedAt,
)
if err != nil {
// 降级:不使用 tsv
_, err = s.db.Exec(
`INSERT INTO knowledge_chunks (id, doc_id, kb_id, chunk_index, content, token_count, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
chunk.ID, chunk.DocID, chunk.KBID, chunk.ChunkIndex, chunk.Content, chunk.TokenCount, chunk.CreatedAt,
)
if err != nil {
return fmt.Errorf("添加分块失败: %w", err)
}
}
return nil
}
// DeleteChunksByDocID 删除文档的所有分块
func (s *KnowledgeStore) DeleteChunksByDocID(docID string) error {
_, err := s.db.Exec(`DELETE FROM knowledge_chunks WHERE doc_id = $1`, docID)
return err
}
// GetChunksByDocID 获取文档的所有分块
func (s *KnowledgeStore) GetChunksByDocID(docID string) ([]KnowledgeChunk, error) {
rows, err := s.db.Query(
`SELECT id, doc_id, kb_id, chunk_index, content, token_count, created_at
FROM knowledge_chunks WHERE doc_id = $1
ORDER BY chunk_index ASC`,
docID,
)
if err != nil {
return nil, fmt.Errorf("查询分块失败: %w", err)
}
defer rows.Close()
var chunks []KnowledgeChunk
for rows.Next() {
var c KnowledgeChunk
if err := rows.Scan(&c.ID, &c.DocID, &c.KBID, &c.ChunkIndex, &c.Content, &c.TokenCount, &c.CreatedAt); err != nil {
return nil, fmt.Errorf("扫描分块行失败: %w", err)
}
chunks = append(chunks, c)
}
if chunks == nil {
chunks = []KnowledgeChunk{}
}
return chunks, rows.Err()
}
// ========== 分块逻辑 ==========
// ChunkDocument 将文档分块并存储
func (s *KnowledgeStore) ChunkDocument(docID string) (int, error) {
// 获取文档
doc, err := s.GetDocument(docID)
if err != nil {
return 0, err
}
if doc == nil {
return 0, fmt.Errorf("文档不存在: %s", docID)
}
// 删除旧的分块
if err := s.DeleteChunksByDocID(docID); err != nil {
return 0, fmt.Errorf("删除旧分块失败: %w", err)
}
// 分块
chunks := splitTextIntoChunks(doc.RawContent, 500, 50)
// 存储分块
for i, content := range chunks {
chunk := &KnowledgeChunk{
ID: generateUUIDv4(),
DocID: docID,
KBID: doc.KBID,
ChunkIndex: i,
Content: content,
TokenCount: estimateTokenCount(content),
}
if err := s.AddChunk(chunk); err != nil {
return 0, fmt.Errorf("存储分块 %d 失败: %w", i, err)
}
}
// 更新文档的分块计数
if err := s.UpdateDocumentChunkCount(docID, len(chunks)); err != nil {
log.Printf("[KnowledgeStore] 更新文档分块计数失败: %v", err)
}
// 更新知识库统计
if err := s.updateKBStats(doc.KBID); err != nil {
log.Printf("[KnowledgeStore] 更新知识库统计失败: %v", err)
}
return len(chunks), nil
}
// ========== 搜索 ==========
// SearchChunks 在指定知识库中搜索
func (s *KnowledgeStore) SearchChunks(kbID, query string, limit int) ([]SearchChunkResult, error) {
if limit <= 0 {
limit = 5
}
// 尝试使用 PostgreSQL 全文搜索
results, err := s.searchWithFullText(kbID, query, limit)
if err != nil {
log.Printf("[KnowledgeStore] 全文搜索失败,降级为ILIKE: %v", err)
// 降级为 ILIKE
results, err = s.searchWithILike(kbID, query, limit)
if err != nil {
return nil, err
}
}
if results == nil {
results = []SearchChunkResult{}
}
return results, nil
}
// SearchAllKBs 在用户的所有知识库中搜索
func (s *KnowledgeStore) SearchAllKBs(userID, query string, limit int) ([]SearchChunkResult, error) {
if limit <= 0 {
limit = 5
}
results, err := s.searchAllWithILike(userID, query, limit)
if err != nil {
return nil, err
}
if results == nil {
results = []SearchChunkResult{}
}
return results, nil
}
// searchWithFullText 使用 PostgreSQL ts_rank + plainto_tsquery 搜索
func (s *KnowledgeStore) searchWithFullText(kbID, query string, limit int) ([]SearchChunkResult, error) {
rows, err := s.db.Query(
`SELECT kc.id, kc.doc_id, kc.kb_id, kc.chunk_index, kc.content, kc.token_count, kc.created_at,
ts_rank(kc.tsv, plainto_tsquery('chinese', $2)) AS relevance,
kd.title AS document_title,
kb.name AS kb_name
FROM knowledge_chunks kc
JOIN knowledge_documents kd ON kc.doc_id = kd.id
JOIN knowledge_bases kb ON kc.kb_id = kb.id
WHERE kc.kb_id = $1 AND kc.tsv @@ plainto_tsquery('chinese', $2)
ORDER BY relevance DESC
LIMIT $3`,
kbID, query, limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
return scanSearchResults(rows)
}
// searchWithILike 使用 ILIKE 降级搜索
func (s *KnowledgeStore) searchWithILike(kbID, query string, limit int) ([]SearchChunkResult, error) {
// 构建 ILIKE 模式
keywords := tokenizeQuery(query)
if len(keywords) == 0 {
return []SearchChunkResult{}, nil
}
// 对每个关键词构建 ILIKE 条件
conditions := make([]string, len(keywords))
args := []interface{}{kbID}
placeholderIdx := 2
for i, kw := range keywords {
conditions[i] = fmt.Sprintf("kc.content ILIKE $%d", placeholderIdx)
args = append(args, "%"+kw+"%")
placeholderIdx++
}
args = append(args, limit)
querySQL := fmt.Sprintf(
`SELECT kc.id, kc.doc_id, kc.kb_id, kc.chunk_index, kc.content, kc.token_count, kc.created_at,
0.0 AS relevance,
kd.title AS document_title,
kb.name AS kb_name
FROM knowledge_chunks kc
JOIN knowledge_documents kd ON kc.doc_id = kd.id
JOIN knowledge_bases kb ON kc.kb_id = kb.id
WHERE kc.kb_id = $1 AND (%s)
LIMIT $%d`,
strings.Join(conditions, " AND "),
placeholderIdx,
)
rows, err := s.db.Query(querySQL, args...)
if err != nil {
return nil, fmt.Errorf("ILIKE搜索失败: %w", err)
}
defer rows.Close()
return scanSearchResults(rows)
}
// searchAllWithILike 跨所有用户知识库使用 ILIKE 搜索
func (s *KnowledgeStore) searchAllWithILike(userID, query string, limit int) ([]SearchChunkResult, error) {
keywords := tokenizeQuery(query)
if len(keywords) == 0 {
return []SearchChunkResult{}, nil
}
conditions := make([]string, len(keywords))
args := []interface{}{userID}
placeholderIdx := 2
for i, kw := range keywords {
conditions[i] = fmt.Sprintf("kc.content ILIKE $%d", placeholderIdx)
args = append(args, "%"+kw+"%")
placeholderIdx++
}
args = append(args, limit)
querySQL := fmt.Sprintf(
`SELECT kc.id, kc.doc_id, kc.kb_id, kc.chunk_index, kc.content, kc.token_count, kc.created_at,
0.0 AS relevance,
kd.title AS document_title,
kb.name AS kb_name
FROM knowledge_chunks kc
JOIN knowledge_documents kd ON kc.doc_id = kd.id
JOIN knowledge_bases kb ON kc.kb_id = kb.id
WHERE kb.user_id = $1 AND (%s)
ORDER BY kc.created_at DESC
LIMIT $%d`,
strings.Join(conditions, " AND "),
placeholderIdx,
)
rows, err := s.db.Query(querySQL, args...)
if err != nil {
return nil, fmt.Errorf("全知识库ILIKE搜索失败: %w", err)
}
defer rows.Close()
return scanSearchResults(rows)
}
// scanSearchResults 扫描搜索结果
func scanSearchResults(rows *sql.Rows) ([]SearchChunkResult, error) {
var results []SearchChunkResult
for rows.Next() {
var r SearchChunkResult
if err := rows.Scan(&r.ID, &r.DocID, &r.KBID, &r.ChunkIndex, &r.Content,
&r.TokenCount, &r.CreatedAt, &r.Relevance, &r.DocumentTitle, &r.KBName); err != nil {
return nil, fmt.Errorf("扫描搜索结果行失败: %w", err)
}
// 生成高亮片段
r.Headline = r.Content
results = append(results, r)
}
if results == nil {
results = []SearchChunkResult{}
}
return results, rows.Err()
}
// ========== 文本分块函数 ==========
// splitTextIntoChunks 将文本按 maxLen 分块,块之间有 overlap 字符重叠
func splitTextIntoChunks(text string, maxLen int, overlap int) []string {
if text == "" {
return nil
}
// 按段落分割
paragraphs := strings.Split(text, "\n\n")
var chunks []string
var currentChunk strings.Builder
for _, para := range paragraphs {
para = strings.TrimSpace(para)
if para == "" {
continue
}
paraLen := utf8.RuneCountInString(para)
if paraLen <= maxLen {
// 如果当前块 + 段落不超过 maxLen,追加到当前块
if utf8.RuneCountInString(currentChunk.String()) == 0 {
currentChunk.WriteString(para)
} else if utf8.RuneCountInString(currentChunk.String())+1+paraLen <= maxLen {
currentChunk.WriteString("\n\n")
currentChunk.WriteString(para)
} else {
// 保存当前块,开始新块
chunks = append(chunks, currentChunk.String())
currentChunk.Reset()
currentChunk.WriteString(para)
}
} else {
// 段落超过 maxLen,需要按句子分割
// 先保存当前块
if currentChunk.Len() > 0 {
chunks = append(chunks, currentChunk.String())
currentChunk.Reset()
}
// 按句子分割
sentences := splitIntoSentences(para)
for _, sent := range sentences {
sent = strings.TrimSpace(sent)
if sent == "" {
continue
}
sentLen := utf8.RuneCountInString(sent)
if sentLen <= maxLen {
if utf8.RuneCountInString(currentChunk.String()) == 0 {
currentChunk.WriteString(sent)
} else if utf8.RuneCountInString(currentChunk.String())+sentLen <= maxLen {
currentChunk.WriteString(sent)
} else {
chunks = append(chunks, currentChunk.String())
currentChunk.Reset()
currentChunk.WriteString(sent)
}
} else {
// 句子超过 maxLen,按 maxLen 截断
if currentChunk.Len() > 0 {
chunks = append(chunks, currentChunk.String())
currentChunk.Reset()
}
// 按 maxLen 截断,带 overlap
runes := []rune(sent)
start := 0
for start < len(runes) {
end := start + maxLen
if end > len(runes) {
end = len(runes)
}
chunks = append(chunks, string(runes[start:end]))
if end >= len(runes) {
break
}
// 下一块从 end-overlap 开始
start = end - overlap
if start <= 0 {
start = end
}
}
}
}
}
}
// 保存最后一个块
if currentChunk.Len() > 0 {
chunks = append(chunks, currentChunk.String())
}
return chunks
}
// splitIntoSentences 按句子分割文本(中文。!?和英文标点)
func splitIntoSentences(text string) []string {
var sentences []string
runes := []rune(text)
var current strings.Builder
for i := 0; i < len(runes); i++ {
current.WriteRune(runes[i])
// 检查句子结束标志
if runes[i] == '。' || runes[i] == '' || runes[i] == '' ||
runes[i] == '!' || runes[i] == '?' ||
(runes[i] == '\n' && i+1 < len(runes) && runes[i+1] != '\n') {
sentences = append(sentences, current.String())
current.Reset()
}
}
// 剩余内容
if current.Len() > 0 {
remaining := strings.TrimSpace(current.String())
if remaining != "" {
sentences = append(sentences, remaining)
}
}
return sentences
}
// estimateTokenCount 估算 token 数量(中文按每个字符1.2个token,英文按每4个字符1个token
func estimateTokenCount(text string) int {
runes := []rune(text)
total := 0
for _, r := range runes {
if r >= 0x4e00 && r <= 0x9fff {
// 中文字符,约1.2个token
total += 1
}
}
// 非中文字符粗略估算:字符数/4
nonChinese := len(runes) - total
total = int(float64(total)*1.2) + nonChinese/4
if total < 1 {
total = 1
}
return total
}
// tokenizeQuery 将查询字符串分词(简单按空格和标点分割)
func tokenizeQuery(query string) []string {
// 按空格、中文标点、英文标点分割
query = strings.TrimSpace(query)
if query == "" {
return nil
}
// 先用空格分割
parts := strings.Fields(query)
var tokens []string
for _, part := range parts {
part = strings.Trim(part, "。!?!?,.;::;、()()[]{}《》\"'")
if part != "" {
tokens = append(tokens, part)
}
}
return tokens
}
// GenerateUUID 使用 crypto/rand 生成 UUID v4 格式的字符串(导出供其他包使用)
func GenerateUUID() string {
return generateUUIDv4()
}
// generateUUIDv4 使用 crypto/rand 生成 UUID v4 格式的字符串
func generateUUIDv4() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
// 降级方案:基于时间戳 + 简单随机
ts := time.Now().UnixNano()
for i := 0; i < 16; i++ {
b[i] = byte((ts >> (i * 4)) & 0xFF)
}
}
// 设置 UUID v4 版本位 (version = 4)
b[6] = (b[6] & 0x0f) | 0x40
// 设置 UUID variant 位 (variant = 10xx)
b[8] = (b[8] & 0x3f) | 0x80
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
}
@@ -0,0 +1,195 @@
package store
import (
"database/sql"
"fmt"
"log"
"time"
)
// Reminder 提醒模型
type Reminder struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Title string `json:"title"`
Description string `json:"description"`
RemindAt time.Time `json:"remind_at"`
Status string `json:"status"` // pending, completed, cancelled
CreatedAt time.Time `json:"created_at"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
RepeatType string `json:"repeat_type"` // none, daily, weekly, monthly
SessionID string `json:"session_id"`
Notified bool `json:"notified"`
}
// ReminderStore 提醒持久化存储
type ReminderStore struct {
db *sql.DB
}
// NewReminderStore 使用已有数据库连接初始化提醒存储并自动建表
func NewReminderStore(db *sql.DB) (*ReminderStore, error) {
store := &ReminderStore{db: db}
if err := store.migrate(); err != nil {
return nil, fmt.Errorf("提醒表迁移失败: %w", err)
}
log.Println("[ReminderStore] 提醒持久化存储已初始化")
return store, nil
}
// migrate 自动创建提醒表结构
func (s *ReminderStore) migrate() error {
queries := []string{
`CREATE TABLE IF NOT EXISTS reminders (
id VARCHAR(36) PRIMARY KEY,
user_id VARCHAR(255) NOT NULL,
title VARCHAR(500) NOT NULL,
description TEXT DEFAULT '',
remind_at TIMESTAMPTZ NOT NULL,
status VARCHAR(20) DEFAULT 'pending',
created_at TIMESTAMPTZ DEFAULT NOW(),
completed_at TIMESTAMPTZ,
repeat_type VARCHAR(20) DEFAULT '',
session_id VARCHAR(36) DEFAULT '',
notified BOOLEAN DEFAULT FALSE
)`,
`CREATE INDEX IF NOT EXISTS idx_reminders_user_id ON reminders(user_id)`,
`CREATE INDEX IF NOT EXISTS idx_reminders_remind_at ON reminders(remind_at)`,
`CREATE INDEX IF NOT EXISTS idx_reminders_status ON reminders(status)`,
`CREATE INDEX IF NOT EXISTS idx_reminders_due ON reminders(remind_at, status, notified)`,
}
for _, q := range queries {
if _, err := s.db.Exec(q); err != nil {
return fmt.Errorf("迁移SQL执行失败: %w\nSQL: %s", err, q)
}
}
return nil
}
// CreateReminder 创建新提醒
func (s *ReminderStore) CreateReminder(r *Reminder) error {
_, err := s.db.Exec(
`INSERT INTO reminders (id, user_id, title, description, remind_at, status, repeat_type, session_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
r.ID, r.UserID, r.Title, r.Description, r.RemindAt, r.Status, r.RepeatType, r.SessionID,
)
if err != nil {
return fmt.Errorf("创建提醒失败: %w", err)
}
return nil
}
// GetRemindersByUser 获取用户的提醒列表(可按状态筛选,按 remind_at 升序)
func (s *ReminderStore) GetRemindersByUser(userID, status string, limit, offset int) ([]Reminder, error) {
if limit <= 0 {
limit = 50
}
if offset < 0 {
offset = 0
}
var rows *sql.Rows
var err error
if status != "" {
rows, err = s.db.Query(
`SELECT id, user_id, title, description, remind_at, status, created_at, completed_at, repeat_type, session_id, notified
FROM reminders WHERE user_id = $1 AND status = $2
ORDER BY remind_at ASC LIMIT $3 OFFSET $4`,
userID, status, limit, offset,
)
} else {
rows, err = s.db.Query(
`SELECT id, user_id, title, description, remind_at, status, created_at, completed_at, repeat_type, session_id, notified
FROM reminders WHERE user_id = $1
ORDER BY remind_at ASC LIMIT $2 OFFSET $3`,
userID, limit, offset,
)
}
if err != nil {
return nil, fmt.Errorf("查询提醒列表失败: %w", err)
}
defer rows.Close()
var reminders []Reminder
for rows.Next() {
var r Reminder
if err := rows.Scan(&r.ID, &r.UserID, &r.Title, &r.Description, &r.RemindAt,
&r.Status, &r.CreatedAt, &r.CompletedAt, &r.RepeatType, &r.SessionID, &r.Notified); err != nil {
return nil, fmt.Errorf("扫描提醒行失败: %w", err)
}
reminders = append(reminders, r)
}
if reminders == nil {
reminders = []Reminder{}
}
return reminders, rows.Err()
}
// GetDueReminders 获取所有到期且未通知的提醒
func (s *ReminderStore) GetDueReminders() ([]Reminder, error) {
rows, err := s.db.Query(
`SELECT id, user_id, title, description, remind_at, status, created_at, completed_at, repeat_type, session_id, notified
FROM reminders
WHERE remind_at <= NOW() AND status = 'pending' AND notified = FALSE
ORDER BY remind_at ASC`,
)
if err != nil {
return nil, fmt.Errorf("查询到期提醒失败: %w", err)
}
defer rows.Close()
var reminders []Reminder
for rows.Next() {
var r Reminder
if err := rows.Scan(&r.ID, &r.UserID, &r.Title, &r.Description, &r.RemindAt,
&r.Status, &r.CreatedAt, &r.CompletedAt, &r.RepeatType, &r.SessionID, &r.Notified); err != nil {
return nil, fmt.Errorf("扫描到期提醒行失败: %w", err)
}
reminders = append(reminders, r)
}
if reminders == nil {
reminders = []Reminder{}
}
return reminders, rows.Err()
}
// MarkNotified 标记提醒为已通知
func (s *ReminderStore) MarkNotified(id string) error {
_, err := s.db.Exec(
`UPDATE reminders SET notified = TRUE WHERE id = $1`,
id,
)
if err != nil {
return fmt.Errorf("标记提醒已通知失败: %w", err)
}
return nil
}
// UpdateReminder 更新提醒字段
func (s *ReminderStore) UpdateReminder(id string, r *Reminder) error {
_, err := s.db.Exec(
`UPDATE reminders SET title = $1, description = $2, remind_at = $3, status = $4,
completed_at = $5, repeat_type = $6, session_id = $7, notified = $8
WHERE id = $9`,
r.Title, r.Description, r.RemindAt, r.Status, r.CompletedAt, r.RepeatType, r.SessionID, r.Notified, id,
)
if err != nil {
return fmt.Errorf("更新提醒失败: %w", err)
}
return nil
}
// DeleteReminder 删除提醒
func (s *ReminderStore) DeleteReminder(id string) error {
_, err := s.db.Exec(`DELETE FROM reminders WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("删除提醒失败: %w", err)
}
return nil
}
@@ -253,6 +253,70 @@ func (s *SessionStore) ClearSessionMessages(sessionID string) error {
return nil
}
// SearchResult 搜索结果
type SearchResult struct {
MessageID int `json:"message_id"`
SessionID string `json:"session_id"`
SessionTitle string `json:"session_title"`
Role string `json:"role"`
Content string `json:"content"`
CreatedAt time.Time `json:"created_at"`
}
// SearchMessages 全文搜索消息 (使用 ILIKE 进行模糊匹配)
// 返回搜索结果列表、总数和可能的错误
func (s *SessionStore) SearchMessages(userID, query string, limit, offset int) ([]SearchResult, int, error) {
if limit <= 0 {
limit = 50
}
if offset < 0 {
offset = 0
}
// 获取匹配总数
var total int
countSQL := `SELECT COUNT(*) FROM messages m
JOIN sessions s ON m.session_id = s.id
WHERE s.user_id = $1 AND m.content ILIKE '%' || $2 || '%'`
if err := s.db.QueryRow(countSQL, userID, query).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("搜索计数失败: %w", err)
}
// 分页查询,关联 sessions 获取会话标题
rows, err := s.db.Query(
`SELECT m.id, m.session_id, COALESCE(s.title, '') AS session_title, m.role, m.content, m.created_at
FROM messages m
JOIN sessions s ON m.session_id = s.id
WHERE s.user_id = $1 AND m.content ILIKE '%' || $2 || '%'
ORDER BY m.created_at DESC
LIMIT $3 OFFSET $4`,
userID, query, limit, offset,
)
if err != nil {
return nil, 0, fmt.Errorf("搜索消息失败: %w", err)
}
defer rows.Close()
var results []SearchResult
for rows.Next() {
var r SearchResult
if err := rows.Scan(&r.MessageID, &r.SessionID, &r.SessionTitle, &r.Role, &r.Content, &r.CreatedAt); err != nil {
return nil, 0, fmt.Errorf("扫描搜索结果行失败: %w", err)
}
results = append(results, r)
}
if results == nil {
results = []SearchResult{}
}
return results, total, rows.Err()
}
// DB 返回底层数据库连接,供其他 store 复用
func (s *SessionStore) DB() *sql.DB {
return s.db
}
// Close 关闭数据库连接
func (s *SessionStore) Close() error {
if s.db != nil {
+5 -4
View File
@@ -32,10 +32,11 @@ type SessionMessage struct {
// Message 完整对话消息(用于缓存)
type Message struct {
ID string `json:"id"`
Role string `json:"role"`
Content string `json:"content"`
Timestamp int64 `json:"timestamp"`
ID string `json:"id"`
Role string `json:"role"`
Content string `json:"content"`
Attachments []MessageAttachment `json:"attachments,omitempty"`
Timestamp int64 `json:"timestamp"`
}
const maxRecentMessages = 20
+45 -21
View File
@@ -1,32 +1,56 @@
package ws
// MessageAttachment 消息附件 (图片等)
type MessageAttachment struct {
Type string `json:"type"` // image
URL string `json:"url"` // 图片 URL 或 data URL
ThumbnailURL string `json:"thumbnail_url,omitempty"`
Filename string `json:"filename,omitempty"`
Width int `json:"width,omitempty"`
Height int `json:"height,omitempty"`
Size int64 `json:"size,omitempty"` // 文件大小 bytes
Description string `json:"description,omitempty"` // AI 对图片的描述
}
// 客户端 → 服务端消息
type ClientMessage struct {
Type string `json:"type"` // message | voice_input | ping | history
SessionID string `json:"session_id"`
Mode string `json:"mode"` // text | voice_msg | voice_assistant
Content string `json:"content"`
AudioData string `json:"audio_data,omitempty"` // base64
Timestamp int64 `json:"timestamp"`
Type string `json:"type"` // message | voice_input | ping | history
SessionID string `json:"session_id"`
Mode string `json:"mode"` // text | voice_msg | voice_assistant
Content string `json:"content"`
AudioData string `json:"audio_data,omitempty"` // base64
Attachments []MessageAttachment `json:"attachments,omitempty"` // 图片等附件
Timestamp int64 `json:"timestamp"`
}
// 服务端 → 客户端消息
type ServerMessage struct {
Type string `json:"type"` // response | segment | audio | error | device_update | pong | history_response | stream_chunk | stream_end | background_thinking
MessageID string `json:"message_id"`
Text string `json:"text,omitempty"`
Content string `json:"content,omitempty"` // stream_chunk 的增量文本
Role string `json:"role,omitempty"` // stream 消息的角色
SessionID string `json:"session_id,omitempty"` // 会话 ID
Segments []VoiceSegment `json:"segments,omitempty"` // 断句数组
FullAudioURL string `json:"full_audio_url,omitempty"`
ResponseMode string `json:"response_mode"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Error string `json:"error,omitempty"`
Timestamp int64 `json:"timestamp"`
Messages []Message `json:"messages,omitempty"` // 历史消息列表
Devices []IotDeviceInfo `json:"devices,omitempty"` // IoT 设备状态
ThinkingStatus string `json:"thinking_status,omitempty"` // 后台思考状态
Type string `json:"type"` // response | segment | audio | error | device_update | pong | history_response | stream_chunk | stream_end | background_thinking | notification
MessageID string `json:"message_id"`
Text string `json:"text,omitempty"`
Content string `json:"content,omitempty"` // stream_chunk 的增量文本
Role string `json:"role,omitempty"` // stream 消息的角色
SessionID string `json:"session_id,omitempty"` // 会话 ID
Segments []VoiceSegment `json:"segments,omitempty"` // 断句数组
FullAudioURL string `json:"full_audio_url,omitempty"`
ResponseMode string `json:"response_mode"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Error string `json:"error,omitempty"`
Timestamp int64 `json:"timestamp"`
Messages []Message `json:"messages,omitempty"` // 历史消息列表
Devices []IotDeviceInfo `json:"devices,omitempty"` // IoT 设备状态
ThinkingStatus string `json:"thinking_status,omitempty"` // 后台思考状态
Notification *NotificationInfo `json:"notification,omitempty"` // 通知推送
}
// NotificationInfo 通知推送信息
type NotificationInfo struct {
ID string `json:"id"`
Type string `json:"type"` // info | warning | success | thinking | reminder
Title string `json:"title"`
Body string `json:"body"`
Timestamp string `json:"timestamp"`
Data map[string]interface{} `json:"data,omitempty"`
}
// IotDeviceInfo IoT 设备信息(用于 WebSocket 推送)
+1
View File
@@ -6,4 +6,5 @@ use (
./iot-debug-service
./memory-service
./tool-engine
./voice-service
)
+76
View File
@@ -0,0 +1,76 @@
package main
import (
"log"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/yourname/cyrene-ai/voice-service/internal/config"
"github.com/yourname/cyrene-ai/voice-service/internal/handler"
"github.com/yourname/cyrene-ai/voice-service/internal/service"
)
func main() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
log.Println("🎤 Voice-Service (STT + TTS) 启动中...")
// 加载配置
cfg := config.Load()
log.Printf("配置: 端口=%s, WhisperBinary=%s, WhisperModel=%s, Language=%s",
cfg.Port, cfg.WhisperBinary, cfg.WhisperModel, cfg.WhisperLanguage)
// 初始化 STT 服务
sttSvc := service.NewSTTService(cfg)
// 检查 whisper 引擎是否可用
if !sttSvc.IsAvailable() {
log.Printf("⚠️ Whisper 引擎未安装 (%s)STT 功能不可用", cfg.WhisperBinary)
log.Printf(" 请运行: bash scripts/setup-whisper.sh")
} else {
log.Println("✅ Whisper 引擎已就绪")
}
// 初始化 TTS 服务
ttsSvc := service.NewTTSService()
if !ttsSvc.IsAvailable() {
log.Println("⚠️ TTS 引擎不可用 (请安装: pip install edge-tts)")
} else {
ttsStatus := ttsSvc.GetEngineStatus()
log.Printf("✅ TTS 引擎已就绪 (引擎: %s)", ttsStatus["engine"])
}
// 初始化 HTTP 处理器
sttHandler := handler.NewSTTHandler(sttSvc, cfg)
sttHandler.SetTTSService(ttsSvc)
ttsHandler := handler.NewTTSHandler(ttsSvc)
// 注册路由
mux := http.NewServeMux()
sttHandler.RegisterRoutes(mux)
ttsHandler.RegisterRoutes(mux)
// 启动 HTTP 服务
srv := &http.Server{
Addr: ":" + cfg.Port,
Handler: mux,
}
go func() {
log.Printf("🚀 Voice-Service 已启动在端口 %s", cfg.Port)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("服务启动失败: %v", err)
}
}()
// 优雅关闭
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("正在关闭 Voice-Service...")
srv.Close()
log.Println("Voice-Service 已关闭")
}
+3
View File
@@ -0,0 +1,3 @@
module github.com/yourname/cyrene-ai/voice-service
go 1.26.2
@@ -0,0 +1,30 @@
package config
import "os"
// Config STT 语音识别服务配置
type Config struct {
Port string
WhisperBinary string
WhisperModel string
WhisperLanguage string
MaxAudioSize int64 // 字节
}
// Load 从环境变量加载配置
func Load() *Config {
return &Config{
Port: getEnv("PORT", "8093"),
WhisperBinary: getEnv("WHISPER_BINARY", "./whisper.cpp/main"),
WhisperModel: getEnv("WHISPER_MODEL", "./whisper.cpp/models/ggml-small.bin"),
WhisperLanguage: getEnv("WHISPER_LANGUAGE", "zh"),
MaxAudioSize: 10 * 1024 * 1024, // 10MB
}
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}
@@ -0,0 +1,201 @@
package handler
import (
"encoding/json"
"io"
"log"
"net/http"
"path/filepath"
"strings"
"time"
"github.com/yourname/cyrene-ai/voice-service/internal/config"
"github.com/yourname/cyrene-ai/voice-service/internal/service"
)
// STTHandler HTTP API 处理器
type STTHandler struct {
svc *service.STTService
ttsSvc *service.TTSService
cfg *config.Config
}
// NewSTTHandler 创建 STT 处理器(可选传入 TTSService 用于组合状态)
func NewSTTHandler(svc *service.STTService, cfg *config.Config) *STTHandler {
return &STTHandler{svc: svc, cfg: cfg}
}
// SetTTSService 设置 TTS 服务引用,用于组合状态端点
func (h *STTHandler) SetTTSService(ttsSvc *service.TTSService) {
h.ttsSvc = ttsSvc
}
// RegisterRoutes 注册所有路由到 mux
func (h *STTHandler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/api/v1/transcribe", h.handleTranscribe)
mux.HandleFunc("/api/v1/health", h.handleHealth)
mux.HandleFunc("/api/v1/status", h.handleStatus)
}
// handleTranscribe POST /api/v1/transcribe
// 接受 multipart/form-data,字段 audio (文件) 和 language (可选)
func (h *STTHandler) handleTranscribe(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
// 限制上传大小
r.Body = http.MaxBytesReader(w, r.Body, h.cfg.MaxAudioSize)
if err := r.ParseMultipartForm(h.cfg.MaxAudioSize); err != nil {
writeError(w, http.StatusBadRequest, "文件过大或解析失败,最大支持 10MB")
return
}
// 获取上传的文件
file, header, err := r.FormFile("audio")
if err != nil {
writeError(w, http.StatusBadRequest, "缺少 audio 文件字段")
return
}
defer file.Close()
// 读取文件内容
audioData, err := io.ReadAll(file)
if err != nil {
writeError(w, http.StatusInternalServerError, "读取音频文件失败")
return
}
if len(audioData) == 0 {
writeError(w, http.StatusBadRequest, "音频文件为空")
return
}
// 获取语言参数 (可选)
language := r.FormValue("language")
// 推断音频格式
format := inferFormat(header.Filename)
if !isSupportedFormat(format) {
writeError(w, http.StatusBadRequest, "不支持的音频格式: "+format+",支持的格式: WAV, MP3, OGG, FLAC, M4A")
return
}
// 执行转录
startTime := time.Now()
text, err := h.svc.Transcribe(audioData, format, language)
durationMs := time.Since(startTime).Milliseconds()
if err != nil {
log.Printf("[stt-handler] 转录失败: %v", err)
writeJSON(w, http.StatusInternalServerError, map[string]interface{}{
"success": false,
"error": err.Error(),
})
return
}
actualLang := language
if actualLang == "" {
actualLang = h.cfg.WhisperLanguage
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"success": true,
"text": text,
"language": actualLang,
"duration_ms": durationMs,
})
}
// handleHealth GET /api/v1/health
func (h *STTHandler) handleHealth(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
sttStatus := h.svc.GetStatus()
healthStatus := "ok"
if !sttStatus["available"].(bool) {
healthStatus = "degraded"
}
resp := map[string]interface{}{
"status": healthStatus,
"service": "voice-service",
"stt": sttStatus,
}
// 如果有 TTS 服务,也包含 TTS 状态
if h.ttsSvc != nil {
resp["tts"] = h.ttsSvc.GetEngineStatus()
}
writeJSON(w, http.StatusOK, resp)
}
// handleStatus GET /api/v1/status
func (h *STTHandler) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
resp := map[string]interface{}{
"service": "voice-service",
"stt": h.svc.GetStatus(),
}
// 如果有 TTS 服务,也包含 TTS 状态
if h.ttsSvc != nil {
resp["tts"] = h.ttsSvc.GetEngineStatus()
}
writeJSON(w, http.StatusOK, resp)
}
// --- 辅助函数 ---
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func writeError(w http.ResponseWriter, status int, message string) {
writeJSON(w, status, map[string]interface{}{
"error": message,
})
}
// inferFormat 根据文件名推断音频格式
func inferFormat(filename string) string {
ext := strings.ToLower(filepath.Ext(filename))
switch ext {
case ".wav", ".wave":
return "wav"
case ".mp3", ".mpeg":
return "mp3"
case ".ogg", ".opus":
return "ogg"
case ".flac":
return "flac"
case ".m4a", ".mp4", ".aac":
return "m4a"
default:
return ext
}
}
// isSupportedFormat 检查是否支持的音频格式
func isSupportedFormat(format string) bool {
switch format {
case "wav", "mp3", "ogg", "flac", "m4a":
return true
default:
return false
}
}
@@ -0,0 +1,117 @@
package handler
import (
"encoding/json"
"log"
"net/http"
"github.com/yourname/cyrene-ai/voice-service/internal/service"
)
// TTSHandler TTS HTTP API 处理器
type TTSHandler struct {
svc *service.TTSService
}
// NewTTSHandler 创建 TTS 处理器
func NewTTSHandler(svc *service.TTSService) *TTSHandler {
return &TTSHandler{svc: svc}
}
// RegisterRoutes 注册 TTS 路由
func (h *TTSHandler) RegisterRoutes(mux *http.ServeMux) {
mux.HandleFunc("/api/v1/tts/synthesize", h.handleSynthesize)
mux.HandleFunc("/api/v1/tts/voices", h.handleVoices)
mux.HandleFunc("/api/v1/tts/status", h.handleStatus)
}
// TTSSynthesizeRequest TTS 合成请求体
type TTSSynthesizeRequest struct {
Text string `json:"text"`
Voice string `json:"voice"`
Rate string `json:"rate"`
}
// handleSynthesize POST /api/v1/tts/synthesize
func (h *TTSHandler) handleSynthesize(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
// 解析 JSON 请求体
var req TTSSynthesizeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "请求体解析失败: "+err.Error())
return
}
if req.Text == "" {
writeError(w, http.StatusBadRequest, "text 字段不能为空")
return
}
// 检查 TTS 引擎是否可用
if !h.svc.IsAvailable() {
log.Printf("[tts-handler] TTS 引擎不可用")
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusServiceUnavailable)
json.NewEncoder(w).Encode(map[string]interface{}{
"error": "TTS 引擎不可用,请安装 edge-tts (pip install edge-tts) 或 espeak-ng",
"code": "TTS_UNAVAILABLE",
"install": "pip install edge-tts",
})
return
}
// 调用合成
audioData, format, err := h.svc.Synthesize(req.Text, req.Voice, req.Rate)
if err != nil {
log.Printf("[tts-handler] TTS 合成失败: %v", err)
writeError(w, http.StatusInternalServerError, "TTS 合成失败: "+err.Error())
return
}
// 返回音频流
contentType := "audio/mpeg"
if format == "wav" {
contentType = "audio/wav"
}
w.Header().Set("Content-Type", contentType)
w.Header().Set("Content-Disposition", "inline; filename=synthesized."+format)
w.WriteHeader(http.StatusOK)
w.Write(audioData)
}
// handleVoices GET /api/v1/tts/voices
func (h *TTSHandler) handleVoices(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
voices := h.svc.GetVoices()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"voices": voices,
"count": len(voices),
})
}
// handleStatus GET /api/v1/tts/status
func (h *TTSHandler) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
status := h.svc.GetEngineStatus()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"service": "voice-service",
"tts": status,
})
}
@@ -0,0 +1,175 @@
package service
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/yourname/cyrene-ai/voice-service/internal/config"
)
// SupportedLanguages STT 支持的语言列表
var SupportedLanguages = []string{"zh", "en", "ja", "ko", "auto"}
// STTService 语音转文字服务
type STTService struct {
whisperBinary string
whisperModel string
language string
}
// NewSTTService 创建 STT 服务
func NewSTTService(cfg *config.Config) *STTService {
return &STTService{
whisperBinary: cfg.WhisperBinary,
whisperModel: cfg.WhisperModel,
language: cfg.WhisperLanguage,
}
}
// IsAvailable 检查 whisper binary 是否存在
func (s *STTService) IsAvailable() bool {
_, err := os.Stat(s.whisperBinary)
return err == nil
}
// Transcribe 将音频数据转录为文字
// audioData: 音频文件的二进制数据
// format: 音频格式 (wav, mp3, ogg, flac, m4a)
// language: 转录语言 (zh, en, ja, ko, auto),为空则使用默认语言
func (s *STTService) Transcribe(audioData []byte, format string, language string) (string, error) {
if !s.IsAvailable() {
return "", fmt.Errorf("STT 引擎未安装,请运行 scripts/setup-whisper.sh")
}
// 如果未指定语言,使用默认语言
if language == "" {
language = s.language
}
// 验证语言是否支持
if !isSupportedLanguage(language) {
return "", fmt.Errorf("不支持的语言: %s,支持的语言: %s", language, strings.Join(SupportedLanguages, ", "))
}
// 将音频数据写入临时文件
ext := normalizeExt(format)
tmpFile, err := os.CreateTemp("/tmp", "cyrene-stt-*"+ext)
if err != nil {
return "", fmt.Errorf("创建临时文件失败: %w", err)
}
tmpPath := tmpFile.Name()
defer os.Remove(tmpPath)
if _, err := tmpFile.Write(audioData); err != nil {
tmpFile.Close()
return "", fmt.Errorf("写入临时文件失败: %w", err)
}
tmpFile.Close()
// 如果不是 WAV 格式,尝试用 ffmpeg 转换
inputPath := tmpPath
if format != "wav" && format != "" {
convertedPath := tmpPath + ".wav"
if err := convertToWav(tmpPath, convertedPath); err == nil {
defer os.Remove(convertedPath)
inputPath = convertedPath
}
// 转换失败则仍使用原始文件(whisper.cpp 也支持其他格式)
}
// 调用 whisper.cpp
outputTxt := inputPath + ".txt"
cmd := exec.Command(s.whisperBinary,
"-m", s.whisperModel,
"-l", language,
"-f", inputPath,
"-otxt",
"-of", strings.TrimSuffix(inputPath, filepath.Ext(inputPath)),
)
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
os.Remove(outputTxt)
return "", fmt.Errorf("whisper 转录失败: %w", err)
}
// 读取输出文本
defer os.Remove(outputTxt)
txtData, err := os.ReadFile(outputTxt)
if err != nil {
return "", fmt.Errorf("读取转录结果失败: %w", err)
}
text := strings.TrimSpace(string(txtData))
return text, nil
}
// GetStatus 返回服务状态
func (s *STTService) GetStatus() map[string]interface{} {
binaryAvailable := s.IsAvailable()
modelExists := false
if _, err := os.Stat(s.whisperModel); err == nil {
modelExists = true
}
modelName := filepath.Base(s.whisperModel)
return map[string]interface{}{
"available": binaryAvailable && modelExists,
"binary_available": binaryAvailable,
"model_loaded": modelExists,
"binary_path": s.whisperBinary,
"model_path": s.whisperModel,
"model_name": modelName,
"default_language": s.language,
"supported_languages": SupportedLanguages,
}
}
// normalizeExt 规范化文件扩展名
func normalizeExt(format string) string {
switch strings.ToLower(format) {
case "wav":
return ".wav"
case "mp3", "mpeg":
return ".mp3"
case "ogg", "opus":
return ".ogg"
case "flac":
return ".flac"
case "m4a", "mp4", "aac":
return ".m4a"
default:
return ".wav"
}
}
// isSupportedLanguage 检查语言是否支持
func isSupportedLanguage(lang string) bool {
for _, l := range SupportedLanguages {
if l == lang {
return true
}
}
return false
}
// convertToWav 使用 ffmpeg 将音频转换为 WAV 格式
func convertToWav(inputPath, outputPath string) error {
cmd := exec.Command("ffmpeg",
"-i", inputPath,
"-ar", "16000",
"-ac", "1",
"-c:a", "pcm_s16le",
outputPath,
"-y",
)
cmd.Stderr = nil
return cmd.Run()
}
@@ -0,0 +1,294 @@
package service
import (
"fmt"
"os"
"os/exec"
"strings"
)
// TTSVoice 表示一个可用的 TTS 语音
type TTSVoice struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
Gender string `json:"gender"`
Locale string `json:"locale"`
}
// BuiltinVoices 内置的 edge-tts 中文语音列表
var BuiltinVoices = []TTSVoice{
{Name: "zh-CN-XiaoxiaoNeural", DisplayName: "晓晓 (女声)", Gender: "Female", Locale: "zh-CN"},
{Name: "zh-CN-YunxiNeural", DisplayName: "云希 (男声)", Gender: "Male", Locale: "zh-CN"},
{Name: "zh-CN-XiaoyiNeural", DisplayName: "晓伊 (女声)", Gender: "Female", Locale: "zh-CN"},
}
// TTSService 文字转语音服务
type TTSService struct{}
// NewTTSService 创建 TTS 服务
func NewTTSService() *TTSService {
return &TTSService{}
}
// IsAvailable 检查 TTS 引擎是否可用
// 优先级: edge-tts > espeak-ng > 纯 Go fallback
func (s *TTSService) IsAvailable() bool {
return s.edgeTTSAvailable() || s.espeakAvailable()
}
// edgeTTSAvailable 检查 edge-tts 是否可用
func (s *TTSService) edgeTTSAvailable() bool {
_, err := exec.LookPath("edge-tts")
return err == nil
}
// espeakAvailable 检查 espeak-ng 是否可用
func (s *TTSService) espeakAvailable() bool {
_, err := exec.LookPath("espeak-ng")
return err == nil
}
// Synthesize 将文字合成为音频
// text: 要合成的文字
// voice: 语音名称 (zh-CN-XiaoxiaoNeural 等)
// rate: 语速调整 ("+0%", "+20%", "-20%" 等)
// 返回: 音频数据, 音频格式 (mp3/wav), 错误
func (s *TTSService) Synthesize(text string, voice string, rate string) ([]byte, string, error) {
if text == "" {
return nil, "", fmt.Errorf("文字内容为空")
}
// 方案 A: edge-tts (推荐)
if s.edgeTTSAvailable() {
return s.synthesizeEdgeTTS(text, voice, rate)
}
// 方案 B: espeak-ng
if s.espeakAvailable() {
return s.synthesizeEspeak(text, voice)
}
// 方案 C: 纯 Go fallback
return s.synthesizeFallback()
}
// synthesizeEdgeTTS 使用 edge-tts 合成语音
func (s *TTSService) synthesizeEdgeTTS(text string, voice string, rate string) ([]byte, string, error) {
if voice == "" {
voice = "zh-CN-XiaoxiaoNeural"
}
if rate == "" {
rate = "+0%"
}
// 写入文本到临时文件
tmpText, err := os.CreateTemp("/tmp", "cyrene-tts-text-*.txt")
if err != nil {
return nil, "", fmt.Errorf("创建临时文本文件失败: %w", err)
}
tmpTextPath := tmpText.Name()
defer os.Remove(tmpTextPath)
if _, err := tmpText.WriteString(text); err != nil {
tmpText.Close()
return nil, "", fmt.Errorf("写入临时文本失败: %w", err)
}
tmpText.Close()
// 输出音频文件
tmpOutput, err := os.CreateTemp("/tmp", "cyrene-tts-output-*.mp3")
if err != nil {
return nil, "", fmt.Errorf("创建临时输出文件失败: %w", err)
}
tmpOutputPath := tmpOutput.Name()
tmpOutput.Close()
defer os.Remove(tmpOutputPath)
// 构建 edge-tts 命令
cmd := exec.Command("edge-tts",
"--voice", voice,
"--rate="+rate,
"--text", text,
"--write-media", tmpOutputPath,
)
output, err := cmd.CombinedOutput()
if err != nil {
return nil, "", fmt.Errorf("edge-tts 合成失败: %w\n输出: %s", err, string(output))
}
// 读取生成的音频
audioData, err := os.ReadFile(tmpOutputPath)
if err != nil {
return nil, "", fmt.Errorf("读取合成的音频失败: %w", err)
}
if len(audioData) == 0 {
return nil, "", fmt.Errorf("edge-tts 生成的音频为空")
}
return audioData, "mp3", nil
}
// synthesizeEspeak 使用 espeak-ng 合成语音
func (s *TTSService) synthesizeEspeak(text string, voice string) ([]byte, string, error) {
if voice == "" {
voice = "zh"
}
// 输出 WAV 文件
tmpOutput, err := os.CreateTemp("/tmp", "cyrene-tts-espeak-*.wav")
if err != nil {
return nil, "", fmt.Errorf("创建临时输出文件失败: %w", err)
}
tmpOutputPath := tmpOutput.Name()
tmpOutput.Close()
defer os.Remove(tmpOutputPath)
cmd := exec.Command("espeak-ng",
"-v", voice,
"-w", tmpOutputPath,
text,
)
output, err := cmd.CombinedOutput()
if err != nil {
return nil, "", fmt.Errorf("espeak-ng 合成失败: %w\n输出: %s", err, string(output))
}
audioData, err := os.ReadFile(tmpOutputPath)
if err != nil {
return nil, "", fmt.Errorf("读取合成的音频失败: %w", err)
}
if len(audioData) == 0 {
return nil, "", fmt.Errorf("espeak-ng 生成的音频为空")
}
return audioData, "wav", nil
}
// synthesizeFallback 生成静默 WAV 作为降级方案
// 生成 1 秒 16kHz 16-bit mono 静默 PCM WAV
func (s *TTSService) synthesizeFallback() ([]byte, string, error) {
// 1 秒 @ 16kHz mono 16-bit = 32000 字节采样数据
sampleRate := 16000
numChannels := 1
bitsPerSample := 16
durationSec := 1
dataSize := sampleRate * numChannels * (bitsPerSample / 8) * durationSec
// WAV header 44 bytes + data
wav := make([]byte, 44+dataSize)
// RIFF header
copy(wav[0:4], "RIFF")
writeUint32LE(wav[4:8], uint32(36+dataSize))
copy(wav[8:12], "WAVE")
// fmt chunk
copy(wav[12:16], "fmt ")
writeUint32LE(wav[16:20], 16) // chunk size
writeUint16LE(wav[20:22], 1) // PCM
writeUint16LE(wav[22:24], uint16(numChannels)) // channels
writeUint32LE(wav[24:28], uint32(sampleRate)) // sample rate
writeUint32LE(wav[28:32], uint32(sampleRate*numChannels*bitsPerSample/8)) // byte rate
writeUint16LE(wav[32:34], uint16(numChannels*bitsPerSample/8)) // block align
writeUint16LE(wav[34:36], uint16(bitsPerSample)) // bits per sample
// data chunk
copy(wav[36:40], "data")
writeUint32LE(wav[40:44], uint32(dataSize))
// 采样数据全是 0 (静默)
return wav, "wav", nil
}
func writeUint16LE(buf []byte, v uint16) {
buf[0] = byte(v)
buf[1] = byte(v >> 8)
}
func writeUint32LE(buf []byte, v uint32) {
buf[0] = byte(v)
buf[1] = byte(v >> 8)
buf[2] = byte(v >> 16)
buf[3] = byte(v >> 24)
}
// GetVoices 返回可用语音列表
func (s *TTSService) GetVoices() []TTSVoice {
// 检查 edge-tts 是否可用,尝试获取完整语音列表
if s.edgeTTSAvailable() {
cmd := exec.Command("edge-tts", "--list-voices")
output, err := cmd.Output()
if err == nil {
voices := s.parseEdgeTTSVoices(string(output))
if len(voices) > 0 {
return voices
}
}
}
return BuiltinVoices
}
// parseEdgeTTSVoices 解析 edge-tts --list-voices 输出
// 简单解析:查找包含 "zh-CN" 的语音
func (s *TTSService) parseEdgeTTSVoices(output string) []TTSVoice {
var voices []TTSVoice
for _, line := range strings.Split(output, "\n") {
line = strings.TrimSpace(line)
if !strings.Contains(line, "zh-CN") {
continue
}
voice := TTSVoice{
Name: "",
Gender: "Unknown",
Locale: "zh-CN",
}
// 简单解析 "Name: zh-CN-XiaoxiaoNeural" 和 "Gender: Female" 格式
for _, field := range strings.Split(line, ",") {
field = strings.TrimSpace(field)
if strings.HasPrefix(field, "Name:") {
voice.Name = strings.TrimSpace(strings.TrimPrefix(field, "Name:"))
}
if strings.HasPrefix(field, "Gender:") {
voice.Gender = strings.TrimSpace(strings.TrimPrefix(field, "Gender:"))
}
}
if voice.Name != "" {
voice.DisplayName = voice.Name
voices = append(voices, voice)
}
}
if len(voices) == 0 {
return nil
}
return voices
}
// GetEngineStatus 返回 TTS 引擎状态
func (s *TTSService) GetEngineStatus() map[string]interface{} {
status := map[string]interface{}{
"available": s.IsAvailable(),
"edge_tts": s.edgeTTSAvailable(),
"espeak_ng": s.espeakAvailable(),
"engine": "none",
"default_voice": "zh-CN-XiaoxiaoNeural",
"builtin_voices": len(BuiltinVoices),
}
if s.edgeTTSAvailable() {
status["engine"] = "edge-tts"
} else if s.espeakAvailable() {
status["engine"] = "espeak-ng"
} else {
status["engine"] = "fallback (silent WAV)"
}
return status
}