591 lines
16 KiB
Go
591 lines
16 KiB
Go
package handler
|
||
|
||
import (
|
||
"html"
|
||
"log"
|
||
"net/http"
|
||
"os"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
|
||
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
|
||
"github.com/yourname/cyrene-ai/gateway/internal/store"
|
||
)
|
||
|
||
// KnowledgeHandler 知识库处理器
|
||
type KnowledgeHandler struct {
|
||
store *store.KnowledgeStore
|
||
fileStore *store.FileStore
|
||
}
|
||
|
||
// NewKnowledgeHandler 创建知识库处理器
|
||
func NewKnowledgeHandler(s *store.KnowledgeStore, fs *store.FileStore) *KnowledgeHandler {
|
||
return &KnowledgeHandler{store: s, fileStore: fs}
|
||
}
|
||
|
||
// checkStore 检查知识库存储是否可用,不可用时返回 true(调用方应 return)
|
||
func (h *KnowledgeHandler) checkStore(c *gin.Context) bool {
|
||
if h.store == nil {
|
||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||
"error": "知识库服务不可用(数据库未连接)",
|
||
"errorType": "service_unavailable",
|
||
})
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
// ========== 请求/响应类型 ==========
|
||
|
||
type createKBRequest struct {
|
||
Name string `json:"name" binding:"required"`
|
||
Description string `json:"description"`
|
||
}
|
||
|
||
type updateKBRequest struct {
|
||
Name string `json:"name" binding:"required"`
|
||
Description string `json:"description"`
|
||
}
|
||
|
||
type addDocRequest struct {
|
||
Title string `json:"title" binding:"required"`
|
||
Content string `json:"content"`
|
||
SourceType string `json:"source_type"`
|
||
FileID string `json:"file_id"`
|
||
}
|
||
|
||
type searchRequest struct {
|
||
Query string `json:"query" binding:"required"`
|
||
KBIDs []string `json:"kb_ids"`
|
||
Limit int `json:"limit"`
|
||
}
|
||
|
||
// ========== POST /api/v1/knowledge/bases — 创建知识库 ==========
|
||
|
||
func (h *KnowledgeHandler) CreateKB(c *gin.Context) {
|
||
if h.checkStore(c) {
|
||
return
|
||
}
|
||
userID := middleware.GetUserID(c)
|
||
|
||
var req createKBRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供知识库名称", "errorType": "invalid_request"})
|
||
return
|
||
}
|
||
|
||
kb := &store.KnowledgeBase{
|
||
ID: store.GenerateUUID(),
|
||
UserID: userID,
|
||
Name: html.EscapeString(req.Name),
|
||
Description: html.EscapeString(req.Description),
|
||
}
|
||
|
||
if err := h.store.CreateKB(kb); err != nil {
|
||
log.Printf("[KnowledgeHandler] 创建知识库失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建知识库失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusCreated, kb)
|
||
}
|
||
|
||
// ========== GET /api/v1/knowledge/bases — 列出用户的知识库 ==========
|
||
|
||
func (h *KnowledgeHandler) ListKBs(c *gin.Context) {
|
||
if h.checkStore(c) {
|
||
return
|
||
}
|
||
userID := middleware.GetUserID(c)
|
||
|
||
kbs, err := h.store.GetKBsByUser(userID)
|
||
if err != nil {
|
||
log.Printf("[KnowledgeHandler] 查询知识库列表失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库列表失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"knowledge_bases": kbs, "total": len(kbs)})
|
||
}
|
||
|
||
// ========== GET /api/v1/knowledge/bases/:id — 获取知识库详情 ==========
|
||
|
||
func (h *KnowledgeHandler) GetKB(c *gin.Context) {
|
||
if h.checkStore(c) {
|
||
return
|
||
}
|
||
userID := middleware.GetUserID(c)
|
||
kbID := c.Param("id")
|
||
|
||
kb, err := h.store.GetKB(kbID)
|
||
if err != nil {
|
||
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
if kb == nil {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
|
||
return
|
||
}
|
||
if kb.UserID != userID {
|
||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此知识库", "errorType": "access_denied"})
|
||
return
|
||
}
|
||
|
||
// 获取文档列表
|
||
docs, err := h.store.GetDocumentsByKB(kbID)
|
||
if err != nil {
|
||
log.Printf("[KnowledgeHandler] 查询文档列表失败: %v", err)
|
||
docs = []store.KnowledgeDocument{}
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"knowledge_base": kb,
|
||
"documents": docs,
|
||
})
|
||
}
|
||
|
||
// ========== PUT /api/v1/knowledge/bases/:id — 更新知识库 ==========
|
||
|
||
func (h *KnowledgeHandler) UpdateKB(c *gin.Context) {
|
||
if h.checkStore(c) {
|
||
return
|
||
}
|
||
userID := middleware.GetUserID(c)
|
||
kbID := c.Param("id")
|
||
|
||
var req updateKBRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供知识库名称", "errorType": "invalid_request"})
|
||
return
|
||
}
|
||
|
||
kb, err := h.store.GetKB(kbID)
|
||
if err != nil {
|
||
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
if kb == nil {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
|
||
return
|
||
}
|
||
if kb.UserID != userID {
|
||
c.JSON(http.StatusForbidden, gin.H{"error": "无权修改此知识库", "errorType": "access_denied"})
|
||
return
|
||
}
|
||
|
||
if err := h.store.UpdateKB(kbID, html.EscapeString(req.Name), html.EscapeString(req.Description)); err != nil {
|
||
log.Printf("[KnowledgeHandler] 更新知识库失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新知识库失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"status": "updated"})
|
||
}
|
||
|
||
// ========== DELETE /api/v1/knowledge/bases/:id — 删除知识库 ==========
|
||
|
||
func (h *KnowledgeHandler) DeleteKB(c *gin.Context) {
|
||
if h.checkStore(c) {
|
||
return
|
||
}
|
||
userID := middleware.GetUserID(c)
|
||
kbID := c.Param("id")
|
||
|
||
kb, err := h.store.GetKB(kbID)
|
||
if err != nil {
|
||
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
if kb == nil {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
|
||
return
|
||
}
|
||
if kb.UserID != userID {
|
||
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此知识库", "errorType": "access_denied"})
|
||
return
|
||
}
|
||
|
||
if err := h.store.DeleteKB(kbID); err != nil {
|
||
log.Printf("[KnowledgeHandler] 删除知识库失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除知识库失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||
}
|
||
|
||
// ========== POST /api/v1/knowledge/bases/:id/documents — 添加文档 ==========
|
||
|
||
func (h *KnowledgeHandler) AddDocument(c *gin.Context) {
|
||
if h.checkStore(c) {
|
||
return
|
||
}
|
||
userID := middleware.GetUserID(c)
|
||
kbID := c.Param("id")
|
||
|
||
// 检查知识库是否存在且属于当前用户
|
||
kb, err := h.store.GetKB(kbID)
|
||
if err != nil {
|
||
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
if kb == nil {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
|
||
return
|
||
}
|
||
if kb.UserID != userID {
|
||
c.JSON(http.StatusForbidden, gin.H{"error": "无权操作此知识库", "errorType": "access_denied"})
|
||
return
|
||
}
|
||
|
||
var req addDocRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文档标题", "errorType": "invalid_request"})
|
||
return
|
||
}
|
||
|
||
if req.SourceType == "" {
|
||
req.SourceType = "text"
|
||
}
|
||
|
||
var content string
|
||
var sourceRef string
|
||
var contentType string
|
||
|
||
switch req.SourceType {
|
||
case "text":
|
||
content = req.Content
|
||
contentType = "text/plain"
|
||
if content == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文档内容", "errorType": "invalid_request"})
|
||
return
|
||
}
|
||
case "file":
|
||
if req.FileID == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文件ID", "errorType": "invalid_request"})
|
||
return
|
||
}
|
||
sourceRef = req.FileID
|
||
|
||
// 从 FileStore 读取文件内容
|
||
if h.fileStore != nil {
|
||
f, err := h.fileStore.GetFile(req.FileID)
|
||
if err != nil || f == nil {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在", "errorType": "not_found"})
|
||
return
|
||
}
|
||
if f.UserID != userID {
|
||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文件", "errorType": "access_denied"})
|
||
return
|
||
}
|
||
// 注意:这里只支持文本类型文件的读取
|
||
// 对于二进制文件,需要更复杂的解析逻辑
|
||
sourceRef = f.Filename
|
||
// 从磁盘读取文件内容
|
||
content, contentType, _ = readFileContent(f.StoredPath, f.MimeType)
|
||
if content == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "无法读取文件内容,仅支持文本文件", "errorType": "unsupported_file"})
|
||
return
|
||
}
|
||
} else {
|
||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
|
||
return
|
||
}
|
||
case "url":
|
||
sourceRef = req.FileID
|
||
content = req.Content
|
||
contentType = "text/html"
|
||
default:
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的来源类型: " + req.SourceType, "errorType": "invalid_request"})
|
||
return
|
||
}
|
||
|
||
if content == "" {
|
||
content = req.Content
|
||
}
|
||
if contentType == "" {
|
||
contentType = "text/plain"
|
||
}
|
||
|
||
doc := &store.KnowledgeDocument{
|
||
ID: store.GenerateUUID(),
|
||
KBID: kbID,
|
||
UserID: userID,
|
||
Title: html.EscapeString(req.Title),
|
||
SourceType: html.EscapeString(req.SourceType),
|
||
SourceRef: sourceRef,
|
||
ContentType: contentType,
|
||
RawContent: content,
|
||
}
|
||
|
||
if err := h.store.AddDocument(doc); err != nil {
|
||
log.Printf("[KnowledgeHandler] 添加文档失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加文档失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
|
||
// 自动分块
|
||
chunkCount, err := h.store.ChunkDocument(doc.ID)
|
||
if err != nil {
|
||
log.Printf("[KnowledgeHandler] 文档分块失败: %v", err)
|
||
// 分块失败不影响文档创建
|
||
}
|
||
|
||
doc.ChunkCount = chunkCount
|
||
|
||
c.JSON(http.StatusCreated, doc)
|
||
}
|
||
|
||
// ========== GET /api/v1/knowledge/bases/:id/documents — 列出知识库中的文档 ==========
|
||
|
||
func (h *KnowledgeHandler) ListDocuments(c *gin.Context) {
|
||
if h.checkStore(c) {
|
||
return
|
||
}
|
||
userID := middleware.GetUserID(c)
|
||
kbID := c.Param("id")
|
||
|
||
// 检查权限
|
||
kb, err := h.store.GetKB(kbID)
|
||
if err != nil {
|
||
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
if kb == nil {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
|
||
return
|
||
}
|
||
if kb.UserID != userID {
|
||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此知识库", "errorType": "access_denied"})
|
||
return
|
||
}
|
||
|
||
docs, err := h.store.GetDocumentsByKB(kbID)
|
||
if err != nil {
|
||
log.Printf("[KnowledgeHandler] 查询文档列表失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档列表失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"documents": docs, "total": len(docs)})
|
||
}
|
||
|
||
// ========== GET /api/v1/knowledge/documents/:id — 获取文档详情 ==========
|
||
|
||
func (h *KnowledgeHandler) GetDocument(c *gin.Context) {
|
||
if h.checkStore(c) {
|
||
return
|
||
}
|
||
userID := middleware.GetUserID(c)
|
||
docID := c.Param("id")
|
||
|
||
doc, err := h.store.GetDocument(docID)
|
||
if err != nil {
|
||
log.Printf("[KnowledgeHandler] 查询文档失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
if doc == nil {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "文档不存在", "errorType": "not_found"})
|
||
return
|
||
}
|
||
if doc.UserID != userID {
|
||
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文档", "errorType": "access_denied"})
|
||
return
|
||
}
|
||
|
||
// 获取分块
|
||
chunks, err := h.store.GetChunksByDocID(docID)
|
||
if err != nil {
|
||
log.Printf("[KnowledgeHandler] 查询分块失败: %v", err)
|
||
chunks = []store.KnowledgeChunk{}
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"document": doc,
|
||
"chunks": chunks,
|
||
})
|
||
}
|
||
|
||
// ========== DELETE /api/v1/knowledge/documents/:id — 删除文档 ==========
|
||
|
||
func (h *KnowledgeHandler) DeleteDocument(c *gin.Context) {
|
||
if h.checkStore(c) {
|
||
return
|
||
}
|
||
userID := middleware.GetUserID(c)
|
||
docID := c.Param("id")
|
||
|
||
doc, err := h.store.GetDocument(docID)
|
||
if err != nil {
|
||
log.Printf("[KnowledgeHandler] 查询文档失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
if doc == nil {
|
||
c.JSON(http.StatusNotFound, gin.H{"error": "文档不存在", "errorType": "not_found"})
|
||
return
|
||
}
|
||
if doc.UserID != userID {
|
||
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此文档", "errorType": "access_denied"})
|
||
return
|
||
}
|
||
|
||
if err := h.store.DeleteDocument(docID); err != nil {
|
||
log.Printf("[KnowledgeHandler] 删除文档失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除文档失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||
}
|
||
|
||
// ========== POST /api/v1/knowledge/search — 搜索知识库 ==========
|
||
|
||
func (h *KnowledgeHandler) Search(c *gin.Context) {
|
||
if h.checkStore(c) {
|
||
return
|
||
}
|
||
userID := middleware.GetUserID(c)
|
||
|
||
var req searchRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供搜索关键词", "errorType": "invalid_request"})
|
||
return
|
||
}
|
||
|
||
if req.Limit <= 0 {
|
||
req.Limit = 5
|
||
}
|
||
if req.Limit > 50 {
|
||
req.Limit = 50
|
||
}
|
||
|
||
var results []store.SearchChunkResult
|
||
var err error
|
||
|
||
if len(req.KBIDs) == 0 {
|
||
// 搜索所有知识库
|
||
results, err = h.store.SearchAllKBs(userID, req.Query, req.Limit)
|
||
} else {
|
||
// 搜索指定知识库,需要验证权限
|
||
results = []store.SearchChunkResult{}
|
||
for _, kbID := range req.KBIDs {
|
||
kb, checkErr := h.store.GetKB(kbID)
|
||
if checkErr != nil || kb == nil || kb.UserID != userID {
|
||
continue // 跳过无权限或不存在的知识库
|
||
}
|
||
kbResults, searchErr := h.store.SearchChunks(kbID, req.Query, req.Limit)
|
||
if searchErr != nil {
|
||
log.Printf("[KnowledgeHandler] 搜索知识库 %s 失败: %v", kbID, searchErr)
|
||
continue
|
||
}
|
||
results = append(results, kbResults...)
|
||
}
|
||
|
||
// 限制总结果数
|
||
if len(results) > req.Limit {
|
||
results = results[:req.Limit]
|
||
}
|
||
}
|
||
|
||
if err != nil {
|
||
log.Printf("[KnowledgeHandler] 搜索失败: %v", err)
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "搜索失败", "errorType": "db_error"})
|
||
return
|
||
}
|
||
|
||
// 生成高亮片段
|
||
for i := range results {
|
||
results[i].Headline = generateHeadline(results[i].Content, req.Query, 200)
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"chunks": results,
|
||
"total": len(results),
|
||
"query": req.Query,
|
||
})
|
||
}
|
||
|
||
// ========== 辅助函数 ==========
|
||
|
||
// generateHeadline 生成高亮片段,提取查询关键词周围的文本
|
||
func generateHeadline(content, query string, maxLen int) string {
|
||
if maxLen <= 0 {
|
||
maxLen = 200
|
||
}
|
||
|
||
runes := []rune(content)
|
||
if len(runes) <= maxLen {
|
||
return content
|
||
}
|
||
|
||
// 查找查询关键词位置
|
||
queryRunes := []rune(query)
|
||
pos := -1
|
||
for i := 0; i <= len(runes)-len(queryRunes); i++ {
|
||
match := true
|
||
for j := 0; j < len(queryRunes); j++ {
|
||
if runes[i+j] != queryRunes[j] {
|
||
match = false
|
||
break
|
||
}
|
||
}
|
||
if match {
|
||
pos = i
|
||
break
|
||
}
|
||
}
|
||
|
||
if pos < 0 {
|
||
// 没有找到精确匹配,返回前 maxLen 个字符
|
||
return string(runes[:maxLen]) + "..."
|
||
}
|
||
|
||
// 以匹配位置为中心,截取上下文
|
||
half := maxLen / 2
|
||
start := pos - half
|
||
if start < 0 {
|
||
start = 0
|
||
}
|
||
end := start + maxLen
|
||
if end > len(runes) {
|
||
end = len(runes)
|
||
start = end - maxLen
|
||
if start < 0 {
|
||
start = 0
|
||
}
|
||
}
|
||
|
||
result := string(runes[start:end])
|
||
if start > 0 {
|
||
result = "..." + result
|
||
}
|
||
if end < len(runes) {
|
||
result = result + "..."
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// readFileContent 从磁盘读取文件内容 (仅支持文本类型)
|
||
func readFileContent(path, mimeType string) (content string, contentType string, err error) {
|
||
// 只支持文本类型
|
||
if !strings.HasPrefix(mimeType, "text/") && mimeType != "application/json" {
|
||
return "", "", nil
|
||
}
|
||
|
||
data, err := os.ReadFile(path)
|
||
if err != nil {
|
||
return "", "", err
|
||
}
|
||
|
||
return string(data), mimeType, nil
|
||
}
|