package handler import ( "encoding/json" "fmt" "io" "net/http" "time" "github.com/gin-gonic/gin" "git.yeij.top/AskaEth/Cyrene/gateway/internal/config" ) // ModelConfigHandler exposes admin CRUD endpoints for model configuration. type ModelConfigHandler struct { store *config.ModelsConfigStore } func NewModelConfigHandler(store *config.ModelsConfigStore) *ModelConfigHandler { return &ModelConfigHandler{store: store} } // ---- Providers ---- func (h *ModelConfigHandler) ListProviders(c *gin.Context) { providers := h.store.ListProviders() if providers == nil { providers = []*config.ProviderConfig{} } c.JSON(http.StatusOK, gin.H{"providers": providers, "total": len(providers)}) } func (h *ModelConfigHandler) GetProvider(c *gin.Context) { name := c.Param("name") p, err := h.store.GetProvider(name) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, p) } func (h *ModelConfigHandler) SetProvider(c *gin.Context) { name := c.Param("name") var body config.ProviderConfig if err := c.ShouldBindJSON(&body); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON: " + err.Error()}) return } body.Name = name if err := h.store.SetProvider(&body); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"status": "saved", "name": name}) } func (h *ModelConfigHandler) DeleteProvider(c *gin.Context) { name := c.Param("name") if err := h.store.DeleteProvider(name); err != nil { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"status": "deleted", "name": name}) } // ---- Models ---- func (h *ModelConfigHandler) ListModels(c *gin.Context) { models := h.store.ListModels() if models == nil { models = []*config.ModelConfig{} } c.JSON(http.StatusOK, gin.H{"models": models, "total": len(models)}) } func (h *ModelConfigHandler) GetModel(c *gin.Context) { id := c.Param("id") m, err := h.store.GetModel(id) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, m) } func (h *ModelConfigHandler) SetModel(c *gin.Context) { id := c.Param("id") var body config.ModelConfig if err := c.ShouldBindJSON(&body); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON: " + err.Error()}) return } body.ID = id if err := h.store.SetModel(&body); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"status": "saved", "id": id}) } func (h *ModelConfigHandler) DeleteModel(c *gin.Context) { id := c.Param("id") if err := h.store.DeleteModel(id); err != nil { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"status": "deleted", "id": id}) } // ---- Routing ---- func (h *ModelConfigHandler) ListRouting(c *gin.Context) { routing := h.store.ListRouting() if routing == nil { routing = []*config.RoutingRule{} } c.JSON(http.StatusOK, gin.H{"routing": routing, "total": len(routing)}) } func (h *ModelConfigHandler) GetRouting(c *gin.Context) { purpose := c.Param("purpose") r, err := h.store.GetRouting(purpose) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, r) } func (h *ModelConfigHandler) SetRouting(c *gin.Context) { purpose := c.Param("purpose") var body config.RoutingRule if err := c.ShouldBindJSON(&body); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON: " + err.Error()}) return } body.Purpose = purpose if err := h.store.SetRouting(&body); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"status": "saved", "purpose": purpose}) } func (h *ModelConfigHandler) DeleteRouting(c *gin.Context) { purpose := c.Param("purpose") if err := h.store.DeleteRouting(purpose); err != nil { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"status": "deleted", "purpose": purpose}) } // ---- Health Check ---- func (h *ModelConfigHandler) TestProvider(c *gin.Context) { var body struct { Provider string `json:"provider"` } if err := c.ShouldBindJSON(&body); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid JSON: " + err.Error()}) return } p, err := h.store.GetProvider(body.Provider) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{ "provider": p.Name, "base_url": p.BaseURL, "message": "Provider 配置已保存,连接测试请通过实际 LLM 调用验证", }) } // ---- Remote Model List Proxy ---- // ProxyListModels forwards a request to the provider's models endpoint using the stored API key. func (h *ModelConfigHandler) ProxyListModels(c *gin.Context) { providerName := c.Param("name") modelsURL := c.Query("url") if modelsURL == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "missing 'url' query parameter"}) return } p, err := h.store.GetProvider(providerName) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) return } if p.APIKey == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "provider 未配置 API Key"}) return } client := &http.Client{Timeout: 15 * time.Second} req, err := http.NewRequest("GET", modelsURL, nil) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "创建请求失败: " + err.Error()}) return } req.Header.Set("Authorization", "Bearer "+p.APIKey) req.Header.Set("Accept", "application/json") resp, err := client.Do(req) if err != nil { c.JSON(http.StatusBadGateway, gin.H{"error": "请求模型列表失败: " + err.Error()}) return } defer resp.Body.Close() body, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) // 2 MB limit if err != nil { c.JSON(http.StatusBadGateway, gin.H{"error": "读取响应失败: " + err.Error()}) return } if resp.StatusCode >= 400 { c.JSON(http.StatusBadGateway, gin.H{ "error": fmt.Sprintf("Provider API 返回错误 (HTTP %d)", resp.StatusCode), "body": string(body), "models_url": modelsURL, }) return } // Parse the response body which may use different formats: // OpenAI: {"object":"list","data":[{"id":"...","object":"model",...}]} // DashScope: {"request_id":"...","data":{"models":[{"model_id":"..."}]}} // Generic: {"data":[{"id":"..."}]} or {"data":[{"model_id":"..."}]} ids := parseModelListResponse(body) if len(ids) == 0 { c.JSON(http.StatusBadGateway, gin.H{ "error": "无法从 Provider 响应中解析模型列表 (不支持的格式)", "raw": string(body), }) return } c.JSON(http.StatusOK, gin.H{ "provider": providerName, "url": modelsURL, "models": ids, "total": len(ids), }) } // parseModelListResponse attempts to extract model IDs from various provider response formats. // Supported formats: // - OpenAI-compatible: {"object":"list","data":[{"id":"gpt-4o",...}]} // - DashScope: {"data":{"models":[{"model_id":"qwen-turbo",...}]}} // - Generic: {"data":[{"id":"..."}]} or {"data":[{"model_id":"..."}]} func parseModelListResponse(body []byte) []string { var raw map[string]interface{} if err := json.Unmarshal(body, &raw); err != nil { return nil } // Strategy 1: data is an array of objects — try "id" then "model_id" if dataArr, ok := raw["data"].([]interface{}); ok { ids := extractIDs(dataArr, "id") if len(ids) > 0 { return ids } return extractIDs(dataArr, "model_id") } // Strategy 2: data is an object with a "models" array (DashScope format) if dataObj, ok := raw["data"].(map[string]interface{}); ok { if modelsArr, ok := dataObj["models"].([]interface{}); ok { ids := extractIDs(modelsArr, "model_id") if len(ids) > 0 { return ids } return extractIDs(modelsArr, "id") } } return nil } func extractIDs(items []interface{}, key string) []string { ids := make([]string, 0, len(items)) for _, item := range items { if obj, ok := item.(map[string]interface{}); ok { if v, ok := obj[key]; ok { if s, ok := v.(string); ok && s != "" { ids = append(ids, s) } } } } return ids }