fix(model): 兼容硅基流动 sub_type

This commit is contained in:
yokowu
2025-07-30 11:32:11 +08:00
parent 5e42469dd1
commit fb57c3dc09

View File

@@ -255,6 +255,20 @@ func (m *ModelUsecase) InitModel(ctx context.Context) error {
return m.repo.InitModel(ctx, m.cfg.InitModel.Name, m.cfg.InitModel.Key, m.cfg.InitModel.URL)
}
func (m *ModelUsecase) getQuery(req *domain.GetProviderModelListReq) request.Query {
q := make(request.Query, 0)
if req.Provider != consts.ModelProviderBaiZhiCloud && req.Provider != consts.ModelProviderSiliconFlow {
return q
}
q["type"] = "text"
q["sub_type"] = string(req.Type)
// 硅基流动不支持coder sub_type
if req.Provider == consts.ModelProviderSiliconFlow && req.Type == consts.ModelTypeCoder {
q["sub_type"] = "chat"
}
return q
}
func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.GetProviderModelListReq) (*domain.GetProviderModelListResp, error) {
switch req.Provider {
case consts.ModelProviderAzureOpenAI,
@@ -266,6 +280,8 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get
consts.ModelProviderHunyuan,
consts.ModelProviderMoonshot,
consts.ModelProviderDeepSeek,
consts.ModelProviderSiliconFlow,
consts.ModelProviderBaiZhiCloud,
consts.ModelProviderBaiLian:
u, err := url.Parse(req.BaseURL)
if err != nil {
@@ -273,11 +289,16 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get
}
u.Path = path.Join(u.Path, "/models")
client := request.NewClient(u.Scheme, u.Host, m.client.Timeout, request.WithClient(m.client))
resp, err := request.Get[domain.OpenAIResp](client, u.Path, request.WithHeader(
request.Header{
"Authorization": fmt.Sprintf("Bearer %s", req.APIKey),
},
))
query := m.getQuery(req)
resp, err := request.Get[domain.OpenAIResp](
client, u.Path,
request.WithHeader(
request.Header{
"Authorization": fmt.Sprintf("Bearer %s", req.APIKey),
},
),
request.WithQuery(query),
)
if err != nil {
return nil, err
}
@@ -289,6 +310,7 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get
}
}),
}, nil
case consts.ModelProviderOllama:
// get from ollama http://10.10.16.24:11434/api/tags
u, err := url.Parse(req.BaseURL)
@@ -306,56 +328,6 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get
return request.Get[domain.GetProviderModelListResp](client, u.Path, request.WithHeader(h))
case consts.ModelProviderSiliconFlow, consts.ModelProviderBaiZhiCloud:
if req.Type == consts.ModelTypeEmbedding || req.Type == consts.ModelTypeReranker {
if req.Provider == consts.ModelProviderBaiZhiCloud {
if req.Type == consts.ModelTypeEmbedding {
return &domain.GetProviderModelListResp{
Models: []domain.ProviderModelListItem{
{
Model: "bge-m3",
},
},
}, nil
} else {
return &domain.GetProviderModelListResp{
Models: []domain.ProviderModelListItem{
{
Model: "bge-reranker-v2-m3",
},
},
}, nil
}
}
}
u, err := url.Parse(req.BaseURL)
if err != nil {
return nil, err
}
st := string(req.Type)
if req.Type == consts.ModelTypeLLM {
st = "chat"
}
client := request.NewClient(u.Scheme, u.Host, m.client.Timeout, request.WithClient(m.client))
resp, err := request.Get[domain.OpenAIResp](client, "/v1/models", request.WithHeader(
request.Header{
"Authorization": fmt.Sprintf("Bearer %s", req.APIKey),
},
), request.WithQuery(request.Query{
"type": "text",
"sub_type": st,
}))
if err != nil {
return nil, err
}
return &domain.GetProviderModelListResp{
Models: cvt.Iter(resp.Data, func(_ int, e *domain.OpenAIData) domain.ProviderModelListItem {
return domain.ProviderModelListItem{
Model: e.ID,
}
}),
}, nil
default:
return nil, fmt.Errorf("invalid provider: %s", req.Provider)
}