package handler
import (
"html"
"log"
"net/http"
"os"
"strings"
"github.com/gin-gonic/gin"
"github.com/yourname/cyrene-ai/gateway/internal/middleware"
"github.com/yourname/cyrene-ai/gateway/internal/store"
)
// KnowledgeHandler 知识库处理器
type KnowledgeHandler struct {
store *store.KnowledgeStore
fileStore *store.FileStore
}
// NewKnowledgeHandler 创建知识库处理器
func NewKnowledgeHandler(s *store.KnowledgeStore, fs *store.FileStore) *KnowledgeHandler {
return &KnowledgeHandler{store: s, fileStore: fs}
}
// checkStore 检查知识库存储是否可用,不可用时返回 true(调用方应 return)
func (h *KnowledgeHandler) checkStore(c *gin.Context) bool {
if h.store == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{
"error": "知识库服务不可用(数据库未连接)",
"errorType": "service_unavailable",
})
return true
}
return false
}
// ========== 请求/响应类型 ==========
type createKBRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
}
type updateKBRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
}
type addDocRequest struct {
Title string `json:"title" binding:"required"`
Content string `json:"content"`
SourceType string `json:"source_type"`
FileID string `json:"file_id"`
}
type searchRequest struct {
Query string `json:"query" binding:"required"`
KBIDs []string `json:"kb_ids"`
Limit int `json:"limit"`
}
// ========== POST /api/v1/knowledge/bases — 创建知识库 ==========
func (h *KnowledgeHandler) CreateKB(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
var req createKBRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供知识库名称", "errorType": "invalid_request"})
return
}
kb := &store.KnowledgeBase{
ID: store.GenerateUUID(),
UserID: userID,
Name: html.EscapeString(req.Name),
Description: html.EscapeString(req.Description),
}
if err := h.store.CreateKB(kb); err != nil {
log.Printf("[KnowledgeHandler] 创建知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建知识库失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusCreated, kb)
}
// ========== GET /api/v1/knowledge/bases — 列出用户的知识库 ==========
func (h *KnowledgeHandler) ListKBs(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbs, err := h.store.GetKBsByUser(userID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库列表失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库列表失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"knowledge_bases": kbs, "total": len(kbs)})
}
// ========== GET /api/v1/knowledge/bases/:id — 获取知识库详情 ==========
func (h *KnowledgeHandler) GetKB(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此知识库", "errorType": "access_denied"})
return
}
// 获取文档列表
docs, err := h.store.GetDocumentsByKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询文档列表失败: %v", err)
docs = []store.KnowledgeDocument{}
}
c.JSON(http.StatusOK, gin.H{
"knowledge_base": kb,
"documents": docs,
})
}
// ========== PUT /api/v1/knowledge/bases/:id — 更新知识库 ==========
func (h *KnowledgeHandler) UpdateKB(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
var req updateKBRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供知识库名称", "errorType": "invalid_request"})
return
}
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权修改此知识库", "errorType": "access_denied"})
return
}
if err := h.store.UpdateKB(kbID, html.EscapeString(req.Name), html.EscapeString(req.Description)); err != nil {
log.Printf("[KnowledgeHandler] 更新知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "更新知识库失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"status": "updated"})
}
// ========== DELETE /api/v1/knowledge/bases/:id — 删除知识库 ==========
func (h *KnowledgeHandler) DeleteKB(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此知识库", "errorType": "access_denied"})
return
}
if err := h.store.DeleteKB(kbID); err != nil {
log.Printf("[KnowledgeHandler] 删除知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除知识库失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
// ========== POST /api/v1/knowledge/bases/:id/documents — 添加文档 ==========
func (h *KnowledgeHandler) AddDocument(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
// 检查知识库是否存在且属于当前用户
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权操作此知识库", "errorType": "access_denied"})
return
}
var req addDocRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文档标题", "errorType": "invalid_request"})
return
}
if req.SourceType == "" {
req.SourceType = "text"
}
var content string
var sourceRef string
var contentType string
switch req.SourceType {
case "text":
content = req.Content
contentType = "text/plain"
if content == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文档内容", "errorType": "invalid_request"})
return
}
case "file":
if req.FileID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供文件ID", "errorType": "invalid_request"})
return
}
sourceRef = req.FileID
// 从 FileStore 读取文件内容
if h.fileStore != nil {
f, err := h.fileStore.GetFile(req.FileID)
if err != nil || f == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在", "errorType": "not_found"})
return
}
if f.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文件", "errorType": "access_denied"})
return
}
// 注意:这里只支持文本类型文件的读取
// 对于二进制文件,需要更复杂的解析逻辑
sourceRef = f.Filename
// 从磁盘读取文件内容
content, contentType, _ = readFileContent(f.StoredPath, f.MimeType)
if content == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "无法读取文件内容,仅支持文本文件", "errorType": "unsupported_file"})
return
}
} else {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "文件存储不可用", "errorType": "service_unavailable"})
return
}
case "url":
sourceRef = req.FileID
content = req.Content
contentType = "text/html"
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "不支持的来源类型: " + req.SourceType, "errorType": "invalid_request"})
return
}
if content == "" {
content = req.Content
}
if contentType == "" {
contentType = "text/plain"
}
doc := &store.KnowledgeDocument{
ID: store.GenerateUUID(),
KBID: kbID,
UserID: userID,
Title: html.EscapeString(req.Title),
SourceType: html.EscapeString(req.SourceType),
SourceRef: sourceRef,
ContentType: contentType,
RawContent: content,
}
if err := h.store.AddDocument(doc); err != nil {
log.Printf("[KnowledgeHandler] 添加文档失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "添加文档失败", "errorType": "db_error"})
return
}
// 自动分块
chunkCount, err := h.store.ChunkDocument(doc.ID)
if err != nil {
log.Printf("[KnowledgeHandler] 文档分块失败: %v", err)
// 分块失败不影响文档创建
}
doc.ChunkCount = chunkCount
c.JSON(http.StatusCreated, doc)
}
// ========== GET /api/v1/knowledge/bases/:id/documents — 列出知识库中的文档 ==========
func (h *KnowledgeHandler) ListDocuments(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
kbID := c.Param("id")
// 检查权限
kb, err := h.store.GetKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询知识库失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询知识库失败", "errorType": "db_error"})
return
}
if kb == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "知识库不存在", "errorType": "not_found"})
return
}
if kb.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此知识库", "errorType": "access_denied"})
return
}
docs, err := h.store.GetDocumentsByKB(kbID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询文档列表失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档列表失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"documents": docs, "total": len(docs)})
}
// ========== GET /api/v1/knowledge/documents/:id — 获取文档详情 ==========
func (h *KnowledgeHandler) GetDocument(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
docID := c.Param("id")
doc, err := h.store.GetDocument(docID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询文档失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档失败", "errorType": "db_error"})
return
}
if doc == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "文档不存在", "errorType": "not_found"})
return
}
if doc.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权访问此文档", "errorType": "access_denied"})
return
}
// 获取分块
chunks, err := h.store.GetChunksByDocID(docID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询分块失败: %v", err)
chunks = []store.KnowledgeChunk{}
}
c.JSON(http.StatusOK, gin.H{
"document": doc,
"chunks": chunks,
})
}
// ========== DELETE /api/v1/knowledge/documents/:id — 删除文档 ==========
func (h *KnowledgeHandler) DeleteDocument(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
docID := c.Param("id")
doc, err := h.store.GetDocument(docID)
if err != nil {
log.Printf("[KnowledgeHandler] 查询文档失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "查询文档失败", "errorType": "db_error"})
return
}
if doc == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "文档不存在", "errorType": "not_found"})
return
}
if doc.UserID != userID {
c.JSON(http.StatusForbidden, gin.H{"error": "无权删除此文档", "errorType": "access_denied"})
return
}
if err := h.store.DeleteDocument(docID); err != nil {
log.Printf("[KnowledgeHandler] 删除文档失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "删除文档失败", "errorType": "db_error"})
return
}
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
// ========== POST /api/v1/knowledge/search — 搜索知识库 ==========
func (h *KnowledgeHandler) Search(c *gin.Context) {
if h.checkStore(c) {
return
}
userID := middleware.GetUserID(c)
var req searchRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供搜索关键词", "errorType": "invalid_request"})
return
}
if req.Limit <= 0 {
req.Limit = 5
}
if req.Limit > 50 {
req.Limit = 50
}
var results []store.SearchChunkResult
var err error
if len(req.KBIDs) == 0 {
// 搜索所有知识库
results, err = h.store.SearchAllKBs(userID, req.Query, req.Limit)
} else {
// 搜索指定知识库,需要验证权限
results = []store.SearchChunkResult{}
for _, kbID := range req.KBIDs {
kb, checkErr := h.store.GetKB(kbID)
if checkErr != nil || kb == nil || kb.UserID != userID {
continue // 跳过无权限或不存在的知识库
}
kbResults, searchErr := h.store.SearchChunks(kbID, req.Query, req.Limit)
if searchErr != nil {
log.Printf("[KnowledgeHandler] 搜索知识库 %s 失败: %v", kbID, searchErr)
continue
}
results = append(results, kbResults...)
}
// 限制总结果数
if len(results) > req.Limit {
results = results[:req.Limit]
}
}
if err != nil {
log.Printf("[KnowledgeHandler] 搜索失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "搜索失败", "errorType": "db_error"})
return
}
// 生成高亮片段
for i := range results {
results[i].Headline = generateHeadline(results[i].Content, req.Query, 200)
}
c.JSON(http.StatusOK, gin.H{
"chunks": results,
"total": len(results),
"query": req.Query,
})
}
// ========== 辅助函数 ==========
// generateHeadline 生成高亮片段,提取查询关键词周围的文本
func generateHeadline(content, query string, maxLen int) string {
if maxLen <= 0 {
maxLen = 200
}
runes := []rune(content)
if len(runes) <= maxLen {
return content
}
// 查找查询关键词位置
queryRunes := []rune(query)
pos := -1
for i := 0; i <= len(runes)-len(queryRunes); i++ {
match := true
for j := 0; j < len(queryRunes); j++ {
if runes[i+j] != queryRunes[j] {
match = false
break
}
}
if match {
pos = i
break
}
}
if pos < 0 {
// 没有找到精确匹配,返回前 maxLen 个字符
return string(runes[:maxLen]) + "..."
}
// 以匹配位置为中心,截取上下文
half := maxLen / 2
start := pos - half
if start < 0 {
start = 0
}
end := start + maxLen
if end > len(runes) {
end = len(runes)
start = end - maxLen
if start < 0 {
start = 0
}
}
result := string(runes[start:end])
if start > 0 {
result = "..." + result
}
if end < len(runes) {
result = result + "..."
}
return result
}
// readFileContent 从磁盘读取文件内容 (仅支持文本类型)
func readFileContent(path, mimeType string) (content string, contentType string, err error) {
// 只支持文本类型
if !strings.HasPrefix(mimeType, "text/") && mimeType != "application/json" {
return "", "", nil
}
data, err := os.ReadFile(path)
if err != nil {
return "", "", err
}
return string(data), mimeType, nil
}