feat: 第五轮开发 - 14项未来路线图功能完整实现
W1-W14 全部完成: - W1: 消息搜索 (ILIKE全文检索 + SearchModal) - W2: 对话导出 (JSON/Markdown/TXT三格式) - W3: 记忆时间线 DevTools 可视化 - W4: 通知推送系统 (WebSocket + Browser Notification API) - W5: 定时提醒 (30s轮询 + 重复提醒 + WebSocket推送) - W6: 每日简报 (08:00自动生成: 天气+新闻+提醒+AI摘要) - W7: IoT场景自动化 (规则引擎 10s轮询 + 条件评估 + 场景执行) - W8: 语音输入 (浏览器 Speech Recognition API) - W9: STT服务 (voice-service + whisper.cpp) - W10: TTS服务 (浏览器 Speech Synthesis + edge-tts三档回退) - W11: 文件管理 (上传/下载/缩略图/纯Go bilinear缩放) - W12: 知识库RAG (PostgreSQL tsvector + 文档分块 + 检索) - W13: 多模态 (图片上传+分析: Vision API + 本地Go分析回退) - W14: PWA (Service Worker + 离线页 + install prompt) 总计: 6个Go微服务 + 10+前端组件 + 10+ PostgreSQL表 + 4个后台调度器
This commit is contained in:
@@ -0,0 +1,483 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/engine"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||||
)
|
||||
|
||||
// AutomationHandler 自动化处理器
|
||||
type AutomationHandler struct {
|
||||
store *store.AutomationStore
|
||||
engine *engine.RuleEngine
|
||||
}
|
||||
|
||||
// NewAutomationHandler 创建自动化处理器
|
||||
func NewAutomationHandler(s *store.AutomationStore, e *engine.RuleEngine) *AutomationHandler {
|
||||
return &AutomationHandler{store: s, engine: e}
|
||||
}
|
||||
|
||||
// ========== 请求/响应体 ==========
|
||||
|
||||
// CreateRuleRequest 创建规则请求
|
||||
type CreateRuleRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
TriggerType string `json:"trigger_type" binding:"required"`
|
||||
TriggerConfig json.RawMessage `json:"trigger_config"`
|
||||
Conditions json.RawMessage `json:"conditions"`
|
||||
Actions json.RawMessage `json:"actions" binding:"required"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// UpdateRuleRequest 更新规则请求
|
||||
type UpdateRuleRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
TriggerType *string `json:"trigger_type"`
|
||||
TriggerConfig *json.RawMessage `json:"trigger_config"`
|
||||
Conditions *json.RawMessage `json:"conditions"`
|
||||
Actions *json.RawMessage `json:"actions"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// CreateSceneRequest 创建场景请求
|
||||
type CreateSceneRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Icon string `json:"icon"`
|
||||
RuleIDs json.RawMessage `json:"rule_ids"`
|
||||
}
|
||||
|
||||
// UpdateSceneRequest 更新场景请求
|
||||
type UpdateSceneRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Icon *string `json:"icon"`
|
||||
RuleIDs *json.RawMessage `json:"rule_ids"`
|
||||
}
|
||||
|
||||
// ========== Rule Handlers ==========
|
||||
|
||||
// ListRules 获取用户的所有规则
|
||||
// GET /api/v1/automation/rules
|
||||
func (h *AutomationHandler) ListRules(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
|
||||
return
|
||||
}
|
||||
|
||||
rules, err := h.store.GetRulesByUser(userID)
|
||||
if err != nil {
|
||||
log.Printf("[automation] 获取规则列表失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取规则列表失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"rules": rules,
|
||||
"count": len(rules),
|
||||
})
|
||||
}
|
||||
|
||||
// CreateRule 创建规则
|
||||
// POST /api/v1/automation/rules
|
||||
func (h *AutomationHandler) CreateRule(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateRuleRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
triggerConfig := ensureRawMessage(req.TriggerConfig)
|
||||
conditions := ensureRawMessage(req.Conditions)
|
||||
actions := ensureRawMessage(req.Actions)
|
||||
|
||||
enabled := true
|
||||
if req.Enabled != nil {
|
||||
enabled = *req.Enabled
|
||||
}
|
||||
|
||||
rule := &store.AutomationRule{
|
||||
ID: randomHexID(),
|
||||
UserID: userID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
TriggerType: req.TriggerType,
|
||||
TriggerConfig: triggerConfig,
|
||||
Conditions: conditions,
|
||||
Actions: actions,
|
||||
Enabled: enabled,
|
||||
}
|
||||
|
||||
if err := h.store.CreateRule(rule); err != nil {
|
||||
log.Printf("[automation] 创建规则失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建规则失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"success": true,
|
||||
"rule": rule,
|
||||
})
|
||||
}
|
||||
|
||||
// GetRule 获取单条规则
|
||||
func (h *AutomationHandler) GetRule(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
rule, err := h.store.GetRule(id)
|
||||
if err != nil {
|
||||
log.Printf("[automation] 获取规则失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取规则失败"})
|
||||
return
|
||||
}
|
||||
if rule == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "规则不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"rule": rule})
|
||||
}
|
||||
|
||||
// UpdateRule 更新规则
|
||||
// PUT /api/v1/automation/rules/:id
|
||||
func (h *AutomationHandler) UpdateRule(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
id := c.Param("id")
|
||||
|
||||
// 先获取规则验证所有权
|
||||
existing, err := h.store.GetRule(id)
|
||||
if err != nil {
|
||||
log.Printf("[automation] 获取规则失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取规则失败"})
|
||||
return
|
||||
}
|
||||
if existing == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "规则不存在"})
|
||||
return
|
||||
}
|
||||
if existing.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权修改此规则"})
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateRuleRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 只更新提供的字段
|
||||
if req.Name != nil {
|
||||
existing.Name = *req.Name
|
||||
}
|
||||
if req.Description != nil {
|
||||
existing.Description = *req.Description
|
||||
}
|
||||
if req.TriggerType != nil {
|
||||
existing.TriggerType = *req.TriggerType
|
||||
}
|
||||
if req.TriggerConfig != nil {
|
||||
existing.TriggerConfig = req.TriggerConfig
|
||||
}
|
||||
if req.Conditions != nil {
|
||||
existing.Conditions = req.Conditions
|
||||
}
|
||||
if req.Actions != nil {
|
||||
existing.Actions = req.Actions
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
existing.Enabled = *req.Enabled
|
||||
}
|
||||
|
||||
if err := h.store.UpdateRule(existing); err != nil {
|
||||
log.Printf("[automation] 更新规则失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新规则失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"rule": existing,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteRule 删除规则
|
||||
// DELETE /api/v1/automation/rules/:id
|
||||
func (h *AutomationHandler) DeleteRule(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
id := c.Param("id")
|
||||
|
||||
existing, err := h.store.GetRule(id)
|
||||
if err != nil {
|
||||
log.Printf("[automation] 获取规则失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取规则失败"})
|
||||
return
|
||||
}
|
||||
if existing == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "规则不存在"})
|
||||
return
|
||||
}
|
||||
if existing.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此规则"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.DeleteRule(id); err != nil {
|
||||
log.Printf("[automation] 删除规则失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除规则失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// TriggerRule 手动触发单条规则
|
||||
// POST /api/v1/automation/rules/:id/trigger
|
||||
func (h *AutomationHandler) TriggerRule(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
id := c.Param("id")
|
||||
|
||||
rule, err := h.store.GetRule(id)
|
||||
if err != nil {
|
||||
log.Printf("[automation] 获取规则失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取规则失败"})
|
||||
return
|
||||
}
|
||||
if rule == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "规则不存在"})
|
||||
return
|
||||
}
|
||||
if rule.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权触发此规则"})
|
||||
return
|
||||
}
|
||||
|
||||
h.engine.ExecuteRuleActions(rule)
|
||||
h.store.MarkRuleTriggered(id)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "规则已触发",
|
||||
})
|
||||
}
|
||||
|
||||
// ========== Scene Handlers ==========
|
||||
|
||||
// ListScenes 获取所有场景
|
||||
// GET /api/v1/automation/scenes
|
||||
func (h *AutomationHandler) ListScenes(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
|
||||
return
|
||||
}
|
||||
|
||||
scenes, err := h.store.GetScenesByUser(userID)
|
||||
if err != nil {
|
||||
log.Printf("[automation] 获取场景列表失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取场景列表失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"scenes": scenes,
|
||||
"count": len(scenes),
|
||||
})
|
||||
}
|
||||
|
||||
// CreateScene 创建场景
|
||||
// POST /api/v1/automation/scenes
|
||||
func (h *AutomationHandler) CreateScene(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateSceneRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ruleIDs := ensureRawMessage(req.RuleIDs)
|
||||
|
||||
scene := &store.AutomationScene{
|
||||
ID: randomHexID(),
|
||||
UserID: userID,
|
||||
Name: req.Name,
|
||||
Icon: req.Icon,
|
||||
RuleIDs: ruleIDs,
|
||||
}
|
||||
|
||||
if err := h.store.CreateScene(scene); err != nil {
|
||||
log.Printf("[automation] 创建场景失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建场景失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"success": true,
|
||||
"scene": scene,
|
||||
})
|
||||
}
|
||||
|
||||
// GetScene 获取单个场景
|
||||
func (h *AutomationHandler) GetScene(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
scene, err := h.store.GetScene(id)
|
||||
if err != nil {
|
||||
log.Printf("[automation] 获取场景失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取场景失败"})
|
||||
return
|
||||
}
|
||||
if scene == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "场景不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"scene": scene})
|
||||
}
|
||||
|
||||
// UpdateScene 更新场景
|
||||
// PUT /api/v1/automation/scenes/:id
|
||||
func (h *AutomationHandler) UpdateScene(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
id := c.Param("id")
|
||||
|
||||
existing, err := h.store.GetScene(id)
|
||||
if err != nil {
|
||||
log.Printf("[automation] 获取场景失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取场景失败"})
|
||||
return
|
||||
}
|
||||
if existing == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "场景不存在"})
|
||||
return
|
||||
}
|
||||
if existing.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权修改此场景"})
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateSceneRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name != nil {
|
||||
existing.Name = *req.Name
|
||||
}
|
||||
if req.Icon != nil {
|
||||
existing.Icon = *req.Icon
|
||||
}
|
||||
if req.RuleIDs != nil {
|
||||
existing.RuleIDs = req.RuleIDs
|
||||
}
|
||||
|
||||
if err := h.store.UpdateScene(existing); err != nil {
|
||||
log.Printf("[automation] 更新场景失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新场景失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"scene": existing,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteScene 删除场景
|
||||
// DELETE /api/v1/automation/scenes/:id
|
||||
func (h *AutomationHandler) DeleteScene(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
id := c.Param("id")
|
||||
|
||||
existing, err := h.store.GetScene(id)
|
||||
if err != nil {
|
||||
log.Printf("[automation] 获取场景失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取场景失败"})
|
||||
return
|
||||
}
|
||||
if existing == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "场景不存在"})
|
||||
return
|
||||
}
|
||||
if existing.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此场景"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.DeleteScene(id); err != nil {
|
||||
log.Printf("[automation] 删除场景失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除场景失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// ExecuteScene 手动执行场景
|
||||
// POST /api/v1/automation/scenes/:id/execute
|
||||
func (h *AutomationHandler) ExecuteScene(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
// 验证场景存在
|
||||
scene, err := h.store.GetScene(id)
|
||||
if err != nil {
|
||||
log.Printf("[automation] 获取场景失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取场景失败"})
|
||||
return
|
||||
}
|
||||
if scene == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "场景不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
if err := h.engine.ExecuteScene(id, userID); err != nil {
|
||||
log.Printf("[automation] 执行场景失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "执行场景失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "场景已执行",
|
||||
})
|
||||
}
|
||||
|
||||
// ========== 辅助函数 ==========
|
||||
|
||||
// randomHexID 使用 crypto/rand 生成 16 字节 hex ID
|
||||
func randomHexID() string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// ensureRawMessage 确保 json.RawMessage 非空
|
||||
func ensureRawMessage(raw json.RawMessage) *json.RawMessage {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make(json.RawMessage, len(raw))
|
||||
copy(result, raw)
|
||||
return &result
|
||||
}
|
||||
@@ -0,0 +1,709 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/ws"
|
||||
)
|
||||
|
||||
// BriefingHandler 每日简报处理器
|
||||
type BriefingHandler struct {
|
||||
cfg *config.Config
|
||||
hub *ws.Hub
|
||||
briefingStore *store.BriefingStore
|
||||
reminderStore *store.ReminderStore
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewBriefingHandler 创建简报处理器
|
||||
func NewBriefingHandler(cfg *config.Config, hub *ws.Hub, bs *store.BriefingStore, rs *store.ReminderStore) *BriefingHandler {
|
||||
return &BriefingHandler{
|
||||
cfg: cfg,
|
||||
hub: hub,
|
||||
briefingStore: bs,
|
||||
reminderStore: rs,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateBriefingRequest 手动生成简报请求体
|
||||
type GenerateBriefingRequest struct {
|
||||
UserID string `json:"user_id" binding:"required"`
|
||||
}
|
||||
|
||||
// GetBriefing 获取指定日期简报
|
||||
// GET /api/v1/briefings?user_id=xxx&date=2024-01-01
|
||||
func (h *BriefingHandler) GetBriefing(c *gin.Context) {
|
||||
authUserID := middleware.GetUserID(c)
|
||||
userID := c.Query("user_id")
|
||||
date := c.Query("date")
|
||||
|
||||
if !strings.HasPrefix(authUserID, "admin_") || userID == "" {
|
||||
userID = authUserID
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少user_id参数"})
|
||||
return
|
||||
}
|
||||
if date == "" {
|
||||
date = time.Now().Format("2006-01-02")
|
||||
}
|
||||
|
||||
briefing, err := h.briefingStore.GetBriefingByDate(userID, date)
|
||||
if err != nil {
|
||||
log.Printf("[briefing] 查询简报失败: user=%s date=%s err=%v", userID, date, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询简报失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if briefing == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"briefing": nil,
|
||||
"message": "当日简报尚未生成",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"briefing": briefing})
|
||||
}
|
||||
|
||||
// GetLatestBriefings 获取最近简报列表
|
||||
// GET /api/v1/briefings/latest?user_id=xxx&limit=7
|
||||
func (h *BriefingHandler) GetLatestBriefings(c *gin.Context) {
|
||||
authUserID := middleware.GetUserID(c)
|
||||
userID := c.Query("user_id")
|
||||
|
||||
if !strings.HasPrefix(authUserID, "admin_") || userID == "" {
|
||||
userID = authUserID
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少user_id参数"})
|
||||
return
|
||||
}
|
||||
|
||||
limit := 7
|
||||
if l := c.Query("limit"); l != "" {
|
||||
if parsed, err := parseInt(l); err == nil && parsed > 0 && parsed <= 30 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
briefings, err := h.briefingStore.GetLatestBriefings(userID, limit)
|
||||
if err != nil {
|
||||
log.Printf("[briefing] 查询简报列表失败: user=%s err=%v", userID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询简报列表失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"briefings": briefings,
|
||||
"total": len(briefings),
|
||||
})
|
||||
}
|
||||
|
||||
// Generate 手动触发生成今日简报
|
||||
// POST /api/v1/briefings/generate
|
||||
func (h *BriefingHandler) Generate(c *gin.Context) {
|
||||
authUserID := middleware.GetUserID(c)
|
||||
var req GenerateBriefingRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 非管理员只能为自己生成
|
||||
if !strings.HasPrefix(authUserID, "admin_") {
|
||||
req.UserID = authUserID
|
||||
}
|
||||
|
||||
if req.UserID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少user_id"})
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.GenerateDailyBriefing(req.UserID)
|
||||
if err != nil {
|
||||
log.Printf("[briefing] 生成简报失败: user=%s err=%v", req.UserID, err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "生成简报失败: " + err.Error(),
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成后推送通知
|
||||
h.pushBriefingNotification(req.UserID, result)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"briefing": result,
|
||||
"message": "简报已生成并推送",
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateDailyBriefing 生成每日简报(核心逻辑)
|
||||
func (h *BriefingHandler) GenerateDailyBriefing(userID string) (*store.Briefing, error) {
|
||||
today := time.Now().Format("2006-01-02")
|
||||
|
||||
briefing := &store.Briefing{
|
||||
ID: "brief_" + generateID(),
|
||||
UserID: userID,
|
||||
Date: today,
|
||||
Status: "pending",
|
||||
Weather: &store.WeatherData{},
|
||||
News: []store.NewsItem{},
|
||||
Reminders: []store.BriefReminder{},
|
||||
}
|
||||
|
||||
// 1. 获取天气数据
|
||||
log.Printf("[briefing] 获取天气数据...")
|
||||
weather, err := h.fetchWeather("Shanghai")
|
||||
if err != nil {
|
||||
log.Printf("[briefing] 天气获取失败 (降级): %v", err)
|
||||
weather = &store.WeatherData{
|
||||
Location: "未知",
|
||||
Temp: 0,
|
||||
Condition: "获取天气失败",
|
||||
Icon: "❓",
|
||||
}
|
||||
}
|
||||
briefing.Weather = weather
|
||||
log.Printf("[briefing] 天气: %s %.1f°C %s", weather.Location, weather.Temp, weather.Condition)
|
||||
|
||||
// 2. 获取今日待办提醒
|
||||
log.Printf("[briefing] 获取待办提醒...")
|
||||
reminders, err := h.reminderStore.GetRemindersByUser(userID, "pending", 10, 0)
|
||||
if err != nil {
|
||||
log.Printf("[briefing] 获取提醒失败: %v", err)
|
||||
} else {
|
||||
now := time.Now()
|
||||
endOfDay := time.Date(now.Year(), now.Month(), now.Day(), 23, 59, 59, 0, now.Location())
|
||||
for _, r := range reminders {
|
||||
if r.RemindAt.Before(endOfDay) || r.RemindAt.Equal(endOfDay) {
|
||||
briefing.Reminders = append(briefing.Reminders, store.BriefReminder{
|
||||
ID: r.ID,
|
||||
Title: r.Title,
|
||||
RemindAt: r.RemindAt.Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Printf("[briefing] 今日待办: %d 项", len(briefing.Reminders))
|
||||
|
||||
// 3. 获取新闻摘要(通过 tool-engine web_search)
|
||||
log.Printf("[briefing] 获取新闻摘要...")
|
||||
news, err := h.fetchNews()
|
||||
if err != nil {
|
||||
log.Printf("[briefing] 新闻获取失败 (降级): %v", err)
|
||||
}
|
||||
briefing.News = news
|
||||
log.Printf("[briefing] 新闻: %d 条", len(news))
|
||||
|
||||
// 4. 生成 AI 摘要
|
||||
log.Printf("[briefing] 生成 AI 摘要...")
|
||||
summary, err := h.generateAISummary(briefing)
|
||||
if err != nil {
|
||||
log.Printf("[briefing] AI 摘要生成失败 (降级): %v", err)
|
||||
summary = h.buildFallbackSummary(briefing)
|
||||
}
|
||||
briefing.Summary = summary
|
||||
|
||||
// 5. 标记为已生成
|
||||
now := time.Now()
|
||||
briefing.Status = "generated"
|
||||
briefing.GeneratedAt = &now
|
||||
|
||||
// 6. 持久化
|
||||
if err := h.briefingStore.CreateOrUpdateBriefing(briefing); err != nil {
|
||||
return nil, fmt.Errorf("保存简报失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[briefing] 简报已生成: user=%s date=%s", userID, today)
|
||||
return briefing, nil
|
||||
}
|
||||
|
||||
// fetchWeather 通过 wttr.in API 获取天气数据
|
||||
func (h *BriefingHandler) fetchWeather(location string) (*store.WeatherData, error) {
|
||||
url := fmt.Sprintf("https://wttr.in/%s?format=j1", location)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("构建天气请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("User-Agent", "Cyrene-AI/1.0")
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求天气API失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("天气API返回状态码 %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取天气响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析 wttr.in JSON 响应
|
||||
var wttrResp struct {
|
||||
CurrentCondition []struct {
|
||||
TempC string `json:"temp_C"`
|
||||
WeatherDesc []struct {
|
||||
Value string `json:"value"`
|
||||
} `json:"weatherDesc"`
|
||||
WeatherIconURL []struct {
|
||||
Value string `json:"value"`
|
||||
} `json:"weatherIconUrl"`
|
||||
} `json:"current_condition"`
|
||||
NearestArea []struct {
|
||||
AreaName []struct {
|
||||
Value string `json:"value"`
|
||||
} `json:"areaName"`
|
||||
} `json:"nearest_area"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &wttrResp); err != nil {
|
||||
return nil, fmt.Errorf("解析天气数据失败: %w", err)
|
||||
}
|
||||
|
||||
wd := &store.WeatherData{
|
||||
Location: location,
|
||||
}
|
||||
|
||||
if len(wttrResp.NearestArea) > 0 && len(wttrResp.NearestArea[0].AreaName) > 0 {
|
||||
wd.Location = wttrResp.NearestArea[0].AreaName[0].Value
|
||||
}
|
||||
|
||||
if len(wttrResp.CurrentCondition) > 0 {
|
||||
cc := wttrResp.CurrentCondition[0]
|
||||
wd.Temp = parseFloat(cc.TempC)
|
||||
|
||||
if len(cc.WeatherDesc) > 0 {
|
||||
wd.Condition = cc.WeatherDesc[0].Value
|
||||
}
|
||||
|
||||
// 根据天气描述转 emoji
|
||||
wd.Icon = weatherEmoji(wd.Condition)
|
||||
}
|
||||
|
||||
return wd, nil
|
||||
}
|
||||
|
||||
// fetchNews 通过 tool-engine web_search 搜索今日新闻
|
||||
func (h *BriefingHandler) fetchNews() ([]store.NewsItem, error) {
|
||||
if h.cfg.ToolEngineURL == "" {
|
||||
return nil, fmt.Errorf("ToolEngine URL 未配置")
|
||||
}
|
||||
|
||||
today := time.Now().Format("2006年01月02日")
|
||||
query := fmt.Sprintf("%s 今日要闻 热点新闻", today)
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]interface{}{
|
||||
"arguments": map[string]interface{}{
|
||||
"query": query,
|
||||
"limit": 5,
|
||||
},
|
||||
})
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/tools/web_search/execute", h.cfg.ToolEngineURL)
|
||||
req, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("构建新闻搜索请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求新闻搜索失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取新闻搜索结果失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("新闻搜索返回状态码 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析 tool-engine 单个工具执行响应: {id, output, error?}
|
||||
var result struct {
|
||||
ID string `json:"id"`
|
||||
Output string `json:"output"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析新闻搜索结果失败: %w", err)
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
log.Printf("[briefing] 新闻搜索失败: %s", result.Error)
|
||||
// 返回降级新闻
|
||||
return []store.NewsItem{
|
||||
{
|
||||
Title: "未能获取今日新闻",
|
||||
URL: "",
|
||||
Source: "系统",
|
||||
Summary: "新闻搜索服务暂时不可用,请稍后再试。",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
var news []store.NewsItem
|
||||
|
||||
// 尝试解析搜索结果为结构化数据
|
||||
var searchResults []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Snippet string `json:"snippet"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(result.Output), &searchResults); err != nil {
|
||||
// 如果不是 JSON 数组,当做纯文本处理
|
||||
news = append(news, store.NewsItem{
|
||||
Title: "今日新闻",
|
||||
URL: "",
|
||||
Source: "搜索引擎",
|
||||
Summary: truncateStr(result.Output, 200),
|
||||
})
|
||||
} else {
|
||||
for _, sr := range searchResults {
|
||||
news = append(news, store.NewsItem{
|
||||
Title: sr.Title,
|
||||
URL: sr.URL,
|
||||
Source: sr.Source,
|
||||
Summary: sr.Snippet,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(news) == 0 {
|
||||
news = []store.NewsItem{
|
||||
{
|
||||
Title: "未能获取今日新闻",
|
||||
URL: "",
|
||||
Source: "系统",
|
||||
Summary: "新闻搜索服务暂时不可用,请稍后再试。",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 限制最多 5 条
|
||||
if len(news) > 5 {
|
||||
news = news[:5]
|
||||
}
|
||||
|
||||
return news, nil
|
||||
}
|
||||
|
||||
// generateAISummary 通过 AI-Core 生成人性化摘要
|
||||
func (h *BriefingHandler) generateAISummary(b *store.Briefing) (string, error) {
|
||||
if h.cfg.AICoreURL == "" {
|
||||
return "", fmt.Errorf("AI-Core URL 未配置")
|
||||
}
|
||||
|
||||
// 构建提示词
|
||||
prompt := h.buildSummaryPrompt(b)
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]interface{}{
|
||||
"messages": []map[string]interface{}{
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是昔涟,一个温柔贴心的AI助手。请用温暖、亲切的语气回复,像朋友一样关心用户。回复使用中文。",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
},
|
||||
"max_tokens": 500,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
|
||||
url := fmt.Sprintf("%s/api/v1/chat/completions", h.cfg.AICoreURL)
|
||||
req, err := http.NewRequest("POST", url, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("构建 AI 请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("请求 AI-Core 失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取 AI 响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("AI-Core 返回状态码 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 解析 OpenAI 兼容响应
|
||||
var aiResp struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &aiResp); err != nil {
|
||||
return "", fmt.Errorf("解析 AI 响应失败: %w", err)
|
||||
}
|
||||
|
||||
if len(aiResp.Choices) == 0 {
|
||||
return "", fmt.Errorf("AI 返回空响应")
|
||||
}
|
||||
|
||||
return aiResp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
// buildSummaryPrompt 构建 AI 摘要提示词
|
||||
func (h *BriefingHandler) buildSummaryPrompt(b *store.Briefing) string {
|
||||
var sb strings.Builder
|
||||
today := time.Now().Format("2006年01月02日")
|
||||
|
||||
sb.WriteString(fmt.Sprintf("今天是%s。请根据以下信息,用昔涟温柔的语气为用户生成一份简短的每日简报(控制在200字以内):\n\n", today))
|
||||
|
||||
// 天气
|
||||
if b.Weather != nil && b.Weather.Condition != "" {
|
||||
sb.WriteString(fmt.Sprintf("☁️ 天气:%s,%.0f°C,%s\n", b.Weather.Location, b.Weather.Temp, b.Weather.Condition))
|
||||
}
|
||||
|
||||
// 待办
|
||||
if len(b.Reminders) > 0 {
|
||||
sb.WriteString("📋 今日待办:\n")
|
||||
for _, r := range b.Reminders {
|
||||
sb.WriteString(fmt.Sprintf(" - %s\n", r.Title))
|
||||
}
|
||||
}
|
||||
|
||||
// 新闻
|
||||
if len(b.News) > 0 {
|
||||
sb.WriteString("📰 今日新闻:\n")
|
||||
for _, n := range b.News {
|
||||
if n.Summary != "" {
|
||||
sb.WriteString(fmt.Sprintf(" - %s: %s\n", n.Title, n.Summary))
|
||||
} else {
|
||||
sb.WriteString(fmt.Sprintf(" - %s\n", n.Title))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n请用昔涟的语气回复,包含:1) 温馨问候 2) 天气提醒 3) 待办提醒 4) 新闻简要 5) 结语祝福。简洁自然即可。")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// buildFallbackSummary 降级摘要(不依赖 AI)
|
||||
func (h *BriefingHandler) buildFallbackSummary(b *store.Briefing) string {
|
||||
today := time.Now().Format("2006年01月02日")
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(fmt.Sprintf("早上好!今天是%s ☀️\n\n", today))
|
||||
|
||||
if b.Weather != nil && b.Weather.Condition != "" {
|
||||
sb.WriteString(fmt.Sprintf("今日%s天气:%s,%.0f°C。", b.Weather.Location, b.Weather.Condition, b.Weather.Temp))
|
||||
if b.Weather.Temp < 10 {
|
||||
sb.WriteString("天气有点凉,记得多穿件衣服哦~")
|
||||
} else if b.Weather.Temp > 30 {
|
||||
sb.WriteString("天气比较热,注意防暑降温哦~")
|
||||
} else {
|
||||
sb.WriteString("天气不错,适合出门走走呢~")
|
||||
}
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
if len(b.Reminders) > 0 {
|
||||
sb.WriteString(fmt.Sprintf("你今天有 %d 项待办事项,记得按时完成哦!\n", len(b.Reminders)))
|
||||
} else {
|
||||
sb.WriteString("今天没有待办事项,可以轻松一下~\n")
|
||||
}
|
||||
|
||||
if len(b.News) > 0 && b.News[0].Title != "未能获取今日新闻" {
|
||||
sb.WriteString(fmt.Sprintf("\n今日热点:%s。", b.News[0].Title))
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n祝你度过美好的一天!🌸")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// pushBriefingNotification 推送简报通知到用户
|
||||
func (h *BriefingHandler) pushBriefingNotification(userID string, b *store.Briefing) {
|
||||
bodyPreview := truncateStr(b.Summary, 100)
|
||||
|
||||
notif := &ws.NotificationInfo{
|
||||
ID: "briefing_" + b.ID,
|
||||
Type: "info",
|
||||
Title: fmt.Sprintf("📋 今日简报 (%s)", b.Date),
|
||||
Body: bodyPreview,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
Data: map[string]interface{}{
|
||||
"briefing_id": b.ID,
|
||||
"date": b.Date,
|
||||
"type": "daily_briefing",
|
||||
},
|
||||
}
|
||||
|
||||
msg := ws.ServerMessage{
|
||||
Type: "notification",
|
||||
MessageID: "briefing_" + b.ID,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
Notification: notif,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
log.Printf("[briefing] 序列化简报通知失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
h.hub.SendToUser(userID, data)
|
||||
|
||||
// 更新简报状态为已送达
|
||||
now := time.Now()
|
||||
b.Status = "delivered"
|
||||
b.DeliveredAt = &now
|
||||
if err := h.briefingStore.CreateOrUpdateBriefing(b); err != nil {
|
||||
log.Printf("[briefing] 更新简报送达状态失败: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("[briefing] 简报通知已推送: user=%s date=%s", userID, b.Date)
|
||||
}
|
||||
|
||||
// StartBriefingScheduler 启动简报调度器
|
||||
// briefingTime 格式: "HH:MM",默认 "08:00"
|
||||
func StartBriefingScheduler(handler *BriefingHandler, briefingStore *store.BriefingStore, briefingTime string) {
|
||||
if briefingTime == "" {
|
||||
briefingTime = "08:00"
|
||||
}
|
||||
|
||||
go func() {
|
||||
// 每 30 秒检查一次是否到达简报时间
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
log.Printf("[BriefingScheduler] 简报调度器已启动 (简报时间: %s)", briefingTime)
|
||||
|
||||
// 记录今天是否已触发
|
||||
lastTriggeredDate := ""
|
||||
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
currentTime := now.Format("15:04")
|
||||
currentDate := now.Format("2006-01-02")
|
||||
|
||||
// 检查是否到达简报时间且今天尚未触发
|
||||
if currentTime == briefingTime && currentDate != lastTriggeredDate {
|
||||
log.Printf("[BriefingScheduler] 触发每日简报生成: %s", currentDate)
|
||||
lastTriggeredDate = currentDate
|
||||
|
||||
// 获取所有用户
|
||||
users, err := briefingStore.GetAllUsers()
|
||||
if err != nil {
|
||||
log.Printf("[BriefingScheduler] 获取用户列表失败: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 如果没有从 reminders 表获取到用户,也尝试从 briefings 表获取
|
||||
if len(users) == 0 {
|
||||
users, _ = briefingStore.GetUsersWithBriefings()
|
||||
}
|
||||
|
||||
if len(users) == 0 {
|
||||
log.Println("[BriefingScheduler] 没有找到用户,跳过简报生成")
|
||||
continue
|
||||
}
|
||||
|
||||
for _, userID := range users {
|
||||
log.Printf("[BriefingScheduler] 为用户 %s 生成简报...", userID)
|
||||
result, err := handler.GenerateDailyBriefing(userID)
|
||||
if err != nil {
|
||||
log.Printf("[BriefingScheduler] 生成简报失败: user=%s err=%v", userID, err)
|
||||
continue
|
||||
}
|
||||
handler.pushBriefingNotification(userID, result)
|
||||
}
|
||||
|
||||
log.Printf("[BriefingScheduler] 每日简报已生成完毕,共 %d 个用户", len(users))
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ========== 辅助函数 ==========
|
||||
|
||||
// weatherEmoji 根据天气描述返回对应 emoji
|
||||
func weatherEmoji(condition string) string {
|
||||
c := strings.ToLower(condition)
|
||||
switch {
|
||||
case strings.Contains(c, "sunny") || strings.Contains(c, "clear") || strings.Contains(c, "晴"):
|
||||
return "☀️"
|
||||
case strings.Contains(c, "partly cloudy") || strings.Contains(c, "多云"):
|
||||
return "⛅"
|
||||
case strings.Contains(c, "cloudy") || strings.Contains(c, "阴"):
|
||||
return "☁️"
|
||||
case strings.Contains(c, "rain") || strings.Contains(c, "drizzle") || strings.Contains(c, "雨"):
|
||||
return "🌧️"
|
||||
case strings.Contains(c, "thunder") || strings.Contains(c, "雷"):
|
||||
return "⛈️"
|
||||
case strings.Contains(c, "snow") || strings.Contains(c, "雪"):
|
||||
return "❄️"
|
||||
case strings.Contains(c, "fog") || strings.Contains(c, "mist") || strings.Contains(c, "雾"):
|
||||
return "🌫️"
|
||||
case strings.Contains(c, "wind") || strings.Contains(c, "风"):
|
||||
return "💨"
|
||||
default:
|
||||
return "🌤️"
|
||||
}
|
||||
}
|
||||
|
||||
// parseFloat 安全解析浮点数
|
||||
func parseFloat(s string) float64 {
|
||||
var f float64
|
||||
fmt.Sscanf(s, "%f", &f)
|
||||
return f
|
||||
}
|
||||
|
||||
// parseInt 安全解析整数
|
||||
func parseInt(s string) (int, error) {
|
||||
var n int
|
||||
_, err := fmt.Sscanf(s, "%d", &n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// truncateStr 截断字符串
|
||||
func truncateStr(s string, maxLen int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return string(runes[:maxLen]) + "..."
|
||||
}
|
||||
@@ -128,12 +128,15 @@ func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage)
|
||||
h.hub.UpdateSessionState(client.SessionID, "thinking")
|
||||
|
||||
// 构建 AI-Core 请求
|
||||
aiReq := map[string]string{
|
||||
aiReq := map[string]interface{}{
|
||||
"user_id": client.UserID,
|
||||
"session_id": client.SessionID,
|
||||
"message": msg.Content,
|
||||
"mode": mode,
|
||||
}
|
||||
if len(msg.Attachments) > 0 {
|
||||
aiReq["attachments"] = msg.Attachments
|
||||
}
|
||||
reqBody, err := json.Marshal(aiReq)
|
||||
if err != nil {
|
||||
log.Printf("[chat] 序列化请求失败: %v", err)
|
||||
@@ -148,12 +151,16 @@ func (h *ChatHandler) handleChatMessage(client *ws.Client, msg ws.ClientMessage)
|
||||
}
|
||||
|
||||
// 缓存用户消息(在 goroutine 前完成,避免竞态)
|
||||
h.hub.CacheMessage(client.UserID, client.SessionID, ws.Message{
|
||||
userMsg := ws.Message{
|
||||
ID: "msg_" + generateID(),
|
||||
Role: "user",
|
||||
Content: msg.Content,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
})
|
||||
}
|
||||
if len(msg.Attachments) > 0 {
|
||||
userMsg.Attachments = msg.Attachments
|
||||
}
|
||||
h.hub.CacheMessage(client.UserID, client.SessionID, userMsg)
|
||||
|
||||
// 在 goroutine 中进行 AI-Core 调用和流式发送,避免阻塞 ReadPump
|
||||
go h.streamResponse(client, mode, reqBody, msg.Content)
|
||||
|
||||
@@ -0,0 +1,706 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/jpeg"
|
||||
"image/png"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||||
)
|
||||
|
||||
// FileHandler 文件管理处理器
|
||||
type FileHandler struct {
|
||||
store *store.FileStore
|
||||
uploadDir string
|
||||
}
|
||||
|
||||
// NewFileHandler 创建文件处理器
|
||||
func NewFileHandler(s *store.FileStore) *FileHandler {
|
||||
return &FileHandler{
|
||||
store: s,
|
||||
uploadDir: "./uploads",
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 允许的文件类型 ==========
|
||||
|
||||
var allowedMimeTypes = map[string]bool{
|
||||
// 图片
|
||||
"image/jpeg": true,
|
||||
"image/png": true,
|
||||
"image/gif": true,
|
||||
"image/webp": true,
|
||||
"image/svg+xml": true,
|
||||
// 文档
|
||||
"application/pdf": true,
|
||||
"text/plain": true,
|
||||
"text/markdown": true,
|
||||
"application/msword": true,
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": true,
|
||||
// 音频
|
||||
"audio/mpeg": true,
|
||||
"audio/wav": true,
|
||||
"audio/ogg": true,
|
||||
"audio/webm": true,
|
||||
// 视频
|
||||
"video/mp4": true,
|
||||
"video/webm": true,
|
||||
}
|
||||
|
||||
var allowedExtensions = map[string]string{
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
".svg": "image/svg+xml",
|
||||
".pdf": "application/pdf",
|
||||
".txt": "text/plain",
|
||||
".md": "text/markdown",
|
||||
".doc": "application/msword",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".mp3": "audio/mpeg",
|
||||
".wav": "audio/wav",
|
||||
".ogg": "audio/ogg",
|
||||
".mp4": "video/mp4",
|
||||
}
|
||||
|
||||
const maxFileSize = 20 * 1024 * 1024 // 20MB
|
||||
|
||||
// ========== POST /api/v1/files/upload ==========
|
||||
|
||||
// Upload 处理文件上传
|
||||
func (h *FileHandler) Upload(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
// Nil store guard — 数据库不可用时拒绝上传
|
||||
if h.store == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用,数据库未连接", "errorType": "service_unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取上传的文件
|
||||
file, header, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到上传文件", "errorType": "missing_file"})
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 检查文件大小
|
||||
if header.Size > maxFileSize {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "文件大小超过限制 (最大 20MB)",
|
||||
"errorType": "file_too_large",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证文件类型
|
||||
mimeType := header.Header.Get("Content-Type")
|
||||
ext := strings.ToLower(filepath.Ext(header.Filename))
|
||||
if mimeType == "" || mimeType == "application/octet-stream" {
|
||||
// 尝试从扩展名推断
|
||||
if inferred, ok := allowedExtensions[ext]; ok {
|
||||
mimeType = inferred
|
||||
}
|
||||
}
|
||||
if !allowedMimeTypes[mimeType] {
|
||||
// 再尝试通过扩展名检查
|
||||
if _, ok := allowedExtensions[ext]; !ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "不支持的文件类型: " + mimeType,
|
||||
"errorType": "unsupported_type",
|
||||
})
|
||||
return
|
||||
}
|
||||
mimeType = allowedExtensions[ext]
|
||||
}
|
||||
|
||||
// 安全化文件名:移除路径分隔符和特殊字符
|
||||
safeFilename := sanitizeFilename(header.Filename)
|
||||
|
||||
// 生成文件ID (crypto/rand UUID v4)
|
||||
fileID := generateUUID()
|
||||
|
||||
// 创建按日期组织的目录
|
||||
dateDir := time.Now().Format("2006-01-02")
|
||||
storedDir := filepath.Join(h.uploadDir, dateDir)
|
||||
if err := os.MkdirAll(storedDir, 0755); err != nil {
|
||||
log.Printf("[FileHandler] 创建上传目录失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建上传目录失败", "errorType": "server_error"})
|
||||
return
|
||||
}
|
||||
|
||||
// 存储路径:以UUID+原始扩展名保存
|
||||
storedFilename := fileID + ext
|
||||
storedPath := filepath.Join(storedDir, storedFilename)
|
||||
|
||||
// 计算 SHA256 hash
|
||||
hasher := sha256.New()
|
||||
teeReader := io.TeeReader(file, hasher)
|
||||
|
||||
// 保存到磁盘
|
||||
dst, err := os.Create(storedPath)
|
||||
if err != nil {
|
||||
log.Printf("[FileHandler] 创建文件失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存文件失败", "errorType": "server_error"})
|
||||
return
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
written, err := io.Copy(dst, teeReader)
|
||||
if err != nil {
|
||||
os.Remove(storedPath)
|
||||
log.Printf("[FileHandler] 写入文件失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "写入文件失败", "errorType": "server_error"})
|
||||
return
|
||||
}
|
||||
|
||||
hash := hex.EncodeToString(hasher.Sum(nil))
|
||||
|
||||
// 去重检查:如果存在相同hash的文件,复用已有记录
|
||||
if existing, err := h.store.GetFileByHash(hash); err == nil && existing != nil {
|
||||
// 删除刚保存的重复文件
|
||||
os.Remove(storedPath)
|
||||
log.Printf("[FileHandler] 文件去重: 复用已有文件 %s (hash=%s)", existing.ID, hash[:16])
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": existing.ID,
|
||||
"filename": existing.Filename,
|
||||
"mime_type": existing.MimeType,
|
||||
"size": existing.Size,
|
||||
"url": fmt.Sprintf("/api/v1/files/%s/download", existing.ID),
|
||||
"dedup": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 创建数据库记录
|
||||
fileRecord := &store.File{
|
||||
ID: fileID,
|
||||
UserID: userID,
|
||||
Filename: safeFilename,
|
||||
StoredPath: storedPath,
|
||||
MimeType: mimeType,
|
||||
Size: written,
|
||||
Hash: hash,
|
||||
IsPublic: false,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := h.store.CreateFile(fileRecord); err != nil {
|
||||
os.Remove(storedPath)
|
||||
log.Printf("[FileHandler] 创建文件记录失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建文件记录失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[FileHandler] 文件上传成功: %s (%s, %d bytes, hash=%s)", fileID, safeFilename, written, hash[:16])
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"id": fileID,
|
||||
"filename": safeFilename,
|
||||
"mime_type": mimeType,
|
||||
"size": written,
|
||||
"url": fmt.Sprintf("/api/v1/files/%s/download", fileID),
|
||||
})
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/files ==========
|
||||
|
||||
// List 列出用户的所有文件 (支持分页)
|
||||
func (h *FileHandler) List(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
if h.store == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
|
||||
files, total, err := h.store.GetUserFiles(userID, page, limit)
|
||||
if err != nil {
|
||||
log.Printf("[FileHandler] 查询文件列表失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文件列表失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := make([]gin.H, 0, len(files))
|
||||
for _, f := range files {
|
||||
item := gin.H{
|
||||
"id": f.ID,
|
||||
"user_id": f.UserID,
|
||||
"filename": f.Filename,
|
||||
"mime_type": f.MimeType,
|
||||
"size": f.Size,
|
||||
"hash": f.Hash,
|
||||
"is_public": f.IsPublic,
|
||||
"created_at": f.CreatedAt.UnixMilli(),
|
||||
"url": fmt.Sprintf("/api/v1/files/%s/download", f.ID),
|
||||
}
|
||||
// 图片类型添加缩略图URL
|
||||
if isImageType(f.MimeType) {
|
||||
item["thumbnail_url"] = fmt.Sprintf("/api/v1/files/%s/thumbnail", f.ID)
|
||||
}
|
||||
result = append(result, item)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"files": result,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
})
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/files/:id ==========
|
||||
|
||||
// Get 获取文件元数据
|
||||
func (h *FileHandler) Get(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
fileID := c.Param("id")
|
||||
|
||||
if h.store == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
f, err := h.store.GetFile(fileID)
|
||||
if err != nil {
|
||||
log.Printf("[FileHandler] 查询文件失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文件失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if f == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "文件不存在",
|
||||
"errorType": "file_not_found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查权限
|
||||
if f.UserID != userID && !f.IsPublic {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文件", "errorType": "access_denied"})
|
||||
return
|
||||
}
|
||||
|
||||
item := gin.H{
|
||||
"id": f.ID,
|
||||
"user_id": f.UserID,
|
||||
"filename": f.Filename,
|
||||
"mime_type": f.MimeType,
|
||||
"size": f.Size,
|
||||
"hash": f.Hash,
|
||||
"is_public": f.IsPublic,
|
||||
"created_at": f.CreatedAt.UnixMilli(),
|
||||
"url": fmt.Sprintf("/api/v1/files/%s/download", f.ID),
|
||||
}
|
||||
if isImageType(f.MimeType) {
|
||||
item["thumbnail_url"] = fmt.Sprintf("/api/v1/files/%s/thumbnail", f.ID)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, item)
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/files/:id/download ==========
|
||||
|
||||
// Download 下载文件
|
||||
func (h *FileHandler) Download(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
fileID := c.Param("id")
|
||||
|
||||
if h.store == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
f, err := h.store.GetFile(fileID)
|
||||
if err != nil {
|
||||
log.Printf("[FileHandler] 查询文件失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文件失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if f == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "文件不存在",
|
||||
"errorType": "file_not_found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查权限
|
||||
if f.UserID != userID && !f.IsPublic {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文件", "errorType": "access_denied"})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查文件在磁盘上是否存在
|
||||
if _, err := os.Stat(f.StoredPath); os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "文件实体不存在(可能已被清理)",
|
||||
"errorType": "file_missing",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置 Content-Disposition
|
||||
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, f.Filename))
|
||||
c.Header("Content-Type", f.MimeType)
|
||||
c.File(f.StoredPath)
|
||||
}
|
||||
|
||||
// ========== DELETE /api/v1/files/:id ==========
|
||||
|
||||
// Delete 删除文件
|
||||
func (h *FileHandler) Delete(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
fileID := c.Param("id")
|
||||
|
||||
if h.store == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
f, err := h.store.GetFile(fileID)
|
||||
if err != nil {
|
||||
log.Printf("[FileHandler] 查询文件失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文件失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if f == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "文件不存在",
|
||||
"errorType": "file_not_found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查权限
|
||||
if f.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此文件", "errorType": "access_denied"})
|
||||
return
|
||||
}
|
||||
|
||||
// 删除磁盘上的文件(忽略错误,可能已被删除)
|
||||
if err := os.Remove(f.StoredPath); err != nil && !os.IsNotExist(err) {
|
||||
log.Printf("[FileHandler] 删除磁盘文件失败 (stored_path=%s): %v", f.StoredPath, err)
|
||||
}
|
||||
|
||||
// 删除数据库记录
|
||||
if err := h.store.DeleteFile(fileID); err != nil {
|
||||
log.Printf("[FileHandler] 删除文件记录失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除文件记录失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/files/:id/thumbnail ==========
|
||||
|
||||
// Thumbnail 返回文件缩略图
|
||||
func (h *FileHandler) Thumbnail(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
fileID := c.Param("id")
|
||||
|
||||
if h.store == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
f, err := h.store.GetFile(fileID)
|
||||
if err != nil {
|
||||
log.Printf("[FileHandler] 查询文件失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文件失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if f == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "文件不存在",
|
||||
"errorType": "file_not_found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if f.UserID != userID && !f.IsPublic {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文件", "errorType": "access_denied"})
|
||||
return
|
||||
}
|
||||
|
||||
// 如果是图片,生成缩略图
|
||||
if isImageType(f.MimeType) && f.MimeType != "image/svg+xml" {
|
||||
if thumbData, contentType, err := generateThumbnail(f.StoredPath, f.MimeType); err == nil {
|
||||
c.Header("Content-Type", contentType)
|
||||
c.Header("Cache-Control", "public, max-age=86400")
|
||||
c.Data(http.StatusOK, contentType, thumbData)
|
||||
return
|
||||
} else {
|
||||
log.Printf("[FileHandler] 生成缩略图失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 非图片文件或缩略图生成失败,返回占位图标 SVG
|
||||
placeholder := generatePlaceholderSVG(f.MimeType)
|
||||
c.Header("Content-Type", "image/svg+xml")
|
||||
c.Header("Cache-Control", "public, max-age=86400")
|
||||
c.Data(http.StatusOK, "image/svg+xml", []byte(placeholder))
|
||||
}
|
||||
|
||||
// ========== 辅助函数 ==========
|
||||
|
||||
// isImageType 判断是否图片类型
|
||||
func isImageType(mimeType string) bool {
|
||||
return strings.HasPrefix(mimeType, "image/")
|
||||
}
|
||||
|
||||
// sanitizeFilename 安全化文件名:移除路径分隔符、特殊字符
|
||||
var unsafeChars = regexp.MustCompile(`[\\/:*?"<>|]`)
|
||||
|
||||
func sanitizeFilename(name string) string {
|
||||
// 移除路径分隔符和Windows特殊字符
|
||||
name = unsafeChars.ReplaceAllString(name, "_")
|
||||
// 限制长度
|
||||
if len(name) > 255 {
|
||||
ext := filepath.Ext(name)
|
||||
base := name[:255-len(ext)]
|
||||
name = base + ext
|
||||
}
|
||||
// 为空时给默认名
|
||||
if name == "" || name == "." || name == ".." {
|
||||
name = "unnamed_file"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// generateThumbnail 使用 Go 标准库生成缩略图 (最大 300x300)
|
||||
func generateThumbnail(filePath, mimeType string) ([]byte, string, error) {
|
||||
// 打开源文件
|
||||
srcFile, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("打开文件失败: %w", err)
|
||||
}
|
||||
defer srcFile.Close()
|
||||
|
||||
// 解码图像
|
||||
var srcImg image.Image
|
||||
switch mimeType {
|
||||
case "image/jpeg":
|
||||
srcImg, err = jpeg.Decode(srcFile)
|
||||
case "image/png":
|
||||
srcImg, err = png.Decode(srcFile)
|
||||
default:
|
||||
// 尝试通用解码
|
||||
srcImg, _, err = image.Decode(srcFile)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("解码图像失败: %w", err)
|
||||
}
|
||||
|
||||
// 计算缩略图尺寸
|
||||
bounds := srcImg.Bounds()
|
||||
srcW := bounds.Dx()
|
||||
srcH := bounds.Dy()
|
||||
|
||||
maxDim := 300
|
||||
newW, newH := srcW, srcH
|
||||
if srcW > maxDim || srcH > maxDim {
|
||||
if srcW > srcH {
|
||||
newW = maxDim
|
||||
newH = srcH * maxDim / srcW
|
||||
} else {
|
||||
newH = maxDim
|
||||
newW = srcW * maxDim / srcH
|
||||
}
|
||||
}
|
||||
if newW < 1 {
|
||||
newW = 1
|
||||
}
|
||||
if newH < 1 {
|
||||
newH = 1
|
||||
}
|
||||
|
||||
// 使用标准库双线性缩放
|
||||
thumbImg := image.NewRGBA(image.Rect(0, 0, newW, newH))
|
||||
scaleBilinear(thumbImg, srcImg)
|
||||
|
||||
// 编码为 JPEG 输出
|
||||
var buf strings.Builder
|
||||
errWriter := &stringWriter{&buf}
|
||||
if err := jpeg.Encode(errWriter, thumbImg, &jpeg.Options{Quality: 80}); err != nil {
|
||||
return nil, "", fmt.Errorf("编码缩略图失败: %w", err)
|
||||
}
|
||||
|
||||
return []byte(buf.String()), "image/jpeg", nil
|
||||
}
|
||||
|
||||
// stringWriter 实现 io.Writer 到 strings.Builder
|
||||
type stringWriter struct {
|
||||
b *strings.Builder
|
||||
}
|
||||
|
||||
func (w *stringWriter) Write(p []byte) (int, error) {
|
||||
return w.b.Write(p)
|
||||
}
|
||||
|
||||
// generatePlaceholderSVG 为非图片文件生成占位图标 SVG
|
||||
func generatePlaceholderSVG(mimeType string) string {
|
||||
var icon, color string
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(mimeType, "audio/"):
|
||||
icon = "🎵"
|
||||
color = "#8B5CF6" // purple
|
||||
case strings.HasPrefix(mimeType, "video/"):
|
||||
icon = "🎬"
|
||||
color = "#EF4444" // red
|
||||
case strings.HasPrefix(mimeType, "application/pdf"):
|
||||
icon = "📄"
|
||||
color = "#F59E0B" // amber
|
||||
case strings.HasPrefix(mimeType, "text/"):
|
||||
icon = "📝"
|
||||
color = "#3B82F6" // blue
|
||||
default:
|
||||
icon = "📎"
|
||||
color = "#6B7280" // gray
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`<svg xmlns="http://www.w3.org/2000/svg" width="300" height="300" viewBox="0 0 300 300">
|
||||
<rect width="300" height="300" fill="%s" opacity="0.1"/>
|
||||
<text x="150" y="160" text-anchor="middle" font-size="64" fill="%s">%s</text>
|
||||
<text x="150" y="220" text-anchor="middle" font-size="16" fill="%s" opacity="0.7">%s</text>
|
||||
</svg>`, color, color, icon, color, getMimeTypeShort(mimeType))
|
||||
}
|
||||
|
||||
// getMimeTypeShort 获取MIME类型简称
|
||||
func getMimeTypeShort(mimeType string) string {
|
||||
parts := strings.SplitN(mimeType, "/", 2)
|
||||
if len(parts) < 2 {
|
||||
return mimeType
|
||||
}
|
||||
return strings.ToUpper(parts[1])
|
||||
}
|
||||
|
||||
// generateUUID 使用 crypto/rand 生成 UUID v4 格式的字符串
|
||||
func generateUUID() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// 降级方案:基于时间戳 + 随机数的唯一标识
|
||||
b = make([]byte, 16)
|
||||
ts := time.Now().UnixNano()
|
||||
for i := 0; i < 8; i++ {
|
||||
b[i] = byte(ts >> (i * 8))
|
||||
}
|
||||
// 用简单 PRNG 填充剩余字节
|
||||
for i := 8; i < 16; i++ {
|
||||
b[i] = byte((ts * int64(i+1)) % 256)
|
||||
}
|
||||
}
|
||||
// 设置 UUID v4 版本位 (version = 4)
|
||||
b[6] = (b[6] & 0x0f) | 0x40
|
||||
// 设置 UUID variant 位 (variant = 10xx)
|
||||
b[8] = (b[8] & 0x3f) | 0x80
|
||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
|
||||
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
||||
}
|
||||
|
||||
// scaleBilinear 使用双线性插值将 src 图像缩放到 dst 的尺寸 (纯标准库实现)
|
||||
func scaleBilinear(dst *image.RGBA, src image.Image) {
|
||||
dstBounds := dst.Bounds()
|
||||
srcBounds := src.Bounds()
|
||||
dstW := dstBounds.Dx()
|
||||
dstH := dstBounds.Dy()
|
||||
srcW := srcBounds.Dx()
|
||||
srcH := srcBounds.Dy()
|
||||
|
||||
// 计算缩放比例
|
||||
scaleX := float64(srcW) / float64(dstW)
|
||||
scaleY := float64(srcH) / float64(dstH)
|
||||
|
||||
for dy := 0; dy < dstH; dy++ {
|
||||
for dx := 0; dx < dstW; dx++ {
|
||||
// 源图像中的浮点坐标
|
||||
sx := float64(dx)*scaleX + float64(srcBounds.Min.X)
|
||||
sy := float64(dy)*scaleY + float64(srcBounds.Min.Y)
|
||||
|
||||
// 四个邻近像素的整数坐标
|
||||
x0 := int(sx)
|
||||
y0 := int(sy)
|
||||
x1 := x0 + 1
|
||||
y1 := y0 + 1
|
||||
|
||||
// 边界限制
|
||||
if x0 < srcBounds.Min.X {
|
||||
x0 = srcBounds.Min.X
|
||||
}
|
||||
if x1 >= srcBounds.Max.X {
|
||||
x1 = srcBounds.Max.X - 1
|
||||
}
|
||||
if x0 >= srcBounds.Max.X {
|
||||
x0 = srcBounds.Max.X - 1
|
||||
}
|
||||
if x1 < srcBounds.Min.X {
|
||||
x1 = srcBounds.Min.X
|
||||
}
|
||||
if y0 < srcBounds.Min.Y {
|
||||
y0 = srcBounds.Min.Y
|
||||
}
|
||||
if y1 >= srcBounds.Max.Y {
|
||||
y1 = srcBounds.Max.Y - 1
|
||||
}
|
||||
if y0 >= srcBounds.Max.Y {
|
||||
y0 = srcBounds.Max.Y - 1
|
||||
}
|
||||
if y1 < srcBounds.Min.Y {
|
||||
y1 = srcBounds.Min.Y
|
||||
}
|
||||
|
||||
// 插值权重
|
||||
fracX := sx - float64(x0)
|
||||
fracY := sy - float64(y0)
|
||||
|
||||
// 四个角的 RGBA 值
|
||||
r00, g00, b00, a00 := src.At(x0, y0).RGBA()
|
||||
r10, g10, b10, a10 := src.At(x1, y0).RGBA()
|
||||
r01, g01, b01, a01 := src.At(x0, y1).RGBA()
|
||||
r11, g11, b11, a11 := src.At(x1, y1).RGBA()
|
||||
|
||||
// 双线性插值 (在 0-65535 范围内进行)
|
||||
interp := func(c00, c10, c01, c11 uint32) uint8 {
|
||||
top := float64(c00)*(1-fracX) + float64(c10)*fracX
|
||||
bot := float64(c01)*(1-fracX) + float64(c11)*fracX
|
||||
val := top*(1-fracY) + bot*fracY
|
||||
return uint8(val / 256)
|
||||
}
|
||||
|
||||
r := interp(r00, r10, r01, r11)
|
||||
g := interp(g00, g10, g01, g11)
|
||||
b := interp(b00, b10, b01, b11)
|
||||
a := interp(a00, a10, a01, a11)
|
||||
|
||||
dst.SetRGBA(dx+int(dstBounds.Min.X), dy+int(dstBounds.Min.Y), color.RGBA{r, g, b, a})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,718 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||||
)
|
||||
|
||||
// ImageHandler 图片分析处理器
|
||||
type ImageHandler struct {
|
||||
cfg *config.Config
|
||||
fileStore *store.FileStore
|
||||
}
|
||||
|
||||
// NewImageHandler 创建图片分析处理器
|
||||
func NewImageHandler(cfg *config.Config, fileStore *store.FileStore) *ImageHandler {
|
||||
return &ImageHandler{
|
||||
cfg: cfg,
|
||||
fileStore: fileStore,
|
||||
}
|
||||
}
|
||||
|
||||
// ImageAnalysis 图片分析结果
|
||||
type ImageAnalysis struct {
|
||||
Format string `json:"format"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
FileSize int64 `json:"file_size"`
|
||||
Description string `json:"description"`
|
||||
TopColors []ColorInfo `json:"top_colors,omitempty"`
|
||||
EXIF map[string]string `json:"exif,omitempty"`
|
||||
AnalyzedBy string `json:"analyzed_by"` // "openai_vision" | "local"
|
||||
}
|
||||
|
||||
// ColorInfo 颜色信息
|
||||
type ColorInfo struct {
|
||||
Hex string `json:"hex"`
|
||||
Percent float64 `json:"percent"`
|
||||
}
|
||||
|
||||
// AnalyzeRequestBody 分析请求体
|
||||
type AnalyzeRequestBody struct {
|
||||
FileID string `json:"file_id"`
|
||||
}
|
||||
|
||||
// ========== POST /api/v1/images/analyze ==========
|
||||
|
||||
// Analyze 分析上传的图片 (multipart/form-data 或 JSON)
|
||||
func (h *ImageHandler) Analyze(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
// 尝试 JSON body: {"file_id": "xxx"}
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
var body AnalyzeRequestBody
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.FileID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 file_id 字段", "errorType": "invalid_request"})
|
||||
return
|
||||
}
|
||||
h.analyzeByFileID(c, userID, body.FileID)
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试 multipart/form-data: 直接上传图片分析
|
||||
file, header, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
// 也尝试 "image" 字段名
|
||||
file, header, err = c.Request.FormFile("image")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到图片文件 (使用 file 或 image 字段)", "errorType": "missing_file"})
|
||||
return
|
||||
}
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
h.analyzeUploadedFile(c, userID, file, header.Filename, header.Size)
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/images/analyze/:file_id ==========
|
||||
|
||||
// AnalyzeByID 对已上传的文件进行分析
|
||||
func (h *ImageHandler) AnalyzeByID(c *gin.Context) {
|
||||
userID := middleware.GetUserID(c)
|
||||
fileID := c.Param("file_id")
|
||||
if fileID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 file_id", "errorType": "invalid_request"})
|
||||
return
|
||||
}
|
||||
h.analyzeByFileID(c, userID, fileID)
|
||||
}
|
||||
|
||||
// analyzeByFileID 根据文件ID分析已存储的图片
|
||||
func (h *ImageHandler) analyzeByFileID(c *gin.Context, userID, fileID string) {
|
||||
if h.fileStore == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
f, err := h.fileStore.GetFile(fileID)
|
||||
if err != nil {
|
||||
log.Printf("[ImageHandler] 查询文件失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文件失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if f == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在", "errorType": "file_not_found"})
|
||||
return
|
||||
}
|
||||
if f.UserID != userID && !f.IsPublic {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文件", "errorType": "access_denied"})
|
||||
return
|
||||
}
|
||||
if !isImageType(f.MimeType) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "文件不是图片类型: " + f.MimeType, "errorType": "unsupported_type"})
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.analyzeImage(f.StoredPath, f.MimeType, f.Size)
|
||||
if err != nil {
|
||||
log.Printf("[ImageHandler] 图片分析失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "图片分析失败: " + err.Error(), "errorType": "analysis_error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// analyzeUploadedFile 分析直接上传的图片文件
|
||||
func (h *ImageHandler) analyzeUploadedFile(c *gin.Context, userID string, file io.Reader, filename string, fileSize int64) {
|
||||
// 检查文件大小 (10MB 限制)
|
||||
const maxImageSize = 10 * 1024 * 1024
|
||||
if fileSize > maxImageSize {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "图片大小超过限制 (最大 10MB)", "errorType": "file_too_large"})
|
||||
return
|
||||
}
|
||||
|
||||
// 读取文件到内存
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "读取图片失败", "errorType": "read_error"})
|
||||
return
|
||||
}
|
||||
|
||||
// 检测格式
|
||||
_, format, err := image.DecodeConfig(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无法解码图片: " + err.Error(), "errorType": "decode_error"})
|
||||
return
|
||||
}
|
||||
|
||||
mimeType := "image/" + format
|
||||
supportedFormats := map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/png": true,
|
||||
"image/gif": true,
|
||||
}
|
||||
if !supportedFormats[mimeType] {
|
||||
// 允许所有 image/* 格式,但只对常见格式做深入分析
|
||||
}
|
||||
|
||||
// 写入临时文件进行分析
|
||||
tmpFile, err := os.CreateTemp("", "cyrene-image-*."+format)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建临时文件失败", "errorType": "server_error"})
|
||||
return
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
defer tmpFile.Close()
|
||||
|
||||
if _, err := tmpFile.Write(data); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "写入临时文件失败", "errorType": "server_error"})
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.analyzeImage(tmpFile.Name(), mimeType, int64(len(data)))
|
||||
if err != nil {
|
||||
log.Printf("[ImageHandler] 图片分析失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "图片分析失败: " + err.Error(), "errorType": "analysis_error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
// analyzeImage 核心分析逻辑:先尝试 OpenAI Vision,失败则降级到本地分析
|
||||
func (h *ImageHandler) analyzeImage(filePath, mimeType string, fileSize int64) (*ImageAnalysis, error) {
|
||||
// 如果配置了 OpenAI API Key,尝试使用 Vision API
|
||||
apiKey := h.cfg.LLMAPIKey
|
||||
if apiKey != "" {
|
||||
result, err := h.analyzeWithOpenAIVision(filePath, mimeType)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
log.Printf("[ImageHandler] OpenAI Vision 分析失败,降级到本地分析: %v", err)
|
||||
}
|
||||
|
||||
// 降级到本地分析
|
||||
return analyzeImageLocally(filePath, mimeType, fileSize)
|
||||
}
|
||||
|
||||
// analyzeWithOpenAIVision 使用 OpenAI Vision API 分析图片
|
||||
func (h *ImageHandler) analyzeWithOpenAIVision(filePath, mimeType string) (*ImageAnalysis, error) {
|
||||
// 读取图片并编码为 base64
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取图片文件失败: %w", err)
|
||||
}
|
||||
|
||||
base64Data := base64.StdEncoding.EncodeToString(data)
|
||||
dataURL := fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data)
|
||||
|
||||
// 获取本地基本信息
|
||||
localInfo, err := analyzeImageLocally(filePath, mimeType, int64(len(data)))
|
||||
if err != nil {
|
||||
localInfo = &ImageAnalysis{}
|
||||
}
|
||||
|
||||
// 构建 OpenAI Vision API 请求
|
||||
reqBody := map[string]interface{}{
|
||||
"model": h.cfg.LLMModel,
|
||||
"messages": []map[string]interface{}{
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]interface{}{
|
||||
{
|
||||
"type": "text",
|
||||
"text": "请详细描述这张图片的内容。用中文回答。请描述:1) 图片中的主要物体/人物 2) 场景/环境 3) 颜色和色调 4) 文字内容(如果有)5) 整体氛围和风格。请尽可能详细。",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]string{
|
||||
"url": dataURL,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"max_tokens": 500,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
apiURL := strings.TrimRight(h.cfg.LLMAPIURL, "/") + "/chat/completions"
|
||||
httpReq, err := http.NewRequest("POST", apiURL, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+h.cfg.LLMAPIKey)
|
||||
|
||||
httpClient := &http.Client{}
|
||||
resp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("API 请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API 返回错误 (%d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
var description string
|
||||
if len(result.Choices) > 0 {
|
||||
description = result.Choices[0].Message.Content
|
||||
}
|
||||
|
||||
return &ImageAnalysis{
|
||||
Format: localInfo.Format,
|
||||
Width: localInfo.Width,
|
||||
Height: localInfo.Height,
|
||||
FileSize: localInfo.FileSize,
|
||||
Description: description,
|
||||
TopColors: localInfo.TopColors,
|
||||
EXIF: localInfo.EXIF,
|
||||
AnalyzedBy: "openai_vision",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// analyzeImageLocally 使用 Go 标准库进行本地图片分析
|
||||
func analyzeImageLocally(filePath, mimeType string, fileSize int64) (*ImageAnalysis, error) {
|
||||
// 1. 读取文件
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 解码图片
|
||||
img, format, err := image.Decode(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解码图片失败: %w", err)
|
||||
}
|
||||
|
||||
// 3. 获取尺寸
|
||||
bounds := img.Bounds()
|
||||
width := bounds.Dx()
|
||||
height := bounds.Dy()
|
||||
|
||||
// 4. 计算颜色直方图 (采样像素)
|
||||
topColors := computeColorHistogram(img, 5)
|
||||
|
||||
// 5. 读取 EXIF (简单实现: 仅 JPEG)
|
||||
exif := extractEXIF(data, format)
|
||||
|
||||
// 6. 生成描述文本
|
||||
description := generateLocalDescription(format, width, height, fileSize, topColors)
|
||||
|
||||
return &ImageAnalysis{
|
||||
Format: format,
|
||||
Width: width,
|
||||
Height: height,
|
||||
FileSize: fileSize,
|
||||
Description: description,
|
||||
TopColors: topColors,
|
||||
EXIF: exif,
|
||||
AnalyzedBy: "local",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// computeColorHistogram 计算颜色直方图,返回 top N 颜色
|
||||
func computeColorHistogram(img image.Image, topN int) []ColorInfo {
|
||||
bounds := img.Bounds()
|
||||
width := bounds.Dx()
|
||||
height := bounds.Dy()
|
||||
|
||||
// 采样间隔:每 step 个像素采样一个
|
||||
step := 1
|
||||
totalPixels := width * height
|
||||
if totalPixels > 10000 {
|
||||
step = (width * height) / 10000
|
||||
if step < 1 {
|
||||
step = 1
|
||||
}
|
||||
}
|
||||
|
||||
colorCount := make(map[string]int)
|
||||
sampledCount := 0
|
||||
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y += step {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x += step {
|
||||
r, g, b, _ := img.At(x, y).RGBA()
|
||||
// 量化到 8-bit 并聚类(每 32 级一分组,减少颜色种类)
|
||||
qr := int(r>>8) / 32
|
||||
qg := int(g>>8) / 32
|
||||
qb := int(b>>8) / 32
|
||||
key := fmt.Sprintf("%02d_%02d_%02d", qr, qg, qb)
|
||||
colorCount[key]++
|
||||
sampledCount++
|
||||
}
|
||||
}
|
||||
|
||||
if sampledCount == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 排序取 topN
|
||||
type kv struct {
|
||||
key string
|
||||
count int
|
||||
}
|
||||
var sorted []kv
|
||||
for k, v := range colorCount {
|
||||
sorted = append(sorted, kv{k, v})
|
||||
}
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].count > sorted[j].count
|
||||
})
|
||||
|
||||
result := make([]ColorInfo, 0, topN)
|
||||
for i := 0; i < topN && i < len(sorted); i++ {
|
||||
var qr, qg, qb int
|
||||
fmt.Sscanf(sorted[i].key, "%d_%d_%d", &qr, &qg, &qb)
|
||||
// 量化组的中间值
|
||||
r := qr*32 + 16
|
||||
g := qg*32 + 16
|
||||
b := qb*32 + 16
|
||||
hex := fmt.Sprintf("#%02X%02X%02X", r, g, b)
|
||||
pct := float64(sorted[i].count) / float64(sampledCount) * 100
|
||||
result = append(result, ColorInfo{
|
||||
Hex: hex,
|
||||
Percent: pct,
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// extractEXIF 简单提取 JPEG EXIF 信息
|
||||
func extractEXIF(data []byte, format string) map[string]string {
|
||||
if format != "jpeg" {
|
||||
return nil
|
||||
}
|
||||
|
||||
exif := make(map[string]string)
|
||||
|
||||
// 查找 EXIF 标记 (0xFFE1)
|
||||
for i := 0; i < len(data)-4; i++ {
|
||||
if data[i] == 0xFF && data[i+1] == 0xE1 {
|
||||
if i+10 >= len(data) {
|
||||
break
|
||||
}
|
||||
// 验证 EXIF 标识 "Exif\0\0"
|
||||
if string(data[i+4:i+10]) != "Exif\x00\x00" {
|
||||
continue
|
||||
}
|
||||
|
||||
exifStart := i + 10
|
||||
if exifStart+8 >= len(data) {
|
||||
break
|
||||
}
|
||||
|
||||
// 判断字节序
|
||||
var bigEndian bool
|
||||
if data[exifStart] == 'M' && data[exifStart+1] == 'M' {
|
||||
bigEndian = true
|
||||
} else if data[exifStart] == 'I' && data[exifStart+1] == 'I' {
|
||||
bigEndian = false
|
||||
} else {
|
||||
break
|
||||
}
|
||||
|
||||
// 读取 IFD0
|
||||
tiffStart := exifStart
|
||||
readUint16 := func(offset int) uint16 {
|
||||
if offset+2 > len(data) {
|
||||
return 0
|
||||
}
|
||||
if bigEndian {
|
||||
return uint16(data[offset])<<8 | uint16(data[offset+1])
|
||||
}
|
||||
return uint16(data[offset+1])<<8 | uint16(data[offset])
|
||||
}
|
||||
|
||||
ifd0Offset := int(readUint16(tiffStart + 4))
|
||||
if ifd0Offset < 8 {
|
||||
break
|
||||
}
|
||||
ifd0Addr := tiffStart + ifd0Offset
|
||||
if ifd0Addr+2 >= len(data) {
|
||||
break
|
||||
}
|
||||
|
||||
numEntries := int(readUint16(ifd0Addr))
|
||||
entryAddr := ifd0Addr + 2
|
||||
|
||||
// 常见 EXIF 标签
|
||||
tagNames := map[uint16]string{
|
||||
0x010F: "Make",
|
||||
0x0110: "Model",
|
||||
0x0112: "Orientation",
|
||||
0x0132: "DateTime",
|
||||
0x829A: "ExposureTime",
|
||||
0x829D: "FNumber",
|
||||
0x8827: "ISO",
|
||||
0x9003: "DateTimeOriginal",
|
||||
0x920A: "FocalLength",
|
||||
}
|
||||
|
||||
for j := 0; j < numEntries && entryAddr+12 <= len(data); j++ {
|
||||
tag := readUint16(entryAddr)
|
||||
dataType := readUint16(entryAddr + 2)
|
||||
dataCount := int(readUint16(entryAddr + 4))
|
||||
|
||||
entryAddr += 12
|
||||
|
||||
if name, ok := tagNames[tag]; ok {
|
||||
valueLen := dataCount
|
||||
switch dataType {
|
||||
case 2: // ASCII
|
||||
valueLen = dataCount
|
||||
case 3, 4: // SHORT, LONG
|
||||
valueLen = dataCount * 2
|
||||
case 5: // RATIONAL
|
||||
valueLen = dataCount * 8
|
||||
}
|
||||
|
||||
if valueLen <= 4 {
|
||||
// 值在 tag 自身中
|
||||
valData := data[entryAddr-4 : entryAddr]
|
||||
valStr := extractASCIIValue(valData, dataType, dataCount, bigEndian)
|
||||
if valStr != "" {
|
||||
exif[name] = valStr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
break // 只处理第一个 EXIF 块
|
||||
}
|
||||
}
|
||||
|
||||
if len(exif) == 0 {
|
||||
return nil
|
||||
}
|
||||
return exif
|
||||
}
|
||||
|
||||
// extractASCIIValue 从 EXIF 数据中提取 ASCII 值
|
||||
func extractASCIIValue(data []byte, dataType uint16, count int, bigEndian bool) string {
|
||||
switch dataType {
|
||||
case 2: // ASCII string
|
||||
s := string(data)
|
||||
if idx := strings.IndexByte(s, 0); idx >= 0 {
|
||||
s = s[:idx]
|
||||
}
|
||||
return s
|
||||
case 3: // SHORT
|
||||
if len(data) >= 2 {
|
||||
var val uint16
|
||||
if bigEndian {
|
||||
val = uint16(data[0])<<8 | uint16(data[1])
|
||||
} else {
|
||||
val = uint16(data[1])<<8 | uint16(data[0])
|
||||
}
|
||||
return fmt.Sprintf("%d", val)
|
||||
}
|
||||
case 5: // RATIONAL
|
||||
// 简化处理:返回原始字节
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// generateLocalDescription 生成本地图片描述文本
|
||||
func generateLocalDescription(format string, width, height int, fileSize int64, topColors []ColorInfo) string {
|
||||
var sb strings.Builder
|
||||
|
||||
formatNames := map[string]string{
|
||||
"jpeg": "JPEG",
|
||||
"jpg": "JPEG",
|
||||
"png": "PNG",
|
||||
"gif": "GIF",
|
||||
"webp": "WebP",
|
||||
"bmp": "BMP",
|
||||
}
|
||||
|
||||
formatName := strings.ToUpper(format)
|
||||
if name, ok := formatNames[strings.ToLower(format)]; ok {
|
||||
formatName = name
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("这是一张 %s 格式的图片,", formatName))
|
||||
sb.WriteString(fmt.Sprintf("分辨率为 %d×%d 像素,", width, height))
|
||||
sb.WriteString(fmt.Sprintf("文件大小为 %s。", formatFileSize(fileSize)))
|
||||
|
||||
// 判断大致比例
|
||||
ratio := float64(width) / float64(height)
|
||||
if ratio > 1.8 {
|
||||
sb.WriteString("图片呈宽幅横幅比例。")
|
||||
} else if ratio < 0.6 {
|
||||
sb.WriteString("图片呈竖幅比例。")
|
||||
} else if ratio > 1.2 {
|
||||
sb.WriteString("图片接近横向画幅。")
|
||||
} else if ratio < 0.8 {
|
||||
sb.WriteString("图片接近纵向画幅。")
|
||||
} else {
|
||||
sb.WriteString("图片接近正方形比例。")
|
||||
}
|
||||
|
||||
// 描述主要颜色
|
||||
if len(topColors) > 0 {
|
||||
sb.WriteString(" 主要色调为")
|
||||
for i, c := range topColors {
|
||||
if i > 0 {
|
||||
if i == len(topColors)-1 {
|
||||
sb.WriteString(" 和 ")
|
||||
} else {
|
||||
sb.WriteString("、")
|
||||
}
|
||||
}
|
||||
colorName := getColorName(c.Hex)
|
||||
sb.WriteString(fmt.Sprintf("%s(%s, %.0f%%)", colorName, c.Hex, c.Percent))
|
||||
}
|
||||
sb.WriteString("。")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// formatFileSize 格式化文件大小
|
||||
func formatFileSize(size int64) string {
|
||||
if size < 1024 {
|
||||
return fmt.Sprintf("%d B", size)
|
||||
}
|
||||
if size < 1024*1024 {
|
||||
return fmt.Sprintf("%.1f KB", float64(size)/1024)
|
||||
}
|
||||
return fmt.Sprintf("%.1f MB", float64(size)/(1024*1024))
|
||||
}
|
||||
|
||||
// getColorName 根据 hex 颜色获取中文颜色名
|
||||
func getColorName(hex string) string {
|
||||
if len(hex) < 7 {
|
||||
return hex
|
||||
}
|
||||
var r, g, b uint8
|
||||
fmt.Sscanf(hex, "#%02X%02X%02X", &r, &g, &b)
|
||||
|
||||
// 灰度判断
|
||||
if absDiff(r, g) < 20 && absDiff(g, b) < 20 && absDiff(r, b) < 20 {
|
||||
if r < 40 {
|
||||
return "黑色"
|
||||
}
|
||||
if r < 100 {
|
||||
return "深灰色"
|
||||
}
|
||||
if r < 180 {
|
||||
return "灰色"
|
||||
}
|
||||
if r < 230 {
|
||||
return "浅灰色"
|
||||
}
|
||||
return "白色"
|
||||
}
|
||||
|
||||
// HSL 近似判断色调
|
||||
maxC := max(r, max(g, b))
|
||||
minC := min(r, min(g, b))
|
||||
delta := maxC - minC
|
||||
|
||||
if delta < 30 {
|
||||
if maxC < 60 {
|
||||
return "暗色"
|
||||
}
|
||||
if maxC > 200 {
|
||||
return "浅色"
|
||||
}
|
||||
return "中性色"
|
||||
}
|
||||
|
||||
var hue string
|
||||
switch {
|
||||
case r == maxC:
|
||||
if g >= b {
|
||||
hue = "红色"
|
||||
} else {
|
||||
hue = "品红色"
|
||||
}
|
||||
case g == maxC:
|
||||
if b >= r {
|
||||
hue = "绿色"
|
||||
} else {
|
||||
hue = "黄绿色"
|
||||
}
|
||||
default:
|
||||
if r >= g {
|
||||
hue = "紫红色"
|
||||
} else {
|
||||
hue = "蓝色"
|
||||
}
|
||||
}
|
||||
|
||||
// 亮度修饰
|
||||
if maxC < 80 {
|
||||
hue = "深" + hue
|
||||
} else if minC > 200 {
|
||||
hue = "浅" + hue
|
||||
}
|
||||
|
||||
return hue
|
||||
}
|
||||
|
||||
func absDiff(a, b uint8) int {
|
||||
if a > b {
|
||||
return int(a - b)
|
||||
}
|
||||
return int(b - a)
|
||||
}
|
||||
|
||||
func max(a, b uint8) uint8 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func min(a, b uint8) uint8 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// ========== color.RGBA → string 辅助 ==========
|
||||
|
||||
var _ = color.RGBA{} // 确保 color 包被使用
|
||||
@@ -0,0 +1,589 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||||
)
|
||||
|
||||
// KnowledgeHandler 知识库处理器
|
||||
type KnowledgeHandler struct {
|
||||
store *store.KnowledgeStore
|
||||
fileStore *store.FileStore
|
||||
}
|
||||
|
||||
// NewKnowledgeHandler 创建知识库处理器
|
||||
func NewKnowledgeHandler(s *store.KnowledgeStore, fs *store.FileStore) *KnowledgeHandler {
|
||||
return &KnowledgeHandler{store: s, fileStore: fs}
|
||||
}
|
||||
|
||||
// checkStore 检查知识库存储是否可用,不可用时返回 true(调用方应 return)
|
||||
func (h *KnowledgeHandler) checkStore(c *gin.Context) bool {
|
||||
if h.store == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": "知识库服务不可用(数据库未连接)",
|
||||
"errorType": "service_unavailable",
|
||||
})
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ========== 请求/响应类型 ==========
|
||||
|
||||
type createKBRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
type updateKBRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
type addDocRequest struct {
|
||||
Title string `json:"title" binding:"required"`
|
||||
Content string `json:"content"`
|
||||
SourceType string `json:"source_type"`
|
||||
FileID string `json:"file_id"`
|
||||
}
|
||||
|
||||
type searchRequest struct {
|
||||
Query string `json:"query" binding:"required"`
|
||||
KBIDs []string `json:"kb_ids"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
// ========== POST /api/v1/knowledge/bases — 创建知识库 ==========
|
||||
|
||||
func (h *KnowledgeHandler) CreateKB(c *gin.Context) {
|
||||
if h.checkStore(c) {
|
||||
return
|
||||
}
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
var req createKBRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供知识库名称", "errorType": "invalid_request"})
|
||||
return
|
||||
}
|
||||
|
||||
kb := &store.KnowledgeBase{
|
||||
ID: store.GenerateUUID(),
|
||||
UserID: userID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
}
|
||||
|
||||
if err := h.store.CreateKB(kb); err != nil {
|
||||
log.Printf("[KnowledgeHandler] 创建知识库失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建知识库失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, kb)
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/knowledge/bases — 列出用户的知识库 ==========
|
||||
|
||||
func (h *KnowledgeHandler) ListKBs(c *gin.Context) {
|
||||
if h.checkStore(c) {
|
||||
return
|
||||
}
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
kbs, err := h.store.GetKBsByUser(userID)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeHandler] 查询知识库列表失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库列表失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"knowledge_bases": kbs, "total": len(kbs)})
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/knowledge/bases/:id — 获取知识库详情 ==========
|
||||
|
||||
func (h *KnowledgeHandler) GetKB(c *gin.Context) {
|
||||
if h.checkStore(c) {
|
||||
return
|
||||
}
|
||||
userID := middleware.GetUserID(c)
|
||||
kbID := c.Param("id")
|
||||
|
||||
kb, err := h.store.GetKB(kbID)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if kb == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
|
||||
return
|
||||
}
|
||||
if kb.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此知识库", "errorType": "access_denied"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取文档列表
|
||||
docs, err := h.store.GetDocumentsByKB(kbID)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeHandler] 查询文档列表失败: %v", err)
|
||||
docs = []store.KnowledgeDocument{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"knowledge_base": kb,
|
||||
"documents": docs,
|
||||
})
|
||||
}
|
||||
|
||||
// ========== PUT /api/v1/knowledge/bases/:id — 更新知识库 ==========
|
||||
|
||||
func (h *KnowledgeHandler) UpdateKB(c *gin.Context) {
|
||||
if h.checkStore(c) {
|
||||
return
|
||||
}
|
||||
userID := middleware.GetUserID(c)
|
||||
kbID := c.Param("id")
|
||||
|
||||
var req updateKBRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供知识库名称", "errorType": "invalid_request"})
|
||||
return
|
||||
}
|
||||
|
||||
kb, err := h.store.GetKB(kbID)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if kb == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
|
||||
return
|
||||
}
|
||||
if kb.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权修改此知识库", "errorType": "access_denied"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.UpdateKB(kbID, req.Name, req.Description); err != nil {
|
||||
log.Printf("[KnowledgeHandler] 更新知识库失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新知识库失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "updated"})
|
||||
}
|
||||
|
||||
// ========== DELETE /api/v1/knowledge/bases/:id — 删除知识库 ==========
|
||||
|
||||
func (h *KnowledgeHandler) DeleteKB(c *gin.Context) {
|
||||
if h.checkStore(c) {
|
||||
return
|
||||
}
|
||||
userID := middleware.GetUserID(c)
|
||||
kbID := c.Param("id")
|
||||
|
||||
kb, err := h.store.GetKB(kbID)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if kb == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
|
||||
return
|
||||
}
|
||||
if kb.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此知识库", "errorType": "access_denied"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.DeleteKB(kbID); err != nil {
|
||||
log.Printf("[KnowledgeHandler] 删除知识库失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除知识库失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||
}
|
||||
|
||||
// ========== POST /api/v1/knowledge/bases/:id/documents — 添加文档 ==========
|
||||
|
||||
func (h *KnowledgeHandler) AddDocument(c *gin.Context) {
|
||||
if h.checkStore(c) {
|
||||
return
|
||||
}
|
||||
userID := middleware.GetUserID(c)
|
||||
kbID := c.Param("id")
|
||||
|
||||
// 检查知识库是否存在且属于当前用户
|
||||
kb, err := h.store.GetKB(kbID)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if kb == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
|
||||
return
|
||||
}
|
||||
if kb.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权操作此知识库", "errorType": "access_denied"})
|
||||
return
|
||||
}
|
||||
|
||||
var req addDocRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文档标题", "errorType": "invalid_request"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.SourceType == "" {
|
||||
req.SourceType = "text"
|
||||
}
|
||||
|
||||
var content string
|
||||
var sourceRef string
|
||||
var contentType string
|
||||
|
||||
switch req.SourceType {
|
||||
case "text":
|
||||
content = req.Content
|
||||
contentType = "text/plain"
|
||||
if content == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文档内容", "errorType": "invalid_request"})
|
||||
return
|
||||
}
|
||||
case "file":
|
||||
if req.FileID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文件ID", "errorType": "invalid_request"})
|
||||
return
|
||||
}
|
||||
sourceRef = req.FileID
|
||||
|
||||
// 从 FileStore 读取文件内容
|
||||
if h.fileStore != nil {
|
||||
f, err := h.fileStore.GetFile(req.FileID)
|
||||
if err != nil || f == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在", "errorType": "not_found"})
|
||||
return
|
||||
}
|
||||
if f.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文件", "errorType": "access_denied"})
|
||||
return
|
||||
}
|
||||
// 注意:这里只支持文本类型文件的读取
|
||||
// 对于二进制文件,需要更复杂的解析逻辑
|
||||
sourceRef = f.Filename
|
||||
// 从磁盘读取文件内容
|
||||
content, contentType, _ = readFileContent(f.StoredPath, f.MimeType)
|
||||
if content == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "无法读取文件内容,仅支持文本文件", "errorType": "unsupported_file"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
|
||||
return
|
||||
}
|
||||
case "url":
|
||||
sourceRef = req.FileID
|
||||
content = req.Content
|
||||
contentType = "text/html"
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的来源类型: " + req.SourceType, "errorType": "invalid_request"})
|
||||
return
|
||||
}
|
||||
|
||||
if content == "" {
|
||||
content = req.Content
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = "text/plain"
|
||||
}
|
||||
|
||||
doc := &store.KnowledgeDocument{
|
||||
ID: store.GenerateUUID(),
|
||||
KBID: kbID,
|
||||
UserID: userID,
|
||||
Title: req.Title,
|
||||
SourceType: req.SourceType,
|
||||
SourceRef: sourceRef,
|
||||
ContentType: contentType,
|
||||
RawContent: content,
|
||||
}
|
||||
|
||||
if err := h.store.AddDocument(doc); err != nil {
|
||||
log.Printf("[KnowledgeHandler] 添加文档失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加文档失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
|
||||
// 自动分块
|
||||
chunkCount, err := h.store.ChunkDocument(doc.ID)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeHandler] 文档分块失败: %v", err)
|
||||
// 分块失败不影响文档创建
|
||||
}
|
||||
|
||||
doc.ChunkCount = chunkCount
|
||||
|
||||
c.JSON(http.StatusCreated, doc)
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/knowledge/bases/:id/documents — 列出知识库中的文档 ==========
|
||||
|
||||
func (h *KnowledgeHandler) ListDocuments(c *gin.Context) {
|
||||
if h.checkStore(c) {
|
||||
return
|
||||
}
|
||||
userID := middleware.GetUserID(c)
|
||||
kbID := c.Param("id")
|
||||
|
||||
// 检查权限
|
||||
kb, err := h.store.GetKB(kbID)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if kb == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
|
||||
return
|
||||
}
|
||||
if kb.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此知识库", "errorType": "access_denied"})
|
||||
return
|
||||
}
|
||||
|
||||
docs, err := h.store.GetDocumentsByKB(kbID)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeHandler] 查询文档列表失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档列表失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"documents": docs, "total": len(docs)})
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/knowledge/documents/:id — 获取文档详情 ==========
|
||||
|
||||
func (h *KnowledgeHandler) GetDocument(c *gin.Context) {
|
||||
if h.checkStore(c) {
|
||||
return
|
||||
}
|
||||
userID := middleware.GetUserID(c)
|
||||
docID := c.Param("id")
|
||||
|
||||
doc, err := h.store.GetDocument(docID)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeHandler] 查询文档失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if doc == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "文档不存在", "errorType": "not_found"})
|
||||
return
|
||||
}
|
||||
if doc.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文档", "errorType": "access_denied"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取分块
|
||||
chunks, err := h.store.GetChunksByDocID(docID)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeHandler] 查询分块失败: %v", err)
|
||||
chunks = []store.KnowledgeChunk{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"document": doc,
|
||||
"chunks": chunks,
|
||||
})
|
||||
}
|
||||
|
||||
// ========== DELETE /api/v1/knowledge/documents/:id — 删除文档 ==========
|
||||
|
||||
func (h *KnowledgeHandler) DeleteDocument(c *gin.Context) {
|
||||
if h.checkStore(c) {
|
||||
return
|
||||
}
|
||||
userID := middleware.GetUserID(c)
|
||||
docID := c.Param("id")
|
||||
|
||||
doc, err := h.store.GetDocument(docID)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeHandler] 查询文档失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if doc == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "文档不存在", "errorType": "not_found"})
|
||||
return
|
||||
}
|
||||
if doc.UserID != userID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此文档", "errorType": "access_denied"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.DeleteDocument(docID); err != nil {
|
||||
log.Printf("[KnowledgeHandler] 删除文档失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除文档失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||
}
|
||||
|
||||
// ========== POST /api/v1/knowledge/search — 搜索知识库 ==========
|
||||
|
||||
func (h *KnowledgeHandler) Search(c *gin.Context) {
|
||||
if h.checkStore(c) {
|
||||
return
|
||||
}
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
var req searchRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供搜索关键词", "errorType": "invalid_request"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Limit <= 0 {
|
||||
req.Limit = 5
|
||||
}
|
||||
if req.Limit > 50 {
|
||||
req.Limit = 50
|
||||
}
|
||||
|
||||
var results []store.SearchChunkResult
|
||||
var err error
|
||||
|
||||
if len(req.KBIDs) == 0 {
|
||||
// 搜索所有知识库
|
||||
results, err = h.store.SearchAllKBs(userID, req.Query, req.Limit)
|
||||
} else {
|
||||
// 搜索指定知识库,需要验证权限
|
||||
results = []store.SearchChunkResult{}
|
||||
for _, kbID := range req.KBIDs {
|
||||
kb, checkErr := h.store.GetKB(kbID)
|
||||
if checkErr != nil || kb == nil || kb.UserID != userID {
|
||||
continue // 跳过无权限或不存在的知识库
|
||||
}
|
||||
kbResults, searchErr := h.store.SearchChunks(kbID, req.Query, req.Limit)
|
||||
if searchErr != nil {
|
||||
log.Printf("[KnowledgeHandler] 搜索知识库 %s 失败: %v", kbID, searchErr)
|
||||
continue
|
||||
}
|
||||
results = append(results, kbResults...)
|
||||
}
|
||||
|
||||
// 限制总结果数
|
||||
if len(results) > req.Limit {
|
||||
results = results[:req.Limit]
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeHandler] 搜索失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "搜索失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成高亮片段
|
||||
for i := range results {
|
||||
results[i].Headline = generateHeadline(results[i].Content, req.Query, 200)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"chunks": results,
|
||||
"total": len(results),
|
||||
"query": req.Query,
|
||||
})
|
||||
}
|
||||
|
||||
// ========== 辅助函数 ==========
|
||||
|
||||
// generateHeadline 生成高亮片段,提取查询关键词周围的文本
|
||||
func generateHeadline(content, query string, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
maxLen = 200
|
||||
}
|
||||
|
||||
runes := []rune(content)
|
||||
if len(runes) <= maxLen {
|
||||
return content
|
||||
}
|
||||
|
||||
// 查找查询关键词位置
|
||||
queryRunes := []rune(query)
|
||||
pos := -1
|
||||
for i := 0; i <= len(runes)-len(queryRunes); i++ {
|
||||
match := true
|
||||
for j := 0; j < len(queryRunes); j++ {
|
||||
if runes[i+j] != queryRunes[j] {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
pos = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if pos < 0 {
|
||||
// 没有找到精确匹配,返回前 maxLen 个字符
|
||||
return string(runes[:maxLen]) + "..."
|
||||
}
|
||||
|
||||
// 以匹配位置为中心,截取上下文
|
||||
half := maxLen / 2
|
||||
start := pos - half
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
end := start + maxLen
|
||||
if end > len(runes) {
|
||||
end = len(runes)
|
||||
start = end - maxLen
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
}
|
||||
|
||||
result := string(runes[start:end])
|
||||
if start > 0 {
|
||||
result = "..." + result
|
||||
}
|
||||
if end < len(runes) {
|
||||
result = result + "..."
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// readFileContent 从磁盘读取文件内容 (仅支持文本类型)
|
||||
func readFileContent(path, mimeType string) (content string, contentType string, err error) {
|
||||
// 只支持文本类型
|
||||
if !strings.HasPrefix(mimeType, "text/") && mimeType != "application/json" {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return string(data), mimeType, nil
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/config"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/ws"
|
||||
)
|
||||
|
||||
// NotificationHandler 通知推送处理器
|
||||
type NotificationHandler struct {
|
||||
cfg *config.Config
|
||||
hub *ws.Hub
|
||||
}
|
||||
|
||||
// NewNotificationHandler 创建通知处理器
|
||||
func NewNotificationHandler(cfg *config.Config, hub *ws.Hub) *NotificationHandler {
|
||||
return &NotificationHandler{cfg: cfg, hub: hub}
|
||||
}
|
||||
|
||||
// PushNotificationRequest 推送通知请求体
|
||||
type PushNotificationRequest struct {
|
||||
UserID string `json:"user_id" binding:"required"`
|
||||
Type string `json:"type" binding:"required,oneof=info warning success thinking reminder"`
|
||||
Title string `json:"title" binding:"required"`
|
||||
Body string `json:"body" binding:"required"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Push 推送通知到指定用户 (需要 JWT 认证)
|
||||
// POST /api/v1/notifications/push
|
||||
func (h *NotificationHandler) Push(c *gin.Context) {
|
||||
var req PushNotificationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成通知
|
||||
notif := h.buildNotification(req)
|
||||
|
||||
// 序列化 WS 消息
|
||||
msg := ws.ServerMessage{
|
||||
Type: "notification",
|
||||
MessageID: "notif_" + generateID(),
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
Notification: notif,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
log.Printf("[notification] 序列化通知失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "内部错误"})
|
||||
return
|
||||
}
|
||||
|
||||
// 通过 Hub 推送给指定用户
|
||||
h.hub.SendToUser(req.UserID, data)
|
||||
|
||||
log.Printf("[notification] 通知已推送: user=%s type=%s title=%s", req.UserID, req.Type, req.Title)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"notification": gin.H{
|
||||
"id": notif.ID,
|
||||
"type": notif.Type,
|
||||
"title": notif.Title,
|
||||
"user_id": req.UserID,
|
||||
"timestamp": notif.Timestamp,
|
||||
"delivered": h.hub.UserClientCount(req.UserID) > 0,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// InternalNotify 内部服务推送通知 (使用内部 service token)
|
||||
// POST /api/v1/internal/notify
|
||||
func (h *NotificationHandler) InternalNotify(c *gin.Context) {
|
||||
var req PushNotificationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成通知
|
||||
notif := h.buildNotification(req)
|
||||
|
||||
// 序列化 WS 消息
|
||||
msg := ws.ServerMessage{
|
||||
Type: "notification",
|
||||
MessageID: "notif_" + generateID(),
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
Notification: notif,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
log.Printf("[notification] 序列化通知失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "内部错误"})
|
||||
return
|
||||
}
|
||||
|
||||
// 通过 Hub 推送给指定用户
|
||||
h.hub.SendToUser(req.UserID, data)
|
||||
|
||||
log.Printf("[notification] 内部通知已推送: user=%s type=%s title=%s", req.UserID, req.Type, req.Title)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"notification": gin.H{
|
||||
"id": notif.ID,
|
||||
"type": notif.Type,
|
||||
"title": notif.Title,
|
||||
"user_id": req.UserID,
|
||||
"timestamp": notif.Timestamp,
|
||||
"delivered": h.hub.UserClientCount(req.UserID) > 0,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// InternalNotifyAuth 内部服务认证中间件
|
||||
func (h *NotificationHandler) InternalNotifyAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := c.GetHeader("X-Internal-Token")
|
||||
if token == "" {
|
||||
token = c.GetHeader("Authorization")
|
||||
if len(token) > 7 && token[:7] == "Bearer " {
|
||||
token = token[7:]
|
||||
}
|
||||
}
|
||||
|
||||
if token != h.cfg.InternalServiceToken {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "内部认证失败"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// buildNotification 构建 NotificationInfo
|
||||
func (h *NotificationHandler) buildNotification(req PushNotificationRequest) *ws.NotificationInfo {
|
||||
notifID := "notif_" + generateID()
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
if req.Data == nil {
|
||||
req.Data = make(map[string]interface{})
|
||||
}
|
||||
|
||||
return &ws.NotificationInfo{
|
||||
ID: notifID,
|
||||
Type: req.Type,
|
||||
Title: req.Title,
|
||||
Body: req.Body,
|
||||
Timestamp: now,
|
||||
Data: req.Data,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,340 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||||
"github.com/yourname/cyrene-ai/gateway/internal/ws"
|
||||
)
|
||||
|
||||
// ReminderHandler 提醒处理器
|
||||
type ReminderHandler struct {
|
||||
store *store.ReminderStore
|
||||
hub *ws.Hub
|
||||
}
|
||||
|
||||
// NewReminderHandler 创建提醒处理器
|
||||
func NewReminderHandler(s *store.ReminderStore, hub *ws.Hub) *ReminderHandler {
|
||||
return &ReminderHandler{store: s, hub: hub}
|
||||
}
|
||||
|
||||
// CreateReminderRequest 创建提醒请求体
|
||||
type CreateReminderRequest struct {
|
||||
Title string `json:"title" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
RemindAt string `json:"remind_at" binding:"required"` // ISO 8601 格式
|
||||
RepeatType string `json:"repeat_type"` // none, daily, weekly, monthly
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// UpdateReminderRequest 更新提醒请求体
|
||||
type UpdateReminderRequest struct {
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
RemindAt string `json:"remind_at"`
|
||||
Status string `json:"status"` // pending, completed, cancelled
|
||||
RepeatType string `json:"repeat_type"` // none, daily, weekly, monthly
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
// List 获取提醒列表
|
||||
// GET /api/v1/reminders?user_id=xxx&status=pending&limit=50&offset=0
|
||||
func (h *ReminderHandler) List(c *gin.Context) {
|
||||
userID := c.Query("user_id")
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 user_id 参数"})
|
||||
return
|
||||
}
|
||||
|
||||
status := c.Query("status")
|
||||
limit := 50
|
||||
offset := 0
|
||||
|
||||
if l, ok := c.GetQuery("limit"); ok {
|
||||
if v, err := strconv.Atoi(l); err == nil && v > 0 {
|
||||
limit = v
|
||||
}
|
||||
}
|
||||
if o, ok := c.GetQuery("offset"); ok {
|
||||
if v, err := strconv.Atoi(o); err == nil && v >= 0 {
|
||||
offset = v
|
||||
}
|
||||
}
|
||||
|
||||
reminders, err := h.store.GetRemindersByUser(userID, status, limit, offset)
|
||||
if err != nil {
|
||||
log.Printf("[reminder] 获取提醒列表失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取提醒列表失败"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"reminders": reminders,
|
||||
"count": len(reminders),
|
||||
})
|
||||
}
|
||||
|
||||
// Create 创建新提醒
|
||||
// POST /api/v1/reminders
|
||||
func (h *ReminderHandler) Create(c *gin.Context) {
|
||||
var req CreateReminderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 从 JWT 获取 userID
|
||||
userID := middleware.GetUserID(c)
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未认证"})
|
||||
return
|
||||
}
|
||||
|
||||
// 解析时间
|
||||
remindAt, err := time.Parse(time.RFC3339, req.RemindAt)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "时间格式无效,请使用 ISO 8601 格式 (例如 2024-01-01T15:00:00Z)"})
|
||||
return
|
||||
}
|
||||
|
||||
// 默认值
|
||||
repeatType := req.RepeatType
|
||||
if repeatType == "" {
|
||||
repeatType = "none"
|
||||
}
|
||||
|
||||
reminder := &store.Reminder{
|
||||
ID: generateID(),
|
||||
UserID: userID,
|
||||
Title: req.Title,
|
||||
Description: req.Description,
|
||||
RemindAt: remindAt,
|
||||
Status: "pending",
|
||||
RepeatType: repeatType,
|
||||
SessionID: req.SessionID,
|
||||
Notified: false,
|
||||
}
|
||||
|
||||
if err := h.store.CreateReminder(reminder); err != nil {
|
||||
log.Printf("[reminder] 创建提醒失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建提醒失败"})
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[reminder] 提醒已创建: id=%s user=%s title=%s remind_at=%s repeat=%s",
|
||||
reminder.ID, userID, reminder.Title, remindAt.Format(time.RFC3339), repeatType)
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"success": true,
|
||||
"reminder": reminder,
|
||||
})
|
||||
}
|
||||
|
||||
// Update 更新提醒
|
||||
// PUT /api/v1/reminders/:id
|
||||
func (h *ReminderHandler) Update(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少提醒 ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID := middleware.GetUserID(c)
|
||||
|
||||
var req UpdateReminderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数无效: " + err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 先获取已有提醒
|
||||
reminders, err := h.store.GetRemindersByUser(userID, "", 100, 0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取提醒失败"})
|
||||
return
|
||||
}
|
||||
|
||||
var existing *store.Reminder
|
||||
for i := range reminders {
|
||||
if reminders[i].ID == id {
|
||||
existing = &reminders[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if existing == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "提醒不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Title != "" {
|
||||
existing.Title = req.Title
|
||||
}
|
||||
if req.Description != "" {
|
||||
existing.Description = req.Description
|
||||
}
|
||||
if req.RemindAt != "" {
|
||||
remindAt, err := time.Parse(time.RFC3339, req.RemindAt)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "时间格式无效"})
|
||||
return
|
||||
}
|
||||
existing.RemindAt = remindAt
|
||||
}
|
||||
if req.Status != "" {
|
||||
existing.Status = req.Status
|
||||
if req.Status == "completed" || req.Status == "cancelled" {
|
||||
now := time.Now()
|
||||
existing.CompletedAt = &now
|
||||
}
|
||||
}
|
||||
if req.RepeatType != "" {
|
||||
existing.RepeatType = req.RepeatType
|
||||
}
|
||||
if req.SessionID != "" {
|
||||
existing.SessionID = req.SessionID
|
||||
}
|
||||
|
||||
if err := h.store.UpdateReminder(id, existing); err != nil {
|
||||
log.Printf("[reminder] 更新提醒失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新提醒失败"})
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[reminder] 提醒已更新: id=%s", id)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"reminder": existing,
|
||||
})
|
||||
}
|
||||
|
||||
// Delete 删除提醒
|
||||
// DELETE /api/v1/reminders/:id
|
||||
func (h *ReminderHandler) Delete(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少提醒 ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.DeleteReminder(id); err != nil {
|
||||
log.Printf("[reminder] 删除提醒失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除提醒失败"})
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[reminder] 提醒已删除: id=%s", id)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
// ========== 提醒调度器 ==========
|
||||
|
||||
// StartReminderScheduler 启动提醒调度器,每 30 秒检查一次到期提醒
|
||||
func StartReminderScheduler(s *store.ReminderStore, hub *ws.Hub) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
log.Println("[ReminderScheduler] 提醒调度器已启动 (检查间隔: 30秒)")
|
||||
|
||||
for range ticker.C {
|
||||
checkAndNotify(s, hub)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// checkAndNotify 检查到期提醒并推送通知
|
||||
func checkAndNotify(s *store.ReminderStore, hub *ws.Hub) {
|
||||
reminders, err := s.GetDueReminders()
|
||||
if err != nil {
|
||||
log.Printf("[ReminderScheduler] 获取到期提醒失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(reminders) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range reminders {
|
||||
// 1. 构建 WebSocket 通知消息
|
||||
notif := &ws.NotificationInfo{
|
||||
ID: "reminder_" + r.ID,
|
||||
Type: "reminder",
|
||||
Title: r.Title,
|
||||
Body: r.Description,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
Data: map[string]interface{}{
|
||||
"reminder_id": r.ID,
|
||||
"session_id": r.SessionID,
|
||||
},
|
||||
}
|
||||
|
||||
msg := ws.ServerMessage{
|
||||
Type: "notification",
|
||||
MessageID: "reminder_" + r.ID,
|
||||
Timestamp: time.Now().UnixMilli(),
|
||||
Notification: notif,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
log.Printf("[ReminderScheduler] 序列化通知失败: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 2. 通过 Hub 向用户推送
|
||||
hub.SendToUser(r.UserID, data)
|
||||
|
||||
// 3. 标记为已通知
|
||||
if err := s.MarkNotified(r.ID); err != nil {
|
||||
log.Printf("[ReminderScheduler] 标记已通知失败: id=%s err=%v", r.ID, err)
|
||||
}
|
||||
|
||||
// 4. 处理重复提醒
|
||||
if r.RepeatType != "" && r.RepeatType != "none" {
|
||||
nextTime := calculateNextRemindAt(r.RemindAt, r.RepeatType)
|
||||
r.RemindAt = nextTime
|
||||
r.Notified = false
|
||||
if err := s.UpdateReminder(r.ID, &r); err != nil {
|
||||
log.Printf("[ReminderScheduler] 更新重复提醒失败: id=%s err=%v", r.ID, err)
|
||||
} else {
|
||||
log.Printf("[ReminderScheduler] 重复提醒已更新: id=%s next=%s", r.ID, nextTime.Format(time.RFC3339))
|
||||
}
|
||||
} else {
|
||||
// 非重复提醒:标记为已完成
|
||||
now := time.Now()
|
||||
r.Status = "completed"
|
||||
r.CompletedAt = &now
|
||||
if err := s.UpdateReminder(r.ID, &r); err != nil {
|
||||
log.Printf("[ReminderScheduler] 标记提醒完成失败: id=%s err=%v", r.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("[ReminderScheduler] 提醒已推送: user=%s title=%s id=%s", r.UserID, r.Title, r.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// calculateNextRemindAt 计算下一次提醒时间
|
||||
func calculateNextRemindAt(current time.Time, repeatType string) time.Time {
|
||||
switch repeatType {
|
||||
case "daily":
|
||||
return current.Add(24 * time.Hour)
|
||||
case "weekly":
|
||||
return current.Add(7 * 24 * time.Hour)
|
||||
case "monthly":
|
||||
return current.AddDate(0, 1, 0)
|
||||
default:
|
||||
return current
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,11 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -305,6 +308,270 @@ func (h *SessionHandler) GetSession(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, session)
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/messages/search?q=xxx&user_id=xxx&limit=50&offset=0 — 全文搜索消息 ==========
|
||||
|
||||
// SearchMessages 全文搜索消息
|
||||
func (h *SessionHandler) SearchMessages(c *gin.Context) {
|
||||
query := c.Query("q")
|
||||
if query == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少搜索关键词参数 q", "errorType": "missing_query"})
|
||||
return
|
||||
}
|
||||
|
||||
userID := c.Query("user_id")
|
||||
if userID == "" {
|
||||
userID = middleware.GetUserID(c)
|
||||
}
|
||||
|
||||
limit := 50
|
||||
offset := 0
|
||||
if l := c.Query("limit"); l != "" {
|
||||
parsed := 0
|
||||
for _, ch := range l {
|
||||
if ch < '0' || ch > '9' {
|
||||
break
|
||||
}
|
||||
parsed = parsed*10 + int(ch-'0')
|
||||
}
|
||||
if parsed > 0 && parsed <= 200 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
if o := c.Query("offset"); o != "" {
|
||||
parsed := 0
|
||||
for _, ch := range o {
|
||||
if ch < '0' || ch > '9' {
|
||||
break
|
||||
}
|
||||
parsed = parsed*10 + int(ch-'0')
|
||||
}
|
||||
if parsed >= 0 {
|
||||
offset = parsed
|
||||
}
|
||||
}
|
||||
|
||||
if !h.useDB {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"results": []gin.H{},
|
||||
"total": 0,
|
||||
"query": query,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
results, total, err := h.store.SearchMessages(userID, query, limit, offset)
|
||||
if err != nil {
|
||||
log.Printf("[SessionHandler] 搜索消息失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "搜索失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为统一格式
|
||||
items := make([]gin.H, 0, len(results))
|
||||
for _, r := range results {
|
||||
items = append(items, gin.H{
|
||||
"message_id": r.MessageID,
|
||||
"session_id": r.SessionID,
|
||||
"session_title": r.SessionTitle,
|
||||
"role": r.Role,
|
||||
"content": r.Content,
|
||||
"created_at": r.CreatedAt.UnixMilli(),
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"results": items,
|
||||
"total": total,
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
})
|
||||
}
|
||||
|
||||
// ========== GET /api/v1/sessions/:id/export?format=json|markdown|txt — 导出会话 ==========
|
||||
|
||||
// ExportSession 导出会话为指定格式
|
||||
func (h *SessionHandler) ExportSession(c *gin.Context) {
|
||||
sessionID := c.Param("id")
|
||||
format := c.Query("format")
|
||||
if format == "" {
|
||||
format = "json"
|
||||
}
|
||||
|
||||
// 验证格式
|
||||
switch format {
|
||||
case "json", "markdown", "txt":
|
||||
// valid
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "不支持的导出格式",
|
||||
"errorType": "invalid_format",
|
||||
"hint": "支持的格式: json, markdown, txt",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !h.useDB {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": "会话存储不可用",
|
||||
"errorType": "store_unavailable",
|
||||
"hint": "数据库连接未建立,无法导出会话",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取会话信息
|
||||
session, err := h.store.GetSession(sessionID)
|
||||
if err != nil {
|
||||
log.Printf("[SessionHandler] 查询会话失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询会话失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if session == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": "会话不存在",
|
||||
"errorType": "session_not_found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取所有消息 (不限制数量,导出全部)
|
||||
messages, err := h.store.GetMessages(sessionID, 0)
|
||||
if err != nil {
|
||||
log.Printf("[SessionHandler] 查询消息失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询消息失败", "errorType": "db_error"})
|
||||
return
|
||||
}
|
||||
if messages == nil {
|
||||
messages = []store.Message{}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
switch format {
|
||||
case "json":
|
||||
h.exportJSON(c, session, messages, now)
|
||||
case "markdown":
|
||||
h.exportMarkdown(c, session, messages, now)
|
||||
case "txt":
|
||||
h.exportTXT(c, session, messages, now)
|
||||
}
|
||||
}
|
||||
|
||||
// exportJSON 导出 JSON 格式
|
||||
func (h *SessionHandler) exportJSON(c *gin.Context, session *store.Session, messages []store.Message, now time.Time) {
|
||||
type msgOut struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
type sessionOut struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
type export struct {
|
||||
Session sessionOut `json:"session"`
|
||||
Messages []msgOut `json:"messages"`
|
||||
}
|
||||
|
||||
msgs := make([]msgOut, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
msgs = append(msgs, msgOut{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
CreatedAt: m.CreatedAt.UnixMilli(),
|
||||
})
|
||||
}
|
||||
|
||||
data := export{
|
||||
Session: sessionOut{
|
||||
ID: session.ID,
|
||||
Title: session.Title,
|
||||
CreatedAt: session.CreatedAt.UnixMilli(),
|
||||
UpdatedAt: session.UpdatedAt.UnixMilli(),
|
||||
},
|
||||
Messages: msgs,
|
||||
}
|
||||
|
||||
jsonBytes, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
log.Printf("[SessionHandler] JSON序列化失败: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "导出失败", "errorType": "serialization_error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "application/json; charset=utf-8")
|
||||
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="session_%s.json"`, session.ID))
|
||||
c.Data(http.StatusOK, "application/json; charset=utf-8", jsonBytes)
|
||||
}
|
||||
|
||||
// exportMarkdown 导出 Markdown 格式
|
||||
func (h *SessionHandler) exportMarkdown(c *gin.Context, session *store.Session, messages []store.Message, now time.Time) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(fmt.Sprintf("# 对话导出: %s\n", session.Title))
|
||||
sb.WriteString(fmt.Sprintf("**会话 ID**: %s\n", session.ID))
|
||||
sb.WriteString(fmt.Sprintf("**导出时间**: %s\n", now.Format("2006-01-02 15:04:05")))
|
||||
sb.WriteString(fmt.Sprintf("**消息数量**: %d\n", len(messages)))
|
||||
sb.WriteString("\n---\n\n")
|
||||
|
||||
for _, m := range messages {
|
||||
timeStr := m.CreatedAt.Format("2006-01-02 15:04:05")
|
||||
switch m.Role {
|
||||
case "user":
|
||||
sb.WriteString(fmt.Sprintf("### 👤 用户 (%s)\n\n", timeStr))
|
||||
case "assistant":
|
||||
sb.WriteString(fmt.Sprintf("### 🤖 昔涟 (%s)\n\n", timeStr))
|
||||
case "system":
|
||||
sb.WriteString(fmt.Sprintf("### ⚙️ 系统 (%s)\n\n", timeStr))
|
||||
default:
|
||||
sb.WriteString(fmt.Sprintf("### %s (%s)\n\n", m.Role, timeStr))
|
||||
}
|
||||
sb.WriteString(m.Content)
|
||||
sb.WriteString("\n\n---\n\n")
|
||||
}
|
||||
|
||||
content := sb.String()
|
||||
c.Header("Content-Type", "text/markdown; charset=utf-8")
|
||||
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="session_%s.md"`, session.ID))
|
||||
c.Data(http.StatusOK, "text/markdown; charset=utf-8", []byte(content))
|
||||
}
|
||||
|
||||
// exportTXT 导出纯文本格式
|
||||
func (h *SessionHandler) exportTXT(c *gin.Context, session *store.Session, messages []store.Message, now time.Time) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(fmt.Sprintf("对话导出: %s\n", session.Title))
|
||||
sb.WriteString(fmt.Sprintf("会话 ID: %s\n", session.ID))
|
||||
sb.WriteString(fmt.Sprintf("导出时间: %s\n", now.Format("2006-01-02 15:04:05")))
|
||||
sb.WriteString(fmt.Sprintf("消息数量: %d\n", len(messages)))
|
||||
sb.WriteString(strings.Repeat("=", 50))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
for _, m := range messages {
|
||||
timeStr := m.CreatedAt.Format("2006-01-02 15:04:05")
|
||||
roleLabel := m.Role
|
||||
switch m.Role {
|
||||
case "user":
|
||||
roleLabel = "用户"
|
||||
case "assistant":
|
||||
roleLabel = "昔涟"
|
||||
case "system":
|
||||
roleLabel = "系统"
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("[%s] %s:\n%s\n\n", timeStr, roleLabel, m.Content))
|
||||
}
|
||||
|
||||
content := sb.String()
|
||||
c.Header("Content-Type", "text/plain; charset=utf-8")
|
||||
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="session_%s.txt"`, session.ID))
|
||||
c.Data(http.StatusOK, "text/plain; charset=utf-8", []byte(content))
|
||||
}
|
||||
|
||||
// 简单的工具函数
|
||||
func randomID(n int) string {
|
||||
const letters = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// VoiceHandler 语音处理器 — 代理到 voice-service
|
||||
type VoiceHandler struct {
|
||||
voiceServiceURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewVoiceHandler 创建语音处理器
|
||||
func NewVoiceHandler(voiceServiceURL string) *VoiceHandler {
|
||||
return &VoiceHandler{
|
||||
voiceServiceURL: voiceServiceURL,
|
||||
client: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// Transcribe POST /api/v1/voice/transcribe
|
||||
// 代理 multipart/form-data 请求到 voice-service
|
||||
func (h *VoiceHandler) Transcribe(c *gin.Context) {
|
||||
// 限制上传大小 (10MB)
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, 10<<20)
|
||||
|
||||
// 读取原始请求体
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败或文件过大,最大支持 10MB"})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建代理请求
|
||||
url := strings.TrimRight(h.voiceServiceURL, "/") + "/api/v1/transcribe"
|
||||
|
||||
proxyReq, err := http.NewRequest("POST", url, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "构建代理请求失败"})
|
||||
return
|
||||
}
|
||||
proxyReq.Header.Set("Content-Type", c.GetHeader("Content-Type"))
|
||||
|
||||
resp, err := h.client.Do(proxyReq)
|
||||
if err != nil {
|
||||
log.Printf("[voice] Voice-Service 不可达 (Transcribe): %v", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": "Voice-Service 不可达: " + err.Error(),
|
||||
"errorType": "voice_service_unreachable",
|
||||
"hint": "Voice-Service 服务未启动或不可达,请先在「服务管理」面板中启动 Voice-Service",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 透传状态码和响应
|
||||
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
|
||||
}
|
||||
|
||||
// TTSSynthesize POST /api/v1/voice/tts
|
||||
// 代理 JSON 请求到 voice-service TTS 合成
|
||||
func (h *VoiceHandler) TTSSynthesize(c *gin.Context) {
|
||||
// 读取 JSON 请求体
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 构建代理请求
|
||||
url := strings.TrimRight(h.voiceServiceURL, "/") + "/api/v1/tts/synthesize"
|
||||
|
||||
proxyReq, err := http.NewRequest("POST", url, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "构建代理请求失败"})
|
||||
return
|
||||
}
|
||||
proxyReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := h.client.Do(proxyReq)
|
||||
if err != nil {
|
||||
log.Printf("[voice] Voice-Service 不可达 (TTS): %v", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": "Voice-Service 不可达: " + err.Error(),
|
||||
"errorType": "voice_service_unreachable",
|
||||
"hint": "Voice-Service 服务未启动或不可达",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "读取 Voice-Service 响应失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 透传状态码、Content-Type 和响应体
|
||||
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), respBody)
|
||||
}
|
||||
|
||||
// TTSVoices GET /api/v1/voice/tts/voices
|
||||
// 代理请求到 voice-service 获取可用语音列表
|
||||
func (h *VoiceHandler) TTSVoices(c *gin.Context) {
|
||||
url := strings.TrimRight(h.voiceServiceURL, "/") + "/api/v1/tts/voices"
|
||||
|
||||
resp, err := h.client.Get(url)
|
||||
if err != nil {
|
||||
log.Printf("[voice] Voice-Service 不可达 (Voices): %v", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": "Voice-Service 不可达: " + err.Error(),
|
||||
"errorType": "voice_service_unreachable",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 解析并透传
|
||||
var data interface{}
|
||||
json.Unmarshal(respBody, &data)
|
||||
c.JSON(resp.StatusCode, data)
|
||||
}
|
||||
|
||||
// TTSStatus GET /api/v1/voice/tts/status
|
||||
// 代理请求到 voice-service 获取 TTS 状态
|
||||
func (h *VoiceHandler) TTSStatus(c *gin.Context) {
|
||||
url := strings.TrimRight(h.voiceServiceURL, "/") + "/api/v1/tts/status"
|
||||
|
||||
resp, err := h.client.Get(url)
|
||||
if err != nil {
|
||||
log.Printf("[voice] Voice-Service 不可达 (TTS Status): %v", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": "Voice-Service 不可达: " + err.Error(),
|
||||
"errorType": "voice_service_unreachable",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
var data interface{}
|
||||
json.Unmarshal(respBody, &data)
|
||||
c.JSON(resp.StatusCode, data)
|
||||
}
|
||||
|
||||
// VoiceStatus GET /api/v1/voice/status
|
||||
// 代理请求到 voice-service 获取完整状态(STT + TTS)
|
||||
func (h *VoiceHandler) VoiceStatus(c *gin.Context) {
|
||||
url := strings.TrimRight(h.voiceServiceURL, "/") + "/api/v1/status"
|
||||
|
||||
resp, err := h.client.Get(url)
|
||||
if err != nil {
|
||||
log.Printf("[voice] Voice-Service 不可达 (Status): %v", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": "Voice-Service 不可达: " + err.Error(),
|
||||
"errorType": "voice_service_unreachable",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
var data interface{}
|
||||
json.Unmarshal(respBody, &data)
|
||||
c.JSON(resp.StatusCode, data)
|
||||
}
|
||||
Reference in New Issue
Block a user