mirror of
https://github.com/chaitin/MonkeyCode.git
synced 2026-02-08 09:43:21 +08:00
fix(model): 兼容硅基流动 sub_type
This commit is contained in:
@@ -255,6 +255,20 @@ func (m *ModelUsecase) InitModel(ctx context.Context) error {
|
||||
return m.repo.InitModel(ctx, m.cfg.InitModel.Name, m.cfg.InitModel.Key, m.cfg.InitModel.URL)
|
||||
}
|
||||
|
||||
func (m *ModelUsecase) getQuery(req *domain.GetProviderModelListReq) request.Query {
|
||||
q := make(request.Query, 0)
|
||||
if req.Provider != consts.ModelProviderBaiZhiCloud && req.Provider != consts.ModelProviderSiliconFlow {
|
||||
return q
|
||||
}
|
||||
q["type"] = "text"
|
||||
q["sub_type"] = string(req.Type)
|
||||
// 硅基流动不支持coder sub_type
|
||||
if req.Provider == consts.ModelProviderSiliconFlow && req.Type == consts.ModelTypeCoder {
|
||||
q["sub_type"] = "chat"
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.GetProviderModelListReq) (*domain.GetProviderModelListResp, error) {
|
||||
switch req.Provider {
|
||||
case consts.ModelProviderAzureOpenAI,
|
||||
@@ -266,6 +280,8 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get
|
||||
consts.ModelProviderHunyuan,
|
||||
consts.ModelProviderMoonshot,
|
||||
consts.ModelProviderDeepSeek,
|
||||
consts.ModelProviderSiliconFlow,
|
||||
consts.ModelProviderBaiZhiCloud,
|
||||
consts.ModelProviderBaiLian:
|
||||
u, err := url.Parse(req.BaseURL)
|
||||
if err != nil {
|
||||
@@ -273,11 +289,16 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get
|
||||
}
|
||||
u.Path = path.Join(u.Path, "/models")
|
||||
client := request.NewClient(u.Scheme, u.Host, m.client.Timeout, request.WithClient(m.client))
|
||||
resp, err := request.Get[domain.OpenAIResp](client, u.Path, request.WithHeader(
|
||||
request.Header{
|
||||
"Authorization": fmt.Sprintf("Bearer %s", req.APIKey),
|
||||
},
|
||||
))
|
||||
query := m.getQuery(req)
|
||||
resp, err := request.Get[domain.OpenAIResp](
|
||||
client, u.Path,
|
||||
request.WithHeader(
|
||||
request.Header{
|
||||
"Authorization": fmt.Sprintf("Bearer %s", req.APIKey),
|
||||
},
|
||||
),
|
||||
request.WithQuery(query),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -289,6 +310,7 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get
|
||||
}
|
||||
}),
|
||||
}, nil
|
||||
|
||||
case consts.ModelProviderOllama:
|
||||
// get from ollama http://10.10.16.24:11434/api/tags
|
||||
u, err := url.Parse(req.BaseURL)
|
||||
@@ -306,56 +328,6 @@ func (m *ModelUsecase) GetProviderModelList(ctx context.Context, req *domain.Get
|
||||
|
||||
return request.Get[domain.GetProviderModelListResp](client, u.Path, request.WithHeader(h))
|
||||
|
||||
case consts.ModelProviderSiliconFlow, consts.ModelProviderBaiZhiCloud:
|
||||
if req.Type == consts.ModelTypeEmbedding || req.Type == consts.ModelTypeReranker {
|
||||
if req.Provider == consts.ModelProviderBaiZhiCloud {
|
||||
if req.Type == consts.ModelTypeEmbedding {
|
||||
return &domain.GetProviderModelListResp{
|
||||
Models: []domain.ProviderModelListItem{
|
||||
{
|
||||
Model: "bge-m3",
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
} else {
|
||||
return &domain.GetProviderModelListResp{
|
||||
Models: []domain.ProviderModelListItem{
|
||||
{
|
||||
Model: "bge-reranker-v2-m3",
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
u, err := url.Parse(req.BaseURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
st := string(req.Type)
|
||||
if req.Type == consts.ModelTypeLLM {
|
||||
st = "chat"
|
||||
}
|
||||
client := request.NewClient(u.Scheme, u.Host, m.client.Timeout, request.WithClient(m.client))
|
||||
resp, err := request.Get[domain.OpenAIResp](client, "/v1/models", request.WithHeader(
|
||||
request.Header{
|
||||
"Authorization": fmt.Sprintf("Bearer %s", req.APIKey),
|
||||
},
|
||||
), request.WithQuery(request.Query{
|
||||
"type": "text",
|
||||
"sub_type": st,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &domain.GetProviderModelListResp{
|
||||
Models: cvt.Iter(resp.Data, func(_ int, e *domain.OpenAIData) domain.ProviderModelListItem {
|
||||
return domain.ProviderModelListItem{
|
||||
Model: e.ID,
|
||||
}
|
||||
}),
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid provider: %s", req.Provider)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user