diff --git a/backend/internal/model/usecase/model.go b/backend/internal/model/usecase/model.go index 792b51a..076c59d 100644 --- a/backend/internal/model/usecase/model.go +++ b/backend/internal/model/usecase/model.go @@ -255,6 +255,20 @@ func (m *ModelUsecase) InitModel(ctx context.Context) error { return m.repo.InitModel(ctx, m.cfg.InitModel.Name, m.cfg.InitModel.Key, m.cfg.InitModel.URL) } +func (m *ModelUsecase) getQuery(req *domain.GetProviderModelListReq) request.Query { + q := make(request.Query, 0) + if req.Provider != consts.ModelProviderBaiZhiCloud && req.Provider != consts.ModelProviderSiliconFlow { + return q + } + q["type"] = "text" + q["sub_type"] = string(req.Type) + // 硅基流动不支持coder sub_type + if req.Provider == consts.ModelProviderSiliconFlow && req.Type == consts.ModelTypeCoder { + q["sub_type"] = "chat" + } + return q +} + func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.GetProviderModelListReq) (*domain.GetProviderModelListResp, error) { switch req.Provider { case consts.ModelProviderAzureOpenAI, @@ -266,6 +280,8 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get consts.ModelProviderHunyuan, consts.ModelProviderMoonshot, consts.ModelProviderDeepSeek, + consts.ModelProviderSiliconFlow, + consts.ModelProviderBaiZhiCloud, consts.ModelProviderBaiLian: u, err := url.Parse(req.BaseURL) if err != nil { @@ -273,11 +289,16 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get } u.Path = path.Join(u.Path, "/models") client := request.NewClient(u.Scheme, u.Host, m.client.Timeout, request.WithClient(m.client)) - resp, err := request.Get[domain.OpenAIResp](client, u.Path, request.WithHeader( - request.Header{ - "Authorization": fmt.Sprintf("Bearer %s", req.APIKey), - }, - )) + query := m.getQuery(req) + resp, err := request.Get[domain.OpenAIResp]( + client, u.Path, + request.WithHeader( + request.Header{ + "Authorization": fmt.Sprintf("Bearer %s", req.APIKey), + }, + ), + request.WithQuery(query), + ) if err != nil { return nil, err } @@ -289,6 +310,7 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get } }), }, nil + case consts.ModelProviderOllama: // get from ollama http://10.10.16.24:11434/api/tags u, err := url.Parse(req.BaseURL) @@ -306,56 +328,6 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get return request.Get[domain.GetProviderModelListResp](client, u.Path, request.WithHeader(h)) - case consts.ModelProviderSiliconFlow, consts.ModelProviderBaiZhiCloud: - if req.Type == consts.ModelTypeEmbedding || req.Type == consts.ModelTypeReranker { - if req.Provider == consts.ModelProviderBaiZhiCloud { - if req.Type == consts.ModelTypeEmbedding { - return &domain.GetProviderModelListResp{ - Models: []domain.ProviderModelListItem{ - { - Model: "bge-m3", - }, - }, - }, nil - } else { - return &domain.GetProviderModelListResp{ - Models: []domain.ProviderModelListItem{ - { - Model: "bge-reranker-v2-m3", - }, - }, - }, nil - } - } - } - u, err := url.Parse(req.BaseURL) - if err != nil { - return nil, err - } - st := string(req.Type) - if req.Type == consts.ModelTypeLLM { - st = "chat" - } - client := request.NewClient(u.Scheme, u.Host, m.client.Timeout, request.WithClient(m.client)) - resp, err := request.Get[domain.OpenAIResp](client, "/v1/models", request.WithHeader( - request.Header{ - "Authorization": fmt.Sprintf("Bearer %s", req.APIKey), - }, - ), request.WithQuery(request.Query{ - "type": "text", - "sub_type": st, - })) - if err != nil { - return nil, err - } - - return &domain.GetProviderModelListResp{ - Models: cvt.Iter(resp.Data, func(_ int, e *domain.OpenAIData) domain.ProviderModelListItem { - return domain.ProviderModelListItem{ - Model: e.ID, - } - }), - }, nil default: return nil, fmt.Errorf("invalid provider: %s", req.Provider) }