feat: 第五轮开发 - 14项未来路线图功能完整实现
W1-W14 全部完成: - W1: 消息搜索 (ILIKE全文检索 + SearchModal) - W2: 对话导出 (JSON/Markdown/TXT三格式) - W3: 记忆时间线 DevTools 可视化 - W4: 通知推送系统 (WebSocket + Browser Notification API) - W5: 定时提醒 (30s轮询 + 重复提醒 + WebSocket推送) - W6: 每日简报 (08:00自动生成: 天气+新闻+提醒+AI摘要) - W7: IoT场景自动化 (规则引擎 10s轮询 + 条件评估 + 场景执行) - W8: 语音输入 (浏览器 Speech Recognition API) - W9: STT服务 (voice-service + whisper.cpp) - W10: TTS服务 (浏览器 Speech Synthesis + edge-tts三档回退) - W11: 文件管理 (上传/下载/缩略图/纯Go bilinear缩放) - W12: 知识库RAG (PostgreSQL tsvector + 文档分块 + 检索) - W13: 多模态 (图片上传+分析: Vision API + 本地Go分析回退) - W14: PWA (Service Worker + 离线页 + install prompt) 总计: 6个Go微服务 + 10+前端组件 + 10+ PostgreSQL表 + 4个后台调度器
This commit is contained in:
@@ -0,0 +1,365 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AutomationRule 自动化规则模型
|
||||
type AutomationRule struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
TriggerType string `json:"trigger_type"`
|
||||
TriggerConfig *json.RawMessage `json:"trigger_config"`
|
||||
Conditions *json.RawMessage `json:"conditions"`
|
||||
Actions *json.RawMessage `json:"actions"`
|
||||
Enabled bool `json:"enabled"`
|
||||
LastTriggeredAt *time.Time `json:"last_triggered_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// AutomationScene 自动化场景模型
|
||||
type AutomationScene struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Icon string `json:"icon"`
|
||||
RuleIDs *json.RawMessage `json:"rule_ids"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// AutomationStore 自动化持久化存储
|
||||
type AutomationStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewAutomationStore 使用已有数据库连接初始化自动化存储并自动建表
|
||||
func NewAutomationStore(db *sql.DB) (*AutomationStore, error) {
|
||||
store := &AutomationStore{db: db}
|
||||
|
||||
if err := store.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("自动化表迁移失败: %w", err)
|
||||
}
|
||||
|
||||
log.Println("[AutomationStore] 自动化持久化存储已初始化")
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// migrate 自动创建表结构
|
||||
func (s *AutomationStore) migrate() error {
|
||||
queries := []string{
|
||||
`CREATE TABLE IF NOT EXISTS automation_rules (
|
||||
id VARCHAR(64) PRIMARY KEY,
|
||||
user_id VARCHAR(64) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
trigger_type VARCHAR(32) NOT NULL,
|
||||
trigger_config JSONB DEFAULT '{}',
|
||||
conditions JSONB DEFAULT '[]',
|
||||
actions JSONB NOT NULL DEFAULT '[]',
|
||||
enabled BOOLEAN DEFAULT TRUE,
|
||||
last_triggered_at TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_automation_rules_user_id ON automation_rules(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_automation_rules_enabled ON automation_rules(enabled)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS automation_scenes (
|
||||
id VARCHAR(64) PRIMARY KEY,
|
||||
user_id VARCHAR(64) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
icon VARCHAR(64) DEFAULT '',
|
||||
rule_ids JSONB DEFAULT '[]',
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_automation_scenes_user_id ON automation_scenes(user_id)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
if _, err := s.db.Exec(q); err != nil {
|
||||
return fmt.Errorf("迁移SQL执行失败: %w\nSQL: %s", err, q)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ========== Rule CRUD ==========
|
||||
|
||||
// CreateRule 创建新规则
|
||||
func (s *AutomationStore) CreateRule(rule *AutomationRule) error {
|
||||
now := time.Now()
|
||||
rule.CreatedAt = now
|
||||
rule.UpdatedAt = now
|
||||
|
||||
triggerConfig := jsonNull(rule.TriggerConfig)
|
||||
conditions := jsonNull(rule.Conditions)
|
||||
actions := jsonNull(rule.Actions)
|
||||
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO automation_rules (id, user_id, name, description, trigger_type, trigger_config, conditions, actions, enabled, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)`,
|
||||
rule.ID, rule.UserID, rule.Name, rule.Description, rule.TriggerType,
|
||||
triggerConfig, conditions, actions, rule.Enabled, rule.CreatedAt, rule.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建规则失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRulesByUser 获取用户的所有规则
|
||||
func (s *AutomationStore) GetRulesByUser(userID string) ([]AutomationRule, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, user_id, name, description, trigger_type, trigger_config, conditions, actions, enabled, last_triggered_at, created_at, updated_at
|
||||
FROM automation_rules WHERE user_id = $1
|
||||
ORDER BY created_at DESC`,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询用户规则失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var rules []AutomationRule
|
||||
for rows.Next() {
|
||||
var r AutomationRule
|
||||
if err := rows.Scan(&r.ID, &r.UserID, &r.Name, &r.Description, &r.TriggerType,
|
||||
&r.TriggerConfig, &r.Conditions, &r.Actions, &r.Enabled, &r.LastTriggeredAt,
|
||||
&r.CreatedAt, &r.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描规则行失败: %w", err)
|
||||
}
|
||||
rules = append(rules, r)
|
||||
}
|
||||
|
||||
if rules == nil {
|
||||
rules = []AutomationRule{}
|
||||
}
|
||||
return rules, rows.Err()
|
||||
}
|
||||
|
||||
// GetRule 获取单个规则
|
||||
func (s *AutomationStore) GetRule(id string) (*AutomationRule, error) {
|
||||
var r AutomationRule
|
||||
err := s.db.QueryRow(
|
||||
`SELECT id, user_id, name, description, trigger_type, trigger_config, conditions, actions, enabled, last_triggered_at, created_at, updated_at
|
||||
FROM automation_rules WHERE id = $1`,
|
||||
id,
|
||||
).Scan(&r.ID, &r.UserID, &r.Name, &r.Description, &r.TriggerType,
|
||||
&r.TriggerConfig, &r.Conditions, &r.Actions, &r.Enabled, &r.LastTriggeredAt,
|
||||
&r.CreatedAt, &r.UpdatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("查询规则失败: %w", err)
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// UpdateRule 更新规则
|
||||
func (s *AutomationStore) UpdateRule(rule *AutomationRule) error {
|
||||
triggerConfig := jsonNull(rule.TriggerConfig)
|
||||
conditions := jsonNull(rule.Conditions)
|
||||
actions := jsonNull(rule.Actions)
|
||||
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE automation_rules SET name = $1, description = $2, trigger_type = $3,
|
||||
trigger_config = $4, conditions = $5, actions = $6, enabled = $7, updated_at = NOW()
|
||||
WHERE id = $8`,
|
||||
rule.Name, rule.Description, rule.TriggerType,
|
||||
triggerConfig, conditions, actions, rule.Enabled, rule.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新规则失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRule 删除规则
|
||||
func (s *AutomationStore) DeleteRule(id string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM automation_rules WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除规则失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEnabledRules 获取所有启用的规则(供引擎使用)
|
||||
func (s *AutomationStore) GetEnabledRules() ([]AutomationRule, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, user_id, name, description, trigger_type, trigger_config, conditions, actions, enabled, last_triggered_at, created_at, updated_at
|
||||
FROM automation_rules WHERE enabled = TRUE
|
||||
ORDER BY created_at ASC`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询启用的规则失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var rules []AutomationRule
|
||||
for rows.Next() {
|
||||
var r AutomationRule
|
||||
if err := rows.Scan(&r.ID, &r.UserID, &r.Name, &r.Description, &r.TriggerType,
|
||||
&r.TriggerConfig, &r.Conditions, &r.Actions, &r.Enabled, &r.LastTriggeredAt,
|
||||
&r.CreatedAt, &r.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描规则行失败: %w", err)
|
||||
}
|
||||
rules = append(rules, r)
|
||||
}
|
||||
|
||||
if rules == nil {
|
||||
rules = []AutomationRule{}
|
||||
}
|
||||
return rules, rows.Err()
|
||||
}
|
||||
|
||||
// MarkRuleTriggered 更新 last_triggered_at
|
||||
func (s *AutomationStore) MarkRuleTriggered(id string) error {
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE automation_rules SET last_triggered_at = NOW(), updated_at = NOW() WHERE id = $1`,
|
||||
id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("标记规则触发失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ========== Scene CRUD ==========
|
||||
|
||||
// CreateScene 创建新场景
|
||||
func (s *AutomationStore) CreateScene(scene *AutomationScene) error {
|
||||
ruleIDs := jsonNull(scene.RuleIDs)
|
||||
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO automation_scenes (id, user_id, name, icon, rule_ids, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW())`,
|
||||
scene.ID, scene.UserID, scene.Name, scene.Icon, ruleIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建场景失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetScenesByUser 获取用户的所有场景
|
||||
func (s *AutomationStore) GetScenesByUser(userID string) ([]AutomationScene, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, user_id, name, icon, rule_ids, created_at
|
||||
FROM automation_scenes WHERE user_id = $1
|
||||
ORDER BY created_at DESC`,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询用户场景失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var scenes []AutomationScene
|
||||
for rows.Next() {
|
||||
var sc AutomationScene
|
||||
if err := rows.Scan(&sc.ID, &sc.UserID, &sc.Name, &sc.Icon, &sc.RuleIDs, &sc.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描场景行失败: %w", err)
|
||||
}
|
||||
scenes = append(scenes, sc)
|
||||
}
|
||||
|
||||
if scenes == nil {
|
||||
scenes = []AutomationScene{}
|
||||
}
|
||||
return scenes, rows.Err()
|
||||
}
|
||||
|
||||
// GetScene 获取单个场景
|
||||
func (s *AutomationStore) GetScene(id string) (*AutomationScene, error) {
|
||||
var sc AutomationScene
|
||||
err := s.db.QueryRow(
|
||||
`SELECT id, user_id, name, icon, rule_ids, created_at
|
||||
FROM automation_scenes WHERE id = $1`,
|
||||
id,
|
||||
).Scan(&sc.ID, &sc.UserID, &sc.Name, &sc.Icon, &sc.RuleIDs, &sc.CreatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("查询场景失败: %w", err)
|
||||
}
|
||||
return &sc, nil
|
||||
}
|
||||
|
||||
// UpdateScene 更新场景
|
||||
func (s *AutomationStore) UpdateScene(scene *AutomationScene) error {
|
||||
ruleIDs := jsonNull(scene.RuleIDs)
|
||||
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE automation_scenes SET name = $1, icon = $2, rule_ids = $3 WHERE id = $4`,
|
||||
scene.Name, scene.Icon, ruleIDs, scene.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新场景失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteScene 删除场景
|
||||
func (s *AutomationStore) DeleteScene(id string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM automation_scenes WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除场景失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSceneRules 根据 scene 的 rule_ids 取出所有关联的 rules
|
||||
func (s *AutomationStore) GetSceneRules(sceneID string) ([]AutomationRule, error) {
|
||||
sc, err := s.GetScene(sceneID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sc == nil {
|
||||
return []AutomationRule{}, nil
|
||||
}
|
||||
|
||||
var ruleIDs []string
|
||||
if sc.RuleIDs != nil {
|
||||
if err := json.Unmarshal(*sc.RuleIDs, &ruleIDs); err != nil {
|
||||
return nil, fmt.Errorf("解析场景规则ID失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ruleIDs) == 0 {
|
||||
return []AutomationRule{}, nil
|
||||
}
|
||||
|
||||
// 构建 IN 查询
|
||||
var rules []AutomationRule
|
||||
for _, rid := range ruleIDs {
|
||||
r, err := s.GetRule(rid)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询场景关联规则失败: %w", err)
|
||||
}
|
||||
if r != nil {
|
||||
rules = append(rules, *r)
|
||||
}
|
||||
}
|
||||
|
||||
if rules == nil {
|
||||
rules = []AutomationRule{}
|
||||
}
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// jsonNull 将 *json.RawMessage 转为可写入数据库的 JSON 或 null
|
||||
func jsonNull(raw *json.RawMessage) interface{} {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
return []byte(*raw)
|
||||
}
|
||||
@@ -0,0 +1,321 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Briefing 每日简报模型
|
||||
type Briefing struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Date string `json:"date"` // YYYY-MM-DD
|
||||
Weather *WeatherData `json:"weather"`
|
||||
News []NewsItem `json:"news"`
|
||||
Reminders []BriefReminder `json:"reminders"`
|
||||
Summary string `json:"summary"`
|
||||
Status string `json:"status"` // pending, generated, delivered
|
||||
GeneratedAt *time.Time `json:"generated_at,omitempty"`
|
||||
DeliveredAt *time.Time `json:"delivered_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// WeatherData 天气数据
|
||||
type WeatherData struct {
|
||||
Location string `json:"location"`
|
||||
Temp float64 `json:"temp"`
|
||||
Condition string `json:"condition"`
|
||||
Icon string `json:"icon"`
|
||||
}
|
||||
|
||||
// NewsItem 新闻条目
|
||||
type NewsItem struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Source string `json:"source"`
|
||||
Summary string `json:"summary"`
|
||||
}
|
||||
|
||||
// BriefReminder 简报中的提醒摘要
|
||||
type BriefReminder struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
RemindAt string `json:"remind_at"`
|
||||
}
|
||||
|
||||
// BriefingStore 每日简报持久化存储
|
||||
type BriefingStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewBriefingStore 使用已有数据库连接初始化简报存储并自动建表
|
||||
func NewBriefingStore(db *sql.DB) (*BriefingStore, error) {
|
||||
store := &BriefingStore{db: db}
|
||||
|
||||
if err := store.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("简报表迁移失败: %w", err)
|
||||
}
|
||||
|
||||
log.Println("[BriefingStore] 简报持久化存储已初始化")
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// migrate 自动创建简报表结构
|
||||
func (s *BriefingStore) migrate() error {
|
||||
queries := []string{
|
||||
`CREATE TABLE IF NOT EXISTS daily_briefings (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
date DATE NOT NULL,
|
||||
weather JSONB DEFAULT '{}',
|
||||
news JSONB DEFAULT '[]',
|
||||
reminders JSONB DEFAULT '[]',
|
||||
summary TEXT DEFAULT '',
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
generated_at TIMESTAMPTZ,
|
||||
delivered_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
UNIQUE(user_id, date)
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_briefings_user_id ON daily_briefings(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_briefings_date ON daily_briefings(date)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_briefings_user_date ON daily_briefings(user_id, date)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
if _, err := s.db.Exec(q); err != nil {
|
||||
return fmt.Errorf("迁移SQL执行失败: %w\nSQL: %s", err, q)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateOrUpdateBriefing upsert 简报
|
||||
func (s *BriefingStore) CreateOrUpdateBriefing(b *Briefing) error {
|
||||
weatherJSON, err := json.Marshal(b.Weather)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化天气数据失败: %w", err)
|
||||
}
|
||||
newsJSON, err := json.Marshal(b.News)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化新闻数据失败: %w", err)
|
||||
}
|
||||
remindersJSON, err := json.Marshal(b.Reminders)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化提醒数据失败: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(
|
||||
`INSERT INTO daily_briefings (id, user_id, date, weather, news, reminders, summary, status, generated_at, delivered_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
ON CONFLICT (user_id, date) DO UPDATE SET
|
||||
weather = EXCLUDED.weather,
|
||||
news = EXCLUDED.news,
|
||||
reminders = EXCLUDED.reminders,
|
||||
summary = EXCLUDED.summary,
|
||||
status = EXCLUDED.status,
|
||||
generated_at = EXCLUDED.generated_at,
|
||||
delivered_at = EXCLUDED.delivered_at`,
|
||||
b.ID, b.UserID, b.Date, string(weatherJSON), string(newsJSON), string(remindersJSON),
|
||||
b.Summary, b.Status, b.GeneratedAt, b.DeliveredAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upsert 简报失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBriefingByDate 获取指定日期简报
|
||||
func (s *BriefingStore) GetBriefingByDate(userID, date string) (*Briefing, error) {
|
||||
row := s.db.QueryRow(
|
||||
`SELECT id, user_id, date::TEXT, weather, news, reminders, summary, status, generated_at, delivered_at, created_at
|
||||
FROM daily_briefings WHERE user_id = $1 AND date = $2::DATE`,
|
||||
userID, date,
|
||||
)
|
||||
|
||||
b, err := s.scanBriefing(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("查询简报失败: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// GetLatestBriefings 获取最近简报列表
|
||||
func (s *BriefingStore) GetLatestBriefings(userID string, limit int) ([]Briefing, error) {
|
||||
if limit <= 0 {
|
||||
limit = 7
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, user_id, date::TEXT, weather, news, reminders, summary, status, generated_at, delivered_at, created_at
|
||||
FROM daily_briefings WHERE user_id = $1
|
||||
ORDER BY date DESC LIMIT $2`,
|
||||
userID, limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询简报列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var briefings []Briefing
|
||||
for rows.Next() {
|
||||
var (
|
||||
id, uid, date, summary, status string
|
||||
weatherRaw, newsRaw, remindersRaw []byte
|
||||
generatedAt, deliveredAt, createdAt sql.NullTime
|
||||
)
|
||||
if err := rows.Scan(&id, &uid, &date, &weatherRaw, &newsRaw, &remindersRaw,
|
||||
&summary, &status, &generatedAt, &deliveredAt, &createdAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描简报行失败: %w", err)
|
||||
}
|
||||
|
||||
b := Briefing{
|
||||
ID: id,
|
||||
UserID: uid,
|
||||
Date: date,
|
||||
Summary: summary,
|
||||
Status: status,
|
||||
}
|
||||
|
||||
if weatherRaw != nil {
|
||||
var w WeatherData
|
||||
if err := json.Unmarshal(weatherRaw, &w); err == nil {
|
||||
b.Weather = &w
|
||||
}
|
||||
}
|
||||
if newsRaw != nil {
|
||||
json.Unmarshal(newsRaw, &b.News)
|
||||
}
|
||||
if remindersRaw != nil {
|
||||
json.Unmarshal(remindersRaw, &b.Reminders)
|
||||
}
|
||||
if generatedAt.Valid {
|
||||
b.GeneratedAt = &generatedAt.Time
|
||||
}
|
||||
if deliveredAt.Valid {
|
||||
b.DeliveredAt = &deliveredAt.Time
|
||||
}
|
||||
b.CreatedAt = createdAt.Time
|
||||
|
||||
// 确保切片不为 nil
|
||||
if b.News == nil {
|
||||
b.News = []NewsItem{}
|
||||
}
|
||||
if b.Reminders == nil {
|
||||
b.Reminders = []BriefReminder{}
|
||||
}
|
||||
if b.Weather == nil {
|
||||
b.Weather = &WeatherData{}
|
||||
}
|
||||
|
||||
briefings = append(briefings, b)
|
||||
}
|
||||
|
||||
if briefings == nil {
|
||||
briefings = []Briefing{}
|
||||
}
|
||||
return briefings, rows.Err()
|
||||
}
|
||||
|
||||
// GetUsersWithBriefings 获取拥有简报的所有用户 ID 列表(用于调度器)
|
||||
func (s *BriefingStore) GetUsersWithBriefings() ([]string, error) {
|
||||
rows, err := s.db.Query(`SELECT DISTINCT user_id FROM daily_briefings`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询简报用户列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var userIDs []string
|
||||
for rows.Next() {
|
||||
var uid string
|
||||
if err := rows.Scan(&uid); err != nil {
|
||||
return nil, fmt.Errorf("扫描用户ID失败: %w", err)
|
||||
}
|
||||
userIDs = append(userIDs, uid)
|
||||
}
|
||||
if userIDs == nil {
|
||||
userIDs = []string{}
|
||||
}
|
||||
return userIDs, rows.Err()
|
||||
}
|
||||
|
||||
// GetAllUsers 获取所有用户 ID(从 reminders 表获取,作为降级方案)
|
||||
func (s *BriefingStore) GetAllUsers() ([]string, error) {
|
||||
rows, err := s.db.Query(`SELECT DISTINCT user_id FROM reminders`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询用户列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var userIDs []string
|
||||
for rows.Next() {
|
||||
var uid string
|
||||
if err := rows.Scan(&uid); err != nil {
|
||||
return nil, fmt.Errorf("扫描用户ID失败: %w", err)
|
||||
}
|
||||
userIDs = append(userIDs, uid)
|
||||
}
|
||||
if userIDs == nil {
|
||||
userIDs = []string{}
|
||||
}
|
||||
return userIDs, rows.Err()
|
||||
}
|
||||
|
||||
// scanBriefing 扫描单行简报
|
||||
func (s *BriefingStore) scanBriefing(row *sql.Row) (*Briefing, error) {
|
||||
var (
|
||||
id, uid, date, summary, status string
|
||||
weatherRaw, newsRaw, remindersRaw []byte
|
||||
generatedAt, deliveredAt, createdAt sql.NullTime
|
||||
)
|
||||
|
||||
if err := row.Scan(&id, &uid, &date, &weatherRaw, &newsRaw, &remindersRaw,
|
||||
&summary, &status, &generatedAt, &deliveredAt, &createdAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b := &Briefing{
|
||||
ID: id,
|
||||
UserID: uid,
|
||||
Date: date,
|
||||
Summary: summary,
|
||||
Status: status,
|
||||
}
|
||||
|
||||
if weatherRaw != nil {
|
||||
var w WeatherData
|
||||
if err := json.Unmarshal(weatherRaw, &w); err == nil {
|
||||
b.Weather = &w
|
||||
}
|
||||
}
|
||||
if b.Weather == nil {
|
||||
b.Weather = &WeatherData{}
|
||||
}
|
||||
if newsRaw != nil {
|
||||
json.Unmarshal(newsRaw, &b.News)
|
||||
}
|
||||
if b.News == nil {
|
||||
b.News = []NewsItem{}
|
||||
}
|
||||
if remindersRaw != nil {
|
||||
json.Unmarshal(remindersRaw, &b.Reminders)
|
||||
}
|
||||
if b.Reminders == nil {
|
||||
b.Reminders = []BriefReminder{}
|
||||
}
|
||||
if generatedAt.Valid {
|
||||
b.GeneratedAt = &generatedAt.Time
|
||||
}
|
||||
if deliveredAt.Valid {
|
||||
b.DeliveredAt = &deliveredAt.Time
|
||||
}
|
||||
b.CreatedAt = createdAt.Time
|
||||
|
||||
return b, nil
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// File 文件元数据模型
|
||||
type File struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Filename string `json:"filename"`
|
||||
StoredPath string `json:"stored_path"`
|
||||
MimeType string `json:"mime_type"`
|
||||
Size int64 `json:"size"`
|
||||
Hash string `json:"hash"`
|
||||
IsPublic bool `json:"is_public"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// FileStore 文件元数据持久化存储
|
||||
type FileStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewFileStore 使用已有数据库连接初始化文件存储并自动建表
|
||||
func NewFileStore(db *sql.DB) (*FileStore, error) {
|
||||
store := &FileStore{db: db}
|
||||
|
||||
if err := store.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("文件表迁移失败: %w", err)
|
||||
}
|
||||
|
||||
log.Println("[FileStore] 文件持久化存储已初始化")
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// migrate 自动创建文件表结构
|
||||
func (s *FileStore) migrate() error {
|
||||
queries := []string{
|
||||
`CREATE TABLE IF NOT EXISTS files (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
filename VARCHAR(500) NOT NULL,
|
||||
stored_path VARCHAR(1000) NOT NULL,
|
||||
mime_type VARCHAR(255) NOT NULL DEFAULT 'application/octet-stream',
|
||||
size BIGINT NOT NULL DEFAULT 0,
|
||||
hash VARCHAR(64) NOT NULL DEFAULT '',
|
||||
is_public BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_files_user_id ON files(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_files_hash ON files(hash)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_files_created_at ON files(user_id, created_at DESC)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
if _, err := s.db.Exec(q); err != nil {
|
||||
return fmt.Errorf("迁移SQL执行失败: %w\nSQL: %s", err, q)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateFile 创建文件元数据记录
|
||||
func (s *FileStore) CreateFile(f *File) error {
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO files (id, user_id, filename, stored_path, mime_type, size, hash, is_public, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
|
||||
f.ID, f.UserID, f.Filename, f.StoredPath, f.MimeType, f.Size, f.Hash, f.IsPublic, f.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建文件记录失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFile 根据ID获取文件元数据
|
||||
func (s *FileStore) GetFile(id string) (*File, error) {
|
||||
var f File
|
||||
err := s.db.QueryRow(
|
||||
`SELECT id, user_id, filename, stored_path, mime_type, size, hash, is_public, created_at
|
||||
FROM files WHERE id = $1`,
|
||||
id,
|
||||
).Scan(&f.ID, &f.UserID, &f.Filename, &f.StoredPath, &f.MimeType, &f.Size, &f.Hash, &f.IsPublic, &f.CreatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("查询文件失败: %w", err)
|
||||
}
|
||||
return &f, nil
|
||||
}
|
||||
|
||||
// GetUserFiles 获取用户的所有文件,支持分页
|
||||
func (s *FileStore) GetUserFiles(userID string, page, limit int) ([]File, int, error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 20
|
||||
}
|
||||
offset := (page - 1) * limit
|
||||
|
||||
// 获取总数
|
||||
var total int
|
||||
if err := s.db.QueryRow(
|
||||
`SELECT COUNT(*) FROM files WHERE user_id = $1`,
|
||||
userID,
|
||||
).Scan(&total); err != nil {
|
||||
return nil, 0, fmt.Errorf("查询文件总数失败: %w", err)
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, user_id, filename, stored_path, mime_type, size, hash, is_public, created_at
|
||||
FROM files WHERE user_id = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2 OFFSET $3`,
|
||||
userID, limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("查询用户文件失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var files []File
|
||||
for rows.Next() {
|
||||
var f File
|
||||
if err := rows.Scan(&f.ID, &f.UserID, &f.Filename, &f.StoredPath, &f.MimeType, &f.Size, &f.Hash, &f.IsPublic, &f.CreatedAt); err != nil {
|
||||
return nil, 0, fmt.Errorf("扫描文件行失败: %w", err)
|
||||
}
|
||||
files = append(files, f)
|
||||
}
|
||||
|
||||
if files == nil {
|
||||
files = []File{}
|
||||
}
|
||||
return files, total, rows.Err()
|
||||
}
|
||||
|
||||
// GetFileByHash 根据SHA256哈希查找文件(用于去重)
|
||||
func (s *FileStore) GetFileByHash(hash string) (*File, error) {
|
||||
if hash == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var f File
|
||||
err := s.db.QueryRow(
|
||||
`SELECT id, user_id, filename, stored_path, mime_type, size, hash, is_public, created_at
|
||||
FROM files WHERE hash = $1
|
||||
ORDER BY created_at ASC LIMIT 1`,
|
||||
hash,
|
||||
).Scan(&f.ID, &f.UserID, &f.Filename, &f.StoredPath, &f.MimeType, &f.Size, &f.Hash, &f.IsPublic, &f.CreatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("按哈希查询文件失败: %w", err)
|
||||
}
|
||||
return &f, nil
|
||||
}
|
||||
|
||||
// DeleteFile 删除文件元数据记录
|
||||
func (s *FileStore) DeleteFile(id string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM files WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除文件记录失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,822 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// ========== 模型定义 ==========
|
||||
|
||||
// KnowledgeBase 知识库
|
||||
type KnowledgeBase struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
DocumentCount int `json:"document_count"`
|
||||
ChunkCount int `json:"chunk_count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// KnowledgeDocument 知识库文档
|
||||
type KnowledgeDocument struct {
|
||||
ID string `json:"id"`
|
||||
KBID string `json:"kb_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Title string `json:"title"`
|
||||
SourceType string `json:"source_type"` // "file", "text", "url"
|
||||
SourceRef string `json:"source_ref"` // 文件 ID 或 URL
|
||||
ContentType string `json:"content_type"` // "text/plain", "text/markdown", "text/html"
|
||||
RawContent string `json:"raw_content"`
|
||||
ChunkCount int `json:"chunk_count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// KnowledgeChunk 文档分块
|
||||
type KnowledgeChunk struct {
|
||||
ID string `json:"id"`
|
||||
DocID string `json:"doc_id"`
|
||||
KBID string `json:"kb_id"`
|
||||
ChunkIndex int `json:"chunk_index"`
|
||||
Content string `json:"content"`
|
||||
TokenCount int `json:"token_count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SearchChunkResult 搜索结果的块,包含额外上下文信息
|
||||
type SearchChunkResult struct {
|
||||
KnowledgeChunk
|
||||
Relevance float64 `json:"relevance"`
|
||||
DocumentTitle string `json:"document_title"`
|
||||
KBName string `json:"kb_name"`
|
||||
Headline string `json:"headline"`
|
||||
}
|
||||
|
||||
// ========== KnowledgeStore ==========
|
||||
|
||||
// KnowledgeStore 知识库持久化存储
|
||||
type KnowledgeStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewKnowledgeStore 使用已有数据库连接初始化知识库存储并自动建表
|
||||
func NewKnowledgeStore(db *sql.DB) (*KnowledgeStore, error) {
|
||||
store := &KnowledgeStore{db: db}
|
||||
|
||||
if err := store.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("知识库表迁移失败: %w", err)
|
||||
}
|
||||
|
||||
log.Println("[KnowledgeStore] 知识库持久化存储已初始化")
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// migrate 自动创建知识库相关表结构
|
||||
func (s *KnowledgeStore) migrate() error {
|
||||
queries := []string{
|
||||
// 知识库表
|
||||
`CREATE TABLE IF NOT EXISTS knowledge_bases (
|
||||
id VARCHAR(64) PRIMARY KEY,
|
||||
user_id VARCHAR(64) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
document_count INT DEFAULT 0,
|
||||
chunk_count INT DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_kb_user_id ON knowledge_bases(user_id)`,
|
||||
|
||||
// 文档表
|
||||
`CREATE TABLE IF NOT EXISTS knowledge_documents (
|
||||
id VARCHAR(64) PRIMARY KEY,
|
||||
kb_id VARCHAR(64) NOT NULL REFERENCES knowledge_bases(id) ON DELETE CASCADE,
|
||||
user_id VARCHAR(64) NOT NULL,
|
||||
title VARCHAR(512) NOT NULL,
|
||||
source_type VARCHAR(32) DEFAULT 'text',
|
||||
source_ref VARCHAR(1024) DEFAULT '',
|
||||
content_type VARCHAR(64) DEFAULT 'text/plain',
|
||||
raw_content TEXT DEFAULT '',
|
||||
chunk_count INT DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_kd_kb_id ON knowledge_documents(kb_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_kd_user_id ON knowledge_documents(user_id)`,
|
||||
|
||||
// 分块表
|
||||
`CREATE TABLE IF NOT EXISTS knowledge_chunks (
|
||||
id VARCHAR(64) PRIMARY KEY,
|
||||
doc_id VARCHAR(64) NOT NULL REFERENCES knowledge_documents(id) ON DELETE CASCADE,
|
||||
kb_id VARCHAR(64) NOT NULL,
|
||||
chunk_index INT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
token_count INT DEFAULT 0,
|
||||
tsv TSVECTOR,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_kc_doc_id ON knowledge_chunks(doc_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_kc_kb_id ON knowledge_chunks(kb_id)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
if _, err := s.db.Exec(q); err != nil {
|
||||
return fmt.Errorf("迁移SQL执行失败: %w\nSQL: %s", err, q)
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试创建 GIN 索引(可能因权限或扩展问题失败,但不影响功能)
|
||||
_, err := s.db.Exec(`CREATE INDEX IF NOT EXISTS idx_kc_tsv_gin ON knowledge_chunks USING GIN(tsv)`)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeStore] ⚠ GIN索引创建失败(将使用ILIKE降级搜索): %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ========== 知识库 CRUD ==========
|
||||
|
||||
// CreateKB 创建知识库
|
||||
func (s *KnowledgeStore) CreateKB(kb *KnowledgeBase) error {
|
||||
now := time.Now()
|
||||
if kb.CreatedAt.IsZero() {
|
||||
kb.CreatedAt = now
|
||||
}
|
||||
if kb.UpdatedAt.IsZero() {
|
||||
kb.UpdatedAt = now
|
||||
}
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO knowledge_bases (id, user_id, name, description, document_count, chunk_count, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
|
||||
kb.ID, kb.UserID, kb.Name, kb.Description, kb.DocumentCount, kb.ChunkCount, kb.CreatedAt, kb.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建知识库失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetKBsByUser 获取用户的所有知识库
|
||||
func (s *KnowledgeStore) GetKBsByUser(userID string) ([]KnowledgeBase, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, user_id, name, description, document_count, chunk_count, created_at, updated_at
|
||||
FROM knowledge_bases WHERE user_id = $1
|
||||
ORDER BY updated_at DESC`,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询知识库列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var kbs []KnowledgeBase
|
||||
for rows.Next() {
|
||||
var kb KnowledgeBase
|
||||
if err := rows.Scan(&kb.ID, &kb.UserID, &kb.Name, &kb.Description,
|
||||
&kb.DocumentCount, &kb.ChunkCount, &kb.CreatedAt, &kb.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描知识库行失败: %w", err)
|
||||
}
|
||||
kbs = append(kbs, kb)
|
||||
}
|
||||
if kbs == nil {
|
||||
kbs = []KnowledgeBase{}
|
||||
}
|
||||
return kbs, rows.Err()
|
||||
}
|
||||
|
||||
// GetKB 获取单个知识库
|
||||
func (s *KnowledgeStore) GetKB(id string) (*KnowledgeBase, error) {
|
||||
var kb KnowledgeBase
|
||||
err := s.db.QueryRow(
|
||||
`SELECT id, user_id, name, description, document_count, chunk_count, created_at, updated_at
|
||||
FROM knowledge_bases WHERE id = $1`,
|
||||
id,
|
||||
).Scan(&kb.ID, &kb.UserID, &kb.Name, &kb.Description,
|
||||
&kb.DocumentCount, &kb.ChunkCount, &kb.CreatedAt, &kb.UpdatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("查询知识库失败: %w", err)
|
||||
}
|
||||
return &kb, nil
|
||||
}
|
||||
|
||||
// UpdateKB 更新知识库名称和描述
|
||||
func (s *KnowledgeStore) UpdateKB(id string, name, description string) error {
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE knowledge_bases SET name = $1, description = $2, updated_at = NOW() WHERE id = $3`,
|
||||
name, description, id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新知识库失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteKB 删除知识库(级联删除文档和块)
|
||||
func (s *KnowledgeStore) DeleteKB(id string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM knowledge_bases WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除知识库失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateKBStats 更新知识库的统计计数
|
||||
func (s *KnowledgeStore) updateKBStats(kbID string) error {
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE knowledge_bases SET
|
||||
document_count = (SELECT COUNT(*) FROM knowledge_documents WHERE kb_id = $1),
|
||||
chunk_count = (SELECT COUNT(*) FROM knowledge_chunks WHERE kb_id = $1),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1`,
|
||||
kbID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// ========== 文档 CRUD ==========
|
||||
|
||||
// AddDocument 添加文档,返回创建的文档
|
||||
func (s *KnowledgeStore) AddDocument(doc *KnowledgeDocument) error {
|
||||
if doc.CreatedAt.IsZero() {
|
||||
doc.CreatedAt = time.Now()
|
||||
}
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO knowledge_documents (id, kb_id, user_id, title, source_type, source_ref, content_type, raw_content, chunk_count, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`,
|
||||
doc.ID, doc.KBID, doc.UserID, doc.Title, doc.SourceType, doc.SourceRef,
|
||||
doc.ContentType, doc.RawContent, doc.ChunkCount, doc.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加文档失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新知识库统计
|
||||
if err := s.updateKBStats(doc.KBID); err != nil {
|
||||
log.Printf("[KnowledgeStore] 更新知识库统计失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDocument 获取单个文档
|
||||
func (s *KnowledgeStore) GetDocument(id string) (*KnowledgeDocument, error) {
|
||||
var doc KnowledgeDocument
|
||||
err := s.db.QueryRow(
|
||||
`SELECT id, kb_id, user_id, title, source_type, source_ref, content_type, raw_content, chunk_count, created_at
|
||||
FROM knowledge_documents WHERE id = $1`,
|
||||
id,
|
||||
).Scan(&doc.ID, &doc.KBID, &doc.UserID, &doc.Title, &doc.SourceType, &doc.SourceRef,
|
||||
&doc.ContentType, &doc.RawContent, &doc.ChunkCount, &doc.CreatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("查询文档失败: %w", err)
|
||||
}
|
||||
return &doc, nil
|
||||
}
|
||||
|
||||
// GetDocumentsByKB 获取知识库中的所有文档
|
||||
func (s *KnowledgeStore) GetDocumentsByKB(kbID string) ([]KnowledgeDocument, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, kb_id, user_id, title, source_type, source_ref, content_type, raw_content, chunk_count, created_at
|
||||
FROM knowledge_documents WHERE kb_id = $1
|
||||
ORDER BY created_at DESC`,
|
||||
kbID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询文档列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var docs []KnowledgeDocument
|
||||
for rows.Next() {
|
||||
var doc KnowledgeDocument
|
||||
if err := rows.Scan(&doc.ID, &doc.KBID, &doc.UserID, &doc.Title, &doc.SourceType, &doc.SourceRef,
|
||||
&doc.ContentType, &doc.RawContent, &doc.ChunkCount, &doc.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描文档行失败: %w", err)
|
||||
}
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
if docs == nil {
|
||||
docs = []KnowledgeDocument{}
|
||||
}
|
||||
return docs, rows.Err()
|
||||
}
|
||||
|
||||
// UpdateDocumentChunkCount 更新文档的分块计数
|
||||
func (s *KnowledgeStore) UpdateDocumentChunkCount(docID string, count int) error {
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE knowledge_documents SET chunk_count = $1 WHERE id = $2`,
|
||||
count, docID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteDocument 删除文档(级联删除块)
|
||||
func (s *KnowledgeStore) DeleteDocument(id string) error {
|
||||
// 先获取 kb_id 以便后续更新统计
|
||||
var kbID string
|
||||
err := s.db.QueryRow(`SELECT kb_id FROM knowledge_documents WHERE id = $1`, id).Scan(&kbID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("查询文档失败: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(`DELETE FROM knowledge_documents WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除文档失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新知识库统计
|
||||
if err := s.updateKBStats(kbID); err != nil {
|
||||
log.Printf("[KnowledgeStore] 更新知识库统计失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ========== 分块操作 ==========
|
||||
|
||||
// AddChunk 添加单个分块
|
||||
func (s *KnowledgeStore) AddChunk(chunk *KnowledgeChunk) error {
|
||||
if chunk.CreatedAt.IsZero() {
|
||||
chunk.CreatedAt = time.Now()
|
||||
}
|
||||
|
||||
// 尝试使用 to_tsvector('chinese', content) 设置 tsv
|
||||
// 如果中文分词不可用,使用 simple 配置
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO knowledge_chunks (id, doc_id, kb_id, chunk_index, content, token_count, tsv, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6,
|
||||
CASE WHEN (SELECT count(*) FROM pg_ts_config WHERE cfgname = 'chinese') > 0
|
||||
THEN to_tsvector('chinese', $5)
|
||||
ELSE to_tsvector('simple', $5)
|
||||
END,
|
||||
$7)`,
|
||||
chunk.ID, chunk.DocID, chunk.KBID, chunk.ChunkIndex, chunk.Content, chunk.TokenCount, chunk.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
// 降级:不使用 tsv
|
||||
_, err = s.db.Exec(
|
||||
`INSERT INTO knowledge_chunks (id, doc_id, kb_id, chunk_index, content, token_count, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
|
||||
chunk.ID, chunk.DocID, chunk.KBID, chunk.ChunkIndex, chunk.Content, chunk.TokenCount, chunk.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("添加分块失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteChunksByDocID 删除文档的所有分块
|
||||
func (s *KnowledgeStore) DeleteChunksByDocID(docID string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM knowledge_chunks WHERE doc_id = $1`, docID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetChunksByDocID 获取文档的所有分块
|
||||
func (s *KnowledgeStore) GetChunksByDocID(docID string) ([]KnowledgeChunk, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, doc_id, kb_id, chunk_index, content, token_count, created_at
|
||||
FROM knowledge_chunks WHERE doc_id = $1
|
||||
ORDER BY chunk_index ASC`,
|
||||
docID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询分块失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var chunks []KnowledgeChunk
|
||||
for rows.Next() {
|
||||
var c KnowledgeChunk
|
||||
if err := rows.Scan(&c.ID, &c.DocID, &c.KBID, &c.ChunkIndex, &c.Content, &c.TokenCount, &c.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描分块行失败: %w", err)
|
||||
}
|
||||
chunks = append(chunks, c)
|
||||
}
|
||||
if chunks == nil {
|
||||
chunks = []KnowledgeChunk{}
|
||||
}
|
||||
return chunks, rows.Err()
|
||||
}
|
||||
|
||||
// ========== 分块逻辑 ==========
|
||||
|
||||
// ChunkDocument 将文档分块并存储
|
||||
func (s *KnowledgeStore) ChunkDocument(docID string) (int, error) {
|
||||
// 获取文档
|
||||
doc, err := s.GetDocument(docID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if doc == nil {
|
||||
return 0, fmt.Errorf("文档不存在: %s", docID)
|
||||
}
|
||||
|
||||
// 删除旧的分块
|
||||
if err := s.DeleteChunksByDocID(docID); err != nil {
|
||||
return 0, fmt.Errorf("删除旧分块失败: %w", err)
|
||||
}
|
||||
|
||||
// 分块
|
||||
chunks := splitTextIntoChunks(doc.RawContent, 500, 50)
|
||||
|
||||
// 存储分块
|
||||
for i, content := range chunks {
|
||||
chunk := &KnowledgeChunk{
|
||||
ID: generateUUIDv4(),
|
||||
DocID: docID,
|
||||
KBID: doc.KBID,
|
||||
ChunkIndex: i,
|
||||
Content: content,
|
||||
TokenCount: estimateTokenCount(content),
|
||||
}
|
||||
if err := s.AddChunk(chunk); err != nil {
|
||||
return 0, fmt.Errorf("存储分块 %d 失败: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新文档的分块计数
|
||||
if err := s.UpdateDocumentChunkCount(docID, len(chunks)); err != nil {
|
||||
log.Printf("[KnowledgeStore] 更新文档分块计数失败: %v", err)
|
||||
}
|
||||
|
||||
// 更新知识库统计
|
||||
if err := s.updateKBStats(doc.KBID); err != nil {
|
||||
log.Printf("[KnowledgeStore] 更新知识库统计失败: %v", err)
|
||||
}
|
||||
|
||||
return len(chunks), nil
|
||||
}
|
||||
|
||||
// ========== 搜索 ==========
|
||||
|
||||
// SearchChunks 在指定知识库中搜索
|
||||
func (s *KnowledgeStore) SearchChunks(kbID, query string, limit int) ([]SearchChunkResult, error) {
|
||||
if limit <= 0 {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
// 尝试使用 PostgreSQL 全文搜索
|
||||
results, err := s.searchWithFullText(kbID, query, limit)
|
||||
if err != nil {
|
||||
log.Printf("[KnowledgeStore] 全文搜索失败,降级为ILIKE: %v", err)
|
||||
// 降级为 ILIKE
|
||||
results, err = s.searchWithILike(kbID, query, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if results == nil {
|
||||
results = []SearchChunkResult{}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// SearchAllKBs 在用户的所有知识库中搜索
|
||||
func (s *KnowledgeStore) SearchAllKBs(userID, query string, limit int) ([]SearchChunkResult, error) {
|
||||
if limit <= 0 {
|
||||
limit = 5
|
||||
}
|
||||
|
||||
results, err := s.searchAllWithILike(userID, query, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if results == nil {
|
||||
results = []SearchChunkResult{}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// searchWithFullText 使用 PostgreSQL ts_rank + plainto_tsquery 搜索
|
||||
func (s *KnowledgeStore) searchWithFullText(kbID, query string, limit int) ([]SearchChunkResult, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT kc.id, kc.doc_id, kc.kb_id, kc.chunk_index, kc.content, kc.token_count, kc.created_at,
|
||||
ts_rank(kc.tsv, plainto_tsquery('chinese', $2)) AS relevance,
|
||||
kd.title AS document_title,
|
||||
kb.name AS kb_name
|
||||
FROM knowledge_chunks kc
|
||||
JOIN knowledge_documents kd ON kc.doc_id = kd.id
|
||||
JOIN knowledge_bases kb ON kc.kb_id = kb.id
|
||||
WHERE kc.kb_id = $1 AND kc.tsv @@ plainto_tsquery('chinese', $2)
|
||||
ORDER BY relevance DESC
|
||||
LIMIT $3`,
|
||||
kbID, query, limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanSearchResults(rows)
|
||||
}
|
||||
|
||||
// searchWithILike 使用 ILIKE 降级搜索
|
||||
func (s *KnowledgeStore) searchWithILike(kbID, query string, limit int) ([]SearchChunkResult, error) {
|
||||
// 构建 ILIKE 模式
|
||||
keywords := tokenizeQuery(query)
|
||||
if len(keywords) == 0 {
|
||||
return []SearchChunkResult{}, nil
|
||||
}
|
||||
|
||||
// 对每个关键词构建 ILIKE 条件
|
||||
conditions := make([]string, len(keywords))
|
||||
args := []interface{}{kbID}
|
||||
placeholderIdx := 2
|
||||
for i, kw := range keywords {
|
||||
conditions[i] = fmt.Sprintf("kc.content ILIKE $%d", placeholderIdx)
|
||||
args = append(args, "%"+kw+"%")
|
||||
placeholderIdx++
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
querySQL := fmt.Sprintf(
|
||||
`SELECT kc.id, kc.doc_id, kc.kb_id, kc.chunk_index, kc.content, kc.token_count, kc.created_at,
|
||||
0.0 AS relevance,
|
||||
kd.title AS document_title,
|
||||
kb.name AS kb_name
|
||||
FROM knowledge_chunks kc
|
||||
JOIN knowledge_documents kd ON kc.doc_id = kd.id
|
||||
JOIN knowledge_bases kb ON kc.kb_id = kb.id
|
||||
WHERE kc.kb_id = $1 AND (%s)
|
||||
LIMIT $%d`,
|
||||
strings.Join(conditions, " AND "),
|
||||
placeholderIdx,
|
||||
)
|
||||
|
||||
rows, err := s.db.Query(querySQL, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ILIKE搜索失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanSearchResults(rows)
|
||||
}
|
||||
|
||||
// searchAllWithILike 跨所有用户知识库使用 ILIKE 搜索
|
||||
func (s *KnowledgeStore) searchAllWithILike(userID, query string, limit int) ([]SearchChunkResult, error) {
|
||||
keywords := tokenizeQuery(query)
|
||||
if len(keywords) == 0 {
|
||||
return []SearchChunkResult{}, nil
|
||||
}
|
||||
|
||||
conditions := make([]string, len(keywords))
|
||||
args := []interface{}{userID}
|
||||
placeholderIdx := 2
|
||||
for i, kw := range keywords {
|
||||
conditions[i] = fmt.Sprintf("kc.content ILIKE $%d", placeholderIdx)
|
||||
args = append(args, "%"+kw+"%")
|
||||
placeholderIdx++
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
querySQL := fmt.Sprintf(
|
||||
`SELECT kc.id, kc.doc_id, kc.kb_id, kc.chunk_index, kc.content, kc.token_count, kc.created_at,
|
||||
0.0 AS relevance,
|
||||
kd.title AS document_title,
|
||||
kb.name AS kb_name
|
||||
FROM knowledge_chunks kc
|
||||
JOIN knowledge_documents kd ON kc.doc_id = kd.id
|
||||
JOIN knowledge_bases kb ON kc.kb_id = kb.id
|
||||
WHERE kb.user_id = $1 AND (%s)
|
||||
ORDER BY kc.created_at DESC
|
||||
LIMIT $%d`,
|
||||
strings.Join(conditions, " AND "),
|
||||
placeholderIdx,
|
||||
)
|
||||
|
||||
rows, err := s.db.Query(querySQL, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("全知识库ILIKE搜索失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanSearchResults(rows)
|
||||
}
|
||||
|
||||
// scanSearchResults 扫描搜索结果
|
||||
func scanSearchResults(rows *sql.Rows) ([]SearchChunkResult, error) {
|
||||
var results []SearchChunkResult
|
||||
for rows.Next() {
|
||||
var r SearchChunkResult
|
||||
if err := rows.Scan(&r.ID, &r.DocID, &r.KBID, &r.ChunkIndex, &r.Content,
|
||||
&r.TokenCount, &r.CreatedAt, &r.Relevance, &r.DocumentTitle, &r.KBName); err != nil {
|
||||
return nil, fmt.Errorf("扫描搜索结果行失败: %w", err)
|
||||
}
|
||||
// 生成高亮片段
|
||||
r.Headline = r.Content
|
||||
results = append(results, r)
|
||||
}
|
||||
if results == nil {
|
||||
results = []SearchChunkResult{}
|
||||
}
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
// ========== 文本分块函数 ==========
|
||||
|
||||
// splitTextIntoChunks 将文本按 maxLen 分块,块之间有 overlap 字符重叠
|
||||
func splitTextIntoChunks(text string, maxLen int, overlap int) []string {
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 按段落分割
|
||||
paragraphs := strings.Split(text, "\n\n")
|
||||
var chunks []string
|
||||
var currentChunk strings.Builder
|
||||
|
||||
for _, para := range paragraphs {
|
||||
para = strings.TrimSpace(para)
|
||||
if para == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
paraLen := utf8.RuneCountInString(para)
|
||||
|
||||
if paraLen <= maxLen {
|
||||
// 如果当前块 + 段落不超过 maxLen,追加到当前块
|
||||
if utf8.RuneCountInString(currentChunk.String()) == 0 {
|
||||
currentChunk.WriteString(para)
|
||||
} else if utf8.RuneCountInString(currentChunk.String())+1+paraLen <= maxLen {
|
||||
currentChunk.WriteString("\n\n")
|
||||
currentChunk.WriteString(para)
|
||||
} else {
|
||||
// 保存当前块,开始新块
|
||||
chunks = append(chunks, currentChunk.String())
|
||||
currentChunk.Reset()
|
||||
currentChunk.WriteString(para)
|
||||
}
|
||||
} else {
|
||||
// 段落超过 maxLen,需要按句子分割
|
||||
// 先保存当前块
|
||||
if currentChunk.Len() > 0 {
|
||||
chunks = append(chunks, currentChunk.String())
|
||||
currentChunk.Reset()
|
||||
}
|
||||
|
||||
// 按句子分割
|
||||
sentences := splitIntoSentences(para)
|
||||
for _, sent := range sentences {
|
||||
sent = strings.TrimSpace(sent)
|
||||
if sent == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
sentLen := utf8.RuneCountInString(sent)
|
||||
|
||||
if sentLen <= maxLen {
|
||||
if utf8.RuneCountInString(currentChunk.String()) == 0 {
|
||||
currentChunk.WriteString(sent)
|
||||
} else if utf8.RuneCountInString(currentChunk.String())+sentLen <= maxLen {
|
||||
currentChunk.WriteString(sent)
|
||||
} else {
|
||||
chunks = append(chunks, currentChunk.String())
|
||||
currentChunk.Reset()
|
||||
currentChunk.WriteString(sent)
|
||||
}
|
||||
} else {
|
||||
// 句子超过 maxLen,按 maxLen 截断
|
||||
if currentChunk.Len() > 0 {
|
||||
chunks = append(chunks, currentChunk.String())
|
||||
currentChunk.Reset()
|
||||
}
|
||||
|
||||
// 按 maxLen 截断,带 overlap
|
||||
runes := []rune(sent)
|
||||
start := 0
|
||||
for start < len(runes) {
|
||||
end := start + maxLen
|
||||
if end > len(runes) {
|
||||
end = len(runes)
|
||||
}
|
||||
chunks = append(chunks, string(runes[start:end]))
|
||||
if end >= len(runes) {
|
||||
break
|
||||
}
|
||||
// 下一块从 end-overlap 开始
|
||||
start = end - overlap
|
||||
if start <= 0 {
|
||||
start = end
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存最后一个块
|
||||
if currentChunk.Len() > 0 {
|
||||
chunks = append(chunks, currentChunk.String())
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// splitIntoSentences 按句子分割文本(中文。!?和英文标点)
|
||||
func splitIntoSentences(text string) []string {
|
||||
var sentences []string
|
||||
runes := []rune(text)
|
||||
var current strings.Builder
|
||||
|
||||
for i := 0; i < len(runes); i++ {
|
||||
current.WriteRune(runes[i])
|
||||
|
||||
// 检查句子结束标志
|
||||
if runes[i] == '。' || runes[i] == '!' || runes[i] == '?' ||
|
||||
runes[i] == '!' || runes[i] == '?' ||
|
||||
(runes[i] == '\n' && i+1 < len(runes) && runes[i+1] != '\n') {
|
||||
sentences = append(sentences, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
// 剩余内容
|
||||
if current.Len() > 0 {
|
||||
remaining := strings.TrimSpace(current.String())
|
||||
if remaining != "" {
|
||||
sentences = append(sentences, remaining)
|
||||
}
|
||||
}
|
||||
|
||||
return sentences
|
||||
}
|
||||
|
||||
// estimateTokenCount 估算 token 数量(中文按每个字符1.2个token,英文按每4个字符1个token)
|
||||
func estimateTokenCount(text string) int {
|
||||
runes := []rune(text)
|
||||
total := 0
|
||||
for _, r := range runes {
|
||||
if r >= 0x4e00 && r <= 0x9fff {
|
||||
// 中文字符,约1.2个token
|
||||
total += 1
|
||||
}
|
||||
}
|
||||
// 非中文字符粗略估算:字符数/4
|
||||
nonChinese := len(runes) - total
|
||||
total = int(float64(total)*1.2) + nonChinese/4
|
||||
if total < 1 {
|
||||
total = 1
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// tokenizeQuery 将查询字符串分词(简单按空格和标点分割)
|
||||
func tokenizeQuery(query string) []string {
|
||||
// 按空格、中文标点、英文标点分割
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 先用空格分割
|
||||
parts := strings.Fields(query)
|
||||
var tokens []string
|
||||
for _, part := range parts {
|
||||
part = strings.Trim(part, "。!?!?,.;::;、()()[]{}《》\"'")
|
||||
if part != "" {
|
||||
tokens = append(tokens, part)
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// GenerateUUID 使用 crypto/rand 生成 UUID v4 格式的字符串(导出供其他包使用)
|
||||
func GenerateUUID() string {
|
||||
return generateUUIDv4()
|
||||
}
|
||||
|
||||
// generateUUIDv4 使用 crypto/rand 生成 UUID v4 格式的字符串
|
||||
func generateUUIDv4() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// 降级方案:基于时间戳 + 简单随机
|
||||
ts := time.Now().UnixNano()
|
||||
for i := 0; i < 16; i++ {
|
||||
b[i] = byte((ts >> (i * 4)) & 0xFF)
|
||||
}
|
||||
}
|
||||
// 设置 UUID v4 版本位 (version = 4)
|
||||
b[6] = (b[6] & 0x0f) | 0x40
|
||||
// 设置 UUID variant 位 (variant = 10xx)
|
||||
b[8] = (b[8] & 0x3f) | 0x80
|
||||
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
|
||||
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Reminder 提醒模型
|
||||
type Reminder struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
RemindAt time.Time `json:"remind_at"`
|
||||
Status string `json:"status"` // pending, completed, cancelled
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty"`
|
||||
RepeatType string `json:"repeat_type"` // none, daily, weekly, monthly
|
||||
SessionID string `json:"session_id"`
|
||||
Notified bool `json:"notified"`
|
||||
}
|
||||
|
||||
// ReminderStore 提醒持久化存储
|
||||
type ReminderStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewReminderStore 使用已有数据库连接初始化提醒存储并自动建表
|
||||
func NewReminderStore(db *sql.DB) (*ReminderStore, error) {
|
||||
store := &ReminderStore{db: db}
|
||||
|
||||
if err := store.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("提醒表迁移失败: %w", err)
|
||||
}
|
||||
|
||||
log.Println("[ReminderStore] 提醒持久化存储已初始化")
|
||||
return store, nil
|
||||
}
|
||||
|
||||
// migrate 自动创建提醒表结构
|
||||
func (s *ReminderStore) migrate() error {
|
||||
queries := []string{
|
||||
`CREATE TABLE IF NOT EXISTS reminders (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
title VARCHAR(500) NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
remind_at TIMESTAMPTZ NOT NULL,
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
completed_at TIMESTAMPTZ,
|
||||
repeat_type VARCHAR(20) DEFAULT '',
|
||||
session_id VARCHAR(36) DEFAULT '',
|
||||
notified BOOLEAN DEFAULT FALSE
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_reminders_user_id ON reminders(user_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_reminders_remind_at ON reminders(remind_at)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_reminders_status ON reminders(status)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_reminders_due ON reminders(remind_at, status, notified)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
if _, err := s.db.Exec(q); err != nil {
|
||||
return fmt.Errorf("迁移SQL执行失败: %w\nSQL: %s", err, q)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateReminder 创建新提醒
|
||||
func (s *ReminderStore) CreateReminder(r *Reminder) error {
|
||||
_, err := s.db.Exec(
|
||||
`INSERT INTO reminders (id, user_id, title, description, remind_at, status, repeat_type, session_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
|
||||
r.ID, r.UserID, r.Title, r.Description, r.RemindAt, r.Status, r.RepeatType, r.SessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建提醒失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRemindersByUser 获取用户的提醒列表(可按状态筛选,按 remind_at 升序)
|
||||
func (s *ReminderStore) GetRemindersByUser(userID, status string, limit, offset int) ([]Reminder, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
if status != "" {
|
||||
rows, err = s.db.Query(
|
||||
`SELECT id, user_id, title, description, remind_at, status, created_at, completed_at, repeat_type, session_id, notified
|
||||
FROM reminders WHERE user_id = $1 AND status = $2
|
||||
ORDER BY remind_at ASC LIMIT $3 OFFSET $4`,
|
||||
userID, status, limit, offset,
|
||||
)
|
||||
} else {
|
||||
rows, err = s.db.Query(
|
||||
`SELECT id, user_id, title, description, remind_at, status, created_at, completed_at, repeat_type, session_id, notified
|
||||
FROM reminders WHERE user_id = $1
|
||||
ORDER BY remind_at ASC LIMIT $2 OFFSET $3`,
|
||||
userID, limit, offset,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询提醒列表失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var reminders []Reminder
|
||||
for rows.Next() {
|
||||
var r Reminder
|
||||
if err := rows.Scan(&r.ID, &r.UserID, &r.Title, &r.Description, &r.RemindAt,
|
||||
&r.Status, &r.CreatedAt, &r.CompletedAt, &r.RepeatType, &r.SessionID, &r.Notified); err != nil {
|
||||
return nil, fmt.Errorf("扫描提醒行失败: %w", err)
|
||||
}
|
||||
reminders = append(reminders, r)
|
||||
}
|
||||
|
||||
if reminders == nil {
|
||||
reminders = []Reminder{}
|
||||
}
|
||||
return reminders, rows.Err()
|
||||
}
|
||||
|
||||
// GetDueReminders 获取所有到期且未通知的提醒
|
||||
func (s *ReminderStore) GetDueReminders() ([]Reminder, error) {
|
||||
rows, err := s.db.Query(
|
||||
`SELECT id, user_id, title, description, remind_at, status, created_at, completed_at, repeat_type, session_id, notified
|
||||
FROM reminders
|
||||
WHERE remind_at <= NOW() AND status = 'pending' AND notified = FALSE
|
||||
ORDER BY remind_at ASC`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询到期提醒失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var reminders []Reminder
|
||||
for rows.Next() {
|
||||
var r Reminder
|
||||
if err := rows.Scan(&r.ID, &r.UserID, &r.Title, &r.Description, &r.RemindAt,
|
||||
&r.Status, &r.CreatedAt, &r.CompletedAt, &r.RepeatType, &r.SessionID, &r.Notified); err != nil {
|
||||
return nil, fmt.Errorf("扫描到期提醒行失败: %w", err)
|
||||
}
|
||||
reminders = append(reminders, r)
|
||||
}
|
||||
|
||||
if reminders == nil {
|
||||
reminders = []Reminder{}
|
||||
}
|
||||
return reminders, rows.Err()
|
||||
}
|
||||
|
||||
// MarkNotified 标记提醒为已通知
|
||||
func (s *ReminderStore) MarkNotified(id string) error {
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE reminders SET notified = TRUE WHERE id = $1`,
|
||||
id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("标记提醒已通知失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateReminder 更新提醒字段
|
||||
func (s *ReminderStore) UpdateReminder(id string, r *Reminder) error {
|
||||
_, err := s.db.Exec(
|
||||
`UPDATE reminders SET title = $1, description = $2, remind_at = $3, status = $4,
|
||||
completed_at = $5, repeat_type = $6, session_id = $7, notified = $8
|
||||
WHERE id = $9`,
|
||||
r.Title, r.Description, r.RemindAt, r.Status, r.CompletedAt, r.RepeatType, r.SessionID, r.Notified, id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新提醒失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteReminder 删除提醒
|
||||
func (s *ReminderStore) DeleteReminder(id string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM reminders WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除提醒失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -253,6 +253,70 @@ func (s *SessionStore) ClearSessionMessages(sessionID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SearchResult 搜索结果
|
||||
type SearchResult struct {
|
||||
MessageID int `json:"message_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
SessionTitle string `json:"session_title"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SearchMessages 全文搜索消息 (使用 ILIKE 进行模糊匹配)
|
||||
// 返回搜索结果列表、总数和可能的错误
|
||||
func (s *SessionStore) SearchMessages(userID, query string, limit, offset int) ([]SearchResult, int, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
// 获取匹配总数
|
||||
var total int
|
||||
countSQL := `SELECT COUNT(*) FROM messages m
|
||||
JOIN sessions s ON m.session_id = s.id
|
||||
WHERE s.user_id = $1 AND m.content ILIKE '%' || $2 || '%'`
|
||||
if err := s.db.QueryRow(countSQL, userID, query).Scan(&total); err != nil {
|
||||
return nil, 0, fmt.Errorf("搜索计数失败: %w", err)
|
||||
}
|
||||
|
||||
// 分页查询,关联 sessions 获取会话标题
|
||||
rows, err := s.db.Query(
|
||||
`SELECT m.id, m.session_id, COALESCE(s.title, '') AS session_title, m.role, m.content, m.created_at
|
||||
FROM messages m
|
||||
JOIN sessions s ON m.session_id = s.id
|
||||
WHERE s.user_id = $1 AND m.content ILIKE '%' || $2 || '%'
|
||||
ORDER BY m.created_at DESC
|
||||
LIMIT $3 OFFSET $4`,
|
||||
userID, query, limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("搜索消息失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []SearchResult
|
||||
for rows.Next() {
|
||||
var r SearchResult
|
||||
if err := rows.Scan(&r.MessageID, &r.SessionID, &r.SessionTitle, &r.Role, &r.Content, &r.CreatedAt); err != nil {
|
||||
return nil, 0, fmt.Errorf("扫描搜索结果行失败: %w", err)
|
||||
}
|
||||
results = append(results, r)
|
||||
}
|
||||
|
||||
if results == nil {
|
||||
results = []SearchResult{}
|
||||
}
|
||||
return results, total, rows.Err()
|
||||
}
|
||||
|
||||
// DB 返回底层数据库连接,供其他 store 复用
|
||||
func (s *SessionStore) DB() *sql.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (s *SessionStore) Close() error {
|
||||
if s.db != nil {
|
||||
|
||||
Reference in New Issue
Block a user