Files
Cyrene/backend/gateway/internal/handler/knowledge_handler.go
T

591 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"html"
"log"
"net/http"
"os"
"strings"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/store"
)
// KnowledgeHandler 知识库处理器
type KnowledgeHandler struct {
store *store.KnowledgeStore
fileStore *store.FileStore
}
// NewKnowledgeHandler 创建知识库处理器
func NewKnowledgeHandler(s *store.KnowledgeStore, fs *store.FileStore) *KnowledgeHandler {
return &KnowledgeHandler{store: s, fileStore: fs}
}
// checkStore 检查知识库存储是否可用,不可用时返回 true(调用方应 return
func (h *KnowledgeHandler) checkStore(c *gin.Context) bool {
if h.store == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": "知识库服务不可用(数据库未连接)",
"errorType": "service_unavailable",
})
return true
}
return false
}
// ========== 请求/响应类型 ==========
type createKBRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
}
type updateKBRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
}
type addDocRequest struct {
Title string `json:"title" binding:"required"`
Content string `json:"content"`
SourceType string `json:"source_type"`
FileID string `json:"file_id"`
}
type searchRequest struct {
Query string `json:"query" binding:"required"`
KBIDs []string `json:"kb_ids"`
Limit int `json:"limit"`
}
// ========== POST /api/v1/knowledge/bases — 创建知识库 ==========
func (h *KnowledgeHandler) CreateKB(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
var req createKBRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供知识库名称", "errorType": "invalid_request"})
return
}
kb := &store.KnowledgeBase{
ID: store.GenerateUUID(),
UserID: userID,
Name: html.EscapeString(req.Name),
Description: html.EscapeString(req.Description),
}
if err := h.store.CreateKB(kb); err != nil {
log.Printf("[KnowledgeHandler] 创建知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建知识库失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusCreated, kb)
}
// ========== GET /api/v1/knowledge/bases — 列出用户的知识库 ==========
func (h *KnowledgeHandler) ListKBs(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbs, err := h.store.GetKBsByUser(userID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库列表失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库列表失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"knowledge_bases": kbs, "total": len(kbs)})
}
// ========== GET /api/v1/knowledge/bases/:id — 获取知识库详情 ==========
func (h *KnowledgeHandler) GetKB(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此知识库", "errorType": "access_denied"})
return
}
// 获取文档列表
docs, err := h.store.GetDocumentsByKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询文档列表失败: %v", err)
docs = []store.KnowledgeDocument{}
}
c.JSON(http.StatusOK, gin.H{
"knowledge_base": kb,
"documents": docs,
})
}
// ========== PUT /api/v1/knowledge/bases/:id — 更新知识库 ==========
func (h *KnowledgeHandler) UpdateKB(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
var req updateKBRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供知识库名称", "errorType": "invalid_request"})
return
}
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权修改此知识库", "errorType": "access_denied"})
return
}
if err := h.store.UpdateKB(kbID, html.EscapeString(req.Name), html.EscapeString(req.Description)); err != nil {
log.Printf("[KnowledgeHandler] 更新知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新知识库失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"status": "updated"})
}
// ========== DELETE /api/v1/knowledge/bases/:id — 删除知识库 ==========
func (h *KnowledgeHandler) DeleteKB(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此知识库", "errorType": "access_denied"})
return
}
if err := h.store.DeleteKB(kbID); err != nil {
log.Printf("[KnowledgeHandler] 删除知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除知识库失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
// ========== POST /api/v1/knowledge/bases/:id/documents — 添加文档 ==========
func (h *KnowledgeHandler) AddDocument(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
// 检查知识库是否存在且属于当前用户
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权操作此知识库", "errorType": "access_denied"})
return
}
var req addDocRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文档标题", "errorType": "invalid_request"})
return
}
if req.SourceType == "" {
req.SourceType = "text"
}
var content string
var sourceRef string
var contentType string
switch req.SourceType {
case "text":
content = req.Content
contentType = "text/plain"
if content == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文档内容", "errorType": "invalid_request"})
return
}
case "file":
if req.FileID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文件ID", "errorType": "invalid_request"})
return
}
sourceRef = req.FileID
// 从 FileStore 读取文件内容
if h.fileStore != nil {
f, err := h.fileStore.GetFile(req.FileID)
if err != nil || f == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在", "errorType": "not_found"})
return
}
if f.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文件", "errorType": "access_denied"})
return
}
// 注意:这里只支持文本类型文件的读取
// 对于二进制文件,需要更复杂的解析逻辑
sourceRef = f.Filename
// 从磁盘读取文件内容
content, contentType, _ = readFileContent(f.StoredPath, f.MimeType)
if content == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "无法读取文件内容,仅支持文本文件", "errorType": "unsupported_file"})
return
}
} else {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
return
}
case "url":
sourceRef = req.FileID
content = req.Content
contentType = "text/html"
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的来源类型: " + req.SourceType, "errorType": "invalid_request"})
return
}
if content == "" {
content = req.Content
}
if contentType == "" {
contentType = "text/plain"
}
doc := &store.KnowledgeDocument{
ID: store.GenerateUUID(),
KBID: kbID,
UserID: userID,
Title: html.EscapeString(req.Title),
SourceType: html.EscapeString(req.SourceType),
SourceRef: sourceRef,
ContentType: contentType,
RawContent: content,
}
if err := h.store.AddDocument(doc); err != nil {
log.Printf("[KnowledgeHandler] 添加文档失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加文档失败", "errorType": "db_error"})
return
}
// 自动分块
chunkCount, err := h.store.ChunkDocument(doc.ID)
if err != nil {
log.Printf("[KnowledgeHandler] 文档分块失败: %v", err)
// 分块失败不影响文档创建
}
doc.ChunkCount = chunkCount
c.JSON(http.StatusCreated, doc)
}
// ========== GET /api/v1/knowledge/bases/:id/documents — 列出知识库中的文档 ==========
func (h *KnowledgeHandler) ListDocuments(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
// 检查权限
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此知识库", "errorType": "access_denied"})
return
}
docs, err := h.store.GetDocumentsByKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询文档列表失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档列表失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"documents": docs, "total": len(docs)})
}
// ========== GET /api/v1/knowledge/documents/:id — 获取文档详情 ==========
func (h *KnowledgeHandler) GetDocument(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
docID := c.Param("id")
doc, err := h.store.GetDocument(docID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询文档失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档失败", "errorType": "db_error"})
return
}
if doc == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "文档不存在", "errorType": "not_found"})
return
}
if doc.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文档", "errorType": "access_denied"})
return
}
// 获取分块
chunks, err := h.store.GetChunksByDocID(docID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询分块失败: %v", err)
chunks = []store.KnowledgeChunk{}
}
c.JSON(http.StatusOK, gin.H{
"document": doc,
"chunks": chunks,
})
}
// ========== DELETE /api/v1/knowledge/documents/:id — 删除文档 ==========
func (h *KnowledgeHandler) DeleteDocument(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
docID := c.Param("id")
doc, err := h.store.GetDocument(docID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询文档失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档失败", "errorType": "db_error"})
return
}
if doc == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "文档不存在", "errorType": "not_found"})
return
}
if doc.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此文档", "errorType": "access_denied"})
return
}
if err := h.store.DeleteDocument(docID); err != nil {
log.Printf("[KnowledgeHandler] 删除文档失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除文档失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
// ========== POST /api/v1/knowledge/search — 搜索知识库 ==========
func (h *KnowledgeHandler) Search(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
var req searchRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供搜索关键词", "errorType": "invalid_request"})
return
}
if req.Limit <= 0 {
req.Limit = 5
}
if req.Limit > 50 {
req.Limit = 50
}
var results []store.SearchChunkResult
var err error
if len(req.KBIDs) == 0 {
// 搜索所有知识库
results, err = h.store.SearchAllKBs(userID, req.Query, req.Limit)
} else {
// 搜索指定知识库,需要验证权限
results = []store.SearchChunkResult{}
for _, kbID := range req.KBIDs {
kb, checkErr := h.store.GetKB(kbID)
if checkErr != nil || kb == nil || kb.UserID != userID {
continue // 跳过无权限或不存在的知识库
}
kbResults, searchErr := h.store.SearchChunks(kbID, req.Query, req.Limit)
if searchErr != nil {
log.Printf("[KnowledgeHandler] 搜索知识库 %s 失败: %v", kbID, searchErr)
continue
}
results = append(results, kbResults...)
}
// 限制总结果数
if len(results) > req.Limit {
results = results[:req.Limit]
}
}
if err != nil {
log.Printf("[KnowledgeHandler] 搜索失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "搜索失败", "errorType": "db_error"})
return
}
// 生成高亮片段
for i := range results {
results[i].Headline = generateHeadline(results[i].Content, req.Query, 200)
}
c.JSON(http.StatusOK, gin.H{
"chunks": results,
"total": len(results),
"query": req.Query,
})
}
// ========== 辅助函数 ==========
// generateHeadline 生成高亮片段,提取查询关键词周围的文本
func generateHeadline(content, query string, maxLen int) string {
if maxLen <= 0 {
maxLen = 200
}
runes := []rune(content)
if len(runes) <= maxLen {
return content
}
// 查找查询关键词位置
queryRunes := []rune(query)
pos := -1
for i := 0; i <= len(runes)-len(queryRunes); i++ {
match := true
for j := 0; j < len(queryRunes); j++ {
if runes[i+j] != queryRunes[j] {
match = false
break
}
}
if match {
pos = i
break
}
}
if pos < 0 {
// 没有找到精确匹配,返回前 maxLen 个字符
return string(runes[:maxLen]) + "..."
}
// 以匹配位置为中心,截取上下文
half := maxLen / 2
start := pos - half
if start < 0 {
start = 0
}
end := start + maxLen
if end > len(runes) {
end = len(runes)
start = end - maxLen
if start < 0 {
start = 0
}
}
result := string(runes[start:end])
if start > 0 {
result = "..." + result
}
if end < len(runes) {
result = result + "..."
}
return result
}
// readFileContent 从磁盘读取文件内容 (仅支持文本类型)
func readFileContent(path, mimeType string) (content string, contentType string, err error) {
// 只支持文本类型
if !strings.HasPrefix(mimeType, "text/") && mimeType != "application/json" {
return "", "", nil
}
data, err := os.ReadFile(path)
if err != nil {
return "", "", err
}
return string(data), mimeType, nil
}