diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 64bf52f..123da85 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -75,7 +75,7 @@ func newServer() (*Server, error) { proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo, securityScanningRepo, slogLogger, configConfig, redisClient) llmProxy := proxy.NewLLMProxy(slogLogger, configConfig, proxyUsecase) openAIRepo := repo4.NewOpenAIRepo(client) - openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, modelRepo, slogLogger) + openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, modelRepo, slogLogger, redisClient) extensionRepo := repo5.NewExtensionRepo(client) extensionUsecase := usecase2.NewExtensionUsecase(extensionRepo, configConfig, slogLogger) ipdbIPDB, err := ipdb.NewIPDB(slogLogger) @@ -85,7 +85,7 @@ func newServer() (*Server, error) { userRepo := repo6.NewUserRepo(client, ipdbIPDB, redisClient, configConfig) sessionSession := session.NewSession(configConfig) userUsecase := usecase3.NewUserUsecase(configConfig, redisClient, userRepo, slogLogger, sessionSession) - proxyMiddleware := middleware.NewProxyMiddleware(proxyUsecase) + proxyMiddleware := middleware.NewProxyMiddleware(proxyUsecase, redisClient, slogLogger) activeMiddleware := middleware.NewActiveMiddleware(redisClient, slogLogger) v1Handler := v1.NewV1Handler(slogLogger, web, llmProxy, proxyUsecase, openAIUsecase, extensionUsecase, userUsecase, proxyMiddleware, activeMiddleware, configConfig) modelUsecase := usecase4.NewModelUsecase(slogLogger, modelRepo, configConfig) diff --git a/backend/consts/model.go b/backend/consts/model.go index fad90d9..f1a37f8 100644 --- a/backend/consts/model.go +++ b/backend/consts/model.go @@ -3,6 +3,7 @@ package consts type ModelStatus string const ( + ModelStatusDefault ModelStatus = "default" ModelStatusActive ModelStatus = "active" ModelStatusInactive ModelStatus = "inactive" ) diff --git a/backend/domain/model.go b/backend/domain/model.go index 4b37b19..2da8243 100644 --- a/backend/domain/model.go +++ b/backend/domain/model.go @@ -20,7 +20,7 @@ type ModelUsecase interface { } type ModelRepo interface { - GetWithCache(ctx context.Context, modelType consts.ModelType) (*db.Model, error) + GetWithCache(ctx context.Context, modelType consts.ModelType) ([]*db.Model, error) List(ctx context.Context) (*AllModelResp, error) Create(ctx context.Context, m *CreateModelReq) (*db.Model, error) Update(ctx context.Context, id string, fn func(tx *db.Tx, old *db.Model, up *db.ModelUpdateOne) error) (*db.Model, error) @@ -181,7 +181,7 @@ func (m *Model) From(e *db.Model) *Model { m.ModelType = e.ModelType m.Status = e.Status m.IsInternal = e.IsInternal - m.IsActive = e.Status == consts.ModelStatusActive + m.IsActive = e.Status == consts.ModelStatusActive || e.Status == consts.ModelStatusDefault if p := e.Parameters; p != nil { m.Param = ModelParam{ R1Enabled: p.R1Enabled, diff --git a/backend/domain/openai.go b/backend/domain/openai.go index 95ec6a9..e3cfcce 100644 --- a/backend/domain/openai.go +++ b/backend/domain/openai.go @@ -68,7 +68,7 @@ type ConfigReq struct { type ConfigResp struct { Type consts.ConfigType `json:"type"` - Content string `json:"content"` + Content any `json:"content"` } type OpenAIResp struct { Object string `json:"object"` diff --git a/backend/domain/plugin.go b/backend/domain/plugin.go new file mode 100644 index 0000000..e979246 --- /dev/null +++ b/backend/domain/plugin.go @@ -0,0 +1,52 @@ +package domain + +type PluginConfig struct { + ProviderProfiles ProviderProfiles `json:"providerProfiles"` + CtcodeTabCompletions CtcodeTabCompletions `json:"ctcodeTabCompletions"` + GlobalSettings GlobalSettings `json:"globalSettings"` +} + +type ProviderProfiles struct { + CurrentApiConfigName string `json:"currentApiConfigName"` + ApiConfigs map[string]ApiConfig `json:"apiConfigs"` + ModeApiConfigs map[string]string `json:"modeApiConfigs"` + Migrations Migrations `json:"migrations"` +} + +type ApiConfig struct { + ApiProvider string `json:"apiProvider"` + ApiModelId string `json:"apiModelId"` + OpenAiBaseUrl string `json:"openAiBaseUrl"` + OpenAiApiKey string `json:"openAiApiKey"` + OpenAiModelId string `json:"openAiModelId"` + OpenAiR1FormatEnabled bool `json:"openAiR1FormatEnabled"` + OpenAiCustomModelInfo OpenAiCustomModelInfo `json:"openAiCustomModelInfo"` + Id string `json:"id"` +} + +type OpenAiCustomModelInfo struct { + MaxTokens int `json:"maxTokens"` + ContextWindow int `json:"contextWindow"` + SupportsImages bool `json:"supportsImages"` + SupportsComputerUse bool `json:"supportsComputerUse"` + SupportsPromptCache bool `json:"supportsPromptCache"` +} + +type Migrations struct { + RateLimitSecondsMigrated bool `json:"rateLimitSecondsMigrated"` + DiffSettingsMigrated bool `json:"diffSettingsMigrated"` +} + +type CtcodeTabCompletions struct { + Enabled bool `json:"enabled"` + ApiProvider string `json:"apiProvider"` + OpenAiBaseUrl string `json:"openAiBaseUrl"` + OpenAiApiKey string `json:"openAiApiKey"` + OpenAiModelId string `json:"openAiModelId"` +} + +type GlobalSettings struct { + AllowedCommands []string `json:"allowedCommands"` + Mode string `json:"mode"` + CustomModes []string `json:"customModes"` +} diff --git a/backend/internal/middleware/proxy.go b/backend/internal/middleware/proxy.go index 702dc08..4b424c2 100644 --- a/backend/internal/middleware/proxy.go +++ b/backend/internal/middleware/proxy.go @@ -2,10 +2,13 @@ package middleware import ( "context" + "encoding/json" + "log/slog" "net/http" "strings" "github.com/labstack/echo/v4" + "github.com/redis/go-redis/v9" "github.com/chaitin/MonkeyCode/backend/domain" "github.com/chaitin/MonkeyCode/backend/ent/rule" @@ -16,15 +19,23 @@ const ( ApiContextKey = "session:apikey" ) +type proxyModelKey struct{} + type ProxyMiddleware struct { usecase domain.ProxyUsecase + redis *redis.Client + logger *slog.Logger } func NewProxyMiddleware( usecase domain.ProxyUsecase, + redis *redis.Client, + logger *slog.Logger, ) *ProxyMiddleware { return &ProxyMiddleware{ usecase: usecase, + redis: redis, + logger: logger.With("module", "ProxyMiddleware"), } } @@ -39,21 +50,54 @@ func (p *ProxyMiddleware) Auth() echo.MiddlewareFunc { return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"}) } - key, err := p.usecase.ValidateApiKey(c.Request().Context(), apiKey) - if err != nil { - return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"}) + ctx := c.Request().Context() + p.logger.With("apiKey", apiKey).DebugContext(ctx, "v1 auth") + if strings.Contains(apiKey, ".") { + s, err := p.redis.Get(ctx, apiKey).Result() + if err != nil { + p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to get api key from redis") + return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"}) + } + var model *domain.Model + if err := json.Unmarshal([]byte(s), &model); err != nil { + p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to unmarshal model from redis") + return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"}) + } + parts := strings.Split(apiKey, ".") + if len(parts) != 2 { + p.logger.With("fn", "Auth").With("apiKey", apiKey).ErrorContext(ctx, "invalid api key") + return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"}) + } + ctx = context.WithValue(ctx, proxyModelKey{}, model) + ctx = context.WithValue(ctx, logger.UserIDKey{}, parts[0]) + c.Set(ApiContextKey, &domain.ApiKey{ + UserID: parts[0], + Key: apiKey, + }) + } else { + key, err := p.usecase.ValidateApiKey(ctx, apiKey) + if err != nil { + return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"}) + } + ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID) + c.Set(ApiContextKey, key) } - ctx := c.Request().Context() - ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID) ctx = rule.SkipPermission(ctx) c.SetRequest(c.Request().WithContext(ctx)) - c.Set(ApiContextKey, key) return next(c) } } } +func GetProxyModel(ctx context.Context) *domain.Model { + m := ctx.Value(proxyModelKey{}) + if m == nil { + return nil + } + return m.(*domain.Model) +} + func GetApiKey(c echo.Context) *domain.ApiKey { i := c.Get(ApiContextKey) if i == nil { diff --git a/backend/internal/model/repo/model.go b/backend/internal/model/repo/model.go index 50c0fd9..9e18841 100644 --- a/backend/internal/model/repo/model.go +++ b/backend/internal/model/repo/model.go @@ -30,15 +30,16 @@ func NewModelRepo(db *db.Client) domain.ModelRepo { return &ModelRepo{db: db, cache: cache} } -func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType) (*db.Model, error) { +func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType) ([]*db.Model, error) { if v, ok := r.cache.Get(string(modelType)); ok { - return v.(*db.Model), nil + return v.([]*db.Model), nil } m, err := r.db.Model.Query(). Where(model.ModelType(modelType)). - Where(model.Status(consts.ModelStatusActive)). - Only(ctx) + Where(model.StatusIn(consts.ModelStatusActive, consts.ModelStatusDefault)). + Order(ByStatusOrder()). + All(ctx) if err != nil { return nil, err } @@ -47,14 +48,22 @@ func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType return m, nil } +func ByStatusOrder() func(s *sql.Selector) { + return func(s *sql.Selector) { + s.OrderExprFunc(func(b *sql.Builder) { + b.WriteString("case when status = 'default' then 3 when status = 'active' then 2 else 1 end desc") + }) + } +} + func (r *ModelRepo) Create(ctx context.Context, m *domain.CreateModelReq) (*db.Model, error) { n, err := r.db.Model.Query().Where(model.ModelType(m.ModelType)).Count(ctx) if err != nil { return nil, err } - status := consts.ModelStatusInactive + status := consts.ModelStatusActive if n == 0 { - status = consts.ModelStatusActive + status = consts.ModelStatusDefault } r.cache.Delete(string(m.ModelType)) diff --git a/backend/internal/model/usecase/model.go b/backend/internal/model/usecase/model.go index f046ea9..a7d5dd9 100644 --- a/backend/internal/model/usecase/model.go +++ b/backend/internal/model/usecase/model.go @@ -108,14 +108,27 @@ func (m *ModelUsecase) Update(ctx context.Context, req *domain.UpdateModelReq) ( up.SetShowName(*req.ShowName) } if req.Status != nil { - if *req.Status == consts.ModelStatusActive { + if *req.Status == consts.ModelStatusDefault { if err := tx.Model.Update(). + Where(model.Status(consts.ModelStatusDefault)). Where(model.ModelType(old.ModelType)). - SetStatus(consts.ModelStatusInactive). + SetStatus(consts.ModelStatusActive). Exec(ctx); err != nil { return err } } + if *req.Status == consts.ModelStatusActive { + n, err := tx.Model.Query(). + Where(model.Status(consts.ModelStatusDefault)). + Where(model.ModelType(old.ModelType)). + Count(ctx) + if err != nil { + return err + } + if n == 0 { + *req.Status = consts.ModelStatusDefault + } + } up.SetStatus(*req.Status) } if req.Param != nil { diff --git a/backend/internal/openai/usecase/openai.go b/backend/internal/openai/usecase/openai.go index 5a4abe4..8fedd7d 100644 --- a/backend/internal/openai/usecase/openai.go +++ b/backend/internal/openai/usecase/openai.go @@ -1,18 +1,21 @@ package openai import ( - "bytes" "context" - "html/template" + "encoding/json" + "errors" + "fmt" "log/slog" + "time" - "github.com/chaitin/MonkeyCode/backend/ent/types" - "github.com/chaitin/MonkeyCode/backend/pkg/cvt" + "github.com/redis/go-redis/v9" "github.com/chaitin/MonkeyCode/backend/config" "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/db" "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/ent/types" + "github.com/chaitin/MonkeyCode/backend/pkg/cvt" ) type OpenAIUsecase struct { @@ -20,6 +23,7 @@ type OpenAIUsecase struct { modelRepo domain.ModelRepo cfg *config.Config logger *slog.Logger + redis *redis.Client } func NewOpenAIUsecase( @@ -27,12 +31,14 @@ func NewOpenAIUsecase( repo domain.OpenAIRepo, modelRepo domain.ModelRepo, logger *slog.Logger, + redis *redis.Client, ) domain.OpenAIUsecase { return &OpenAIUsecase{ repo: repo, modelRepo: modelRepo, cfg: cfg, logger: logger, + redis: redis, } } @@ -42,10 +48,6 @@ func (u *OpenAIUsecase) ModelList(ctx context.Context) (*domain.ModelListResp, e return nil, err } - for _, v := range models { - u.logger.DebugContext(ctx, "model", slog.Any("model", v)) - } - resp := &domain.ModelListResp{ Object: "list", Data: cvt.Iter(models, func(_ int, m *db.Model) *domain.ModelData { @@ -61,44 +63,106 @@ func (u *OpenAIUsecase) GetConfig(ctx context.Context, req *domain.ConfigReq) (* if err != nil { return nil, err } - llm, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeLLM) + llms, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeLLM) if err != nil { return nil, err } - coder, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeCoder) + coders, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeCoder) if err != nil { return nil, err } + u.logger.With( + "llms", len(llms), + "coders", len(coders), + ).DebugContext(ctx, "get config") + + if len(llms) == 0 || len(coders) == 0 { + return nil, errors.New("no model") + } + + llm := llms[0] + coder := coders[0] + coderkey := fmt.Sprintf("%s.%s", apiKey.UserID.String(), coder.ID.String()) + if err = u.redis.Get(ctx, coderkey).Err(); err != nil { + b, err := json.Marshal(cvt.From(coder, &domain.Model{})) + if err != nil { + return nil, err + } + if err = u.redis.Set(ctx, coderkey, string(b), time.Hour*24).Err(); err != nil { + return nil, err + } + } + if llm.Parameters == nil { llm.Parameters = types.DefaultModelParam() } - t, err := template.New("config").Parse(string(config.ConfigTmpl)) - if err != nil { - return nil, err + config := &domain.PluginConfig{ + ProviderProfiles: domain.ProviderProfiles{ + CurrentApiConfigName: "default", + ApiConfigs: map[string]domain.ApiConfig{}, + ModeApiConfigs: map[string]string{ + "code": "59admorkig4", + "architect": "59admorkig4", + "ask": "59admorkig4", + "debug": "59admorkig4", + "deepresearch": "59admorkig4", + }, + Migrations: domain.Migrations{ + RateLimitSecondsMigrated: true, + DiffSettingsMigrated: true, + }, + }, + CtcodeTabCompletions: domain.CtcodeTabCompletions{ + Enabled: true, + ApiProvider: "openai", + OpenAiBaseUrl: req.BaseURL + "/v1", + OpenAiApiKey: coderkey, + OpenAiModelId: coder.ModelName, + }, } - u.logger.With("param", llm.Parameters).DebugContext(ctx, "get config") - cnt := bytes.NewBuffer(nil) - data := map[string]any{ - "apiBase": req.BaseURL, - "apikey": apiKey.Key, - "chatModel": llm.ModelName, - "codeModel": coder.ModelName, - "r1Enabled": llm.Parameters.R1Enabled, - "maxTokens": llm.Parameters.MaxTokens, - "contextWindow": llm.Parameters.ContextWindow, - "supportsImages": llm.Parameters.SupprtImages, - "supportsComputerUse": llm.Parameters.SupportComputerUse, - "supportsPromptCache": llm.Parameters.SupportPromptCache, - } - if err := t.Execute(cnt, data); err != nil { - return nil, err + for _, m := range llms { + key := fmt.Sprintf("%s.%s", apiKey.UserID.String(), m.ID.String()) + if m.Parameters == nil { + m.Parameters = types.DefaultModelParam() + } + name := fmt.Sprintf("%s (%s)", m.ModelName, m.Provider) + if m.Status == consts.ModelStatusDefault { + name = "default" + } + config.ProviderProfiles.ApiConfigs[name] = domain.ApiConfig{ + ApiProvider: "openai", + ApiModelId: m.ModelName, + OpenAiBaseUrl: req.BaseURL + "/v1", + OpenAiApiKey: key, + OpenAiModelId: m.ModelName, + OpenAiR1FormatEnabled: m.Parameters.R1Enabled, + OpenAiCustomModelInfo: domain.OpenAiCustomModelInfo{ + MaxTokens: m.Parameters.MaxTokens, + ContextWindow: m.Parameters.ContextWindow, + SupportsImages: m.Parameters.SupprtImages, + SupportsComputerUse: m.Parameters.SupportComputerUse, + SupportsPromptCache: m.Parameters.SupportPromptCache, + }, + Id: m.ID.String(), + } + + if err = u.redis.Get(ctx, key).Err(); err == nil { + continue + } + b, err := json.Marshal(cvt.From(m, &domain.Model{})) + if err != nil { + return nil, err + } + if err := u.redis.Set(ctx, key, string(b), time.Hour*24).Err(); err != nil { + return nil, err + } } return &domain.ConfigResp{ Type: req.Type, - Content: cnt.String(), + Content: config, }, nil } diff --git a/backend/internal/proxy/proxy.go b/backend/internal/proxy/proxy.go index 612c2dc..bb1f0fe 100644 --- a/backend/internal/proxy/proxy.go +++ b/backend/internal/proxy/proxy.go @@ -16,6 +16,7 @@ import ( "github.com/chaitin/MonkeyCode/backend/config" "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/internal/middleware" "github.com/chaitin/MonkeyCode/backend/pkg/logger" "github.com/chaitin/MonkeyCode/backend/pkg/tee" ) @@ -96,11 +97,16 @@ func (l *LLMProxy) rewrite(r *httputil.ProxyRequest) { return } - m, err := l.usecase.SelectModelWithLoadBalancing("", mt) - if err != nil { - l.logger.ErrorContext(r.In.Context(), "select model with load balancing failed", slog.String("path", r.In.URL.Path), slog.Any("err", err)) - return + var m *domain.Model + var err error + if m = middleware.GetProxyModel(r.In.Context()); m == nil { + m, err = l.usecase.SelectModelWithLoadBalancing("", mt) + if err != nil { + l.logger.ErrorContext(r.In.Context(), "select model with load balancing failed", slog.String("path", r.In.URL.Path), slog.Any("err", err)) + return + } } + ul, err := url.Parse(m.APIBase) if err != nil { l.logger.ErrorContext(r.In.Context(), "parse model api base failed", slog.String("path", r.In.URL.Path), slog.Any("err", err)) diff --git a/backend/internal/proxy/usecase/proxy.go b/backend/internal/proxy/usecase/proxy.go index 4970ce9..b3b9ea1 100644 --- a/backend/internal/proxy/usecase/proxy.go +++ b/backend/internal/proxy/usecase/proxy.go @@ -9,6 +9,7 @@ import ( "path" "time" + "github.com/google/uuid" "github.com/redis/go-redis/v9" "github.com/chaitin/MonkeyCode/backend/config" @@ -16,6 +17,7 @@ import ( "github.com/chaitin/MonkeyCode/backend/db" "github.com/chaitin/MonkeyCode/backend/domain" "github.com/chaitin/MonkeyCode/backend/ent/rule" + "github.com/chaitin/MonkeyCode/backend/internal/middleware" "github.com/chaitin/MonkeyCode/backend/pkg/cvt" "github.com/chaitin/MonkeyCode/backend/pkg/queuerunner" "github.com/chaitin/MonkeyCode/backend/pkg/request" @@ -85,7 +87,16 @@ func (p *ProxyUsecase) SelectModelWithLoadBalancing(modelName string, modelType if err != nil { return nil, err } - return cvt.From(model, &domain.Model{}), nil + if len(model) == 0 { + return nil, fmt.Errorf("no model found") + } + m := model[0] + for _, mm := range model { + if mm.Status == consts.ModelStatusDefault { + m = mm + } + } + return cvt.From(m, &domain.Model{}), nil } func (p *ProxyUsecase) ValidateApiKey(ctx context.Context, key string) (*domain.ApiKey, error) { @@ -102,12 +113,33 @@ func (p *ProxyUsecase) AcceptCompletion(ctx context.Context, req *domain.AcceptC func (p *ProxyUsecase) Report(ctx context.Context, req *domain.ReportReq) error { var model *db.Model - var err error if req.Action == consts.ReportActionNewTask { - model, err = p.modelRepo.GetWithCache(context.Background(), consts.ModelTypeLLM) - if err != nil { - p.logger.With("fn", "Report").With("error", err).ErrorContext(ctx, "failed to get model") - return err + m := middleware.GetProxyModel(ctx) + if m == nil { + ms, err := p.modelRepo.GetWithCache(ctx, consts.ModelTypeLLM) + if err != nil { + return fmt.Errorf("get model with cache failed: %w", err) + } + if len(ms) == 0 { + return fmt.Errorf("no model found") + } + model = ms[0] + for _, mm := range ms { + if mm.Status == consts.ModelStatusDefault { + model = mm + break + } + } + } else { + mid, err := uuid.Parse(m.ID) + if err != nil { + return fmt.Errorf("parse proxy model id failed: %w", err) + } + model = &db.Model{ + ID: mid, + ModelName: m.ModelName, + ModelType: m.ModelType, + } } } return p.repo.Report(ctx, model, req) diff --git a/backend/pro b/backend/pro index 482ab56..a566152 160000 --- a/backend/pro +++ b/backend/pro @@ -1 +1 @@ -Subproject commit 482ab56daf6d7d205f47e25c057bcab4804a6b73 +Subproject commit a566152111f123a8ccb3bef9a21b1d40f0d5bb47 diff --git a/ui/src/api/types.ts b/ui/src/api/types.ts index e2966ea..f7f8d11 100644 --- a/ui/src/api/types.ts +++ b/ui/src/api/types.ts @@ -19,6 +19,7 @@ export enum GithubComChaitinMonkeyCodeBackendConstsModelType { } export enum GithubComChaitinMonkeyCodeBackendConstsModelStatus { + ModelStatusDefault = "default", ModelStatusActive = "active", ModelStatusInactive = "inactive", } diff --git a/ui/src/pages/model/components/modelCard.tsx b/ui/src/pages/model/components/modelCard.tsx index e52af6c..cf838eb 100644 --- a/ui/src/pages/model/components/modelCard.tsx +++ b/ui/src/pages/model/components/modelCard.tsx @@ -6,7 +6,6 @@ import { } from '@/api/Model'; import { DomainModel, GithubComChaitinMonkeyCodeBackendConstsModelStatus, GithubComChaitinMonkeyCodeBackendConstsModelType, } from '@/api/types'; import { Stack, Box, Button, Grid2 as Grid, ButtonBase } from '@mui/material'; -import StyledLabel from '@/components/label'; import { Icon, Modal, message } from '@c-x/ui'; import { addCommasToNumber } from '@/utils'; import NoData from '@/assets/images/nodata.png'; @@ -76,6 +75,31 @@ const ModelItem = ({ }); }; + const onSetDefaultModel = () => { + Modal.confirm({ + title: '设为默认模型', + content: ( + <> + 确定要设置{' '} + + {data.model_name} + {' '} + 为默认模型吗? + + ), + onOk: () => { + putUpdateModel({ + id: data.id, + status: GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusDefault, + provider: data.provider!, + }).then(() => { + message.success('设为默认模型成功'); + refresh(); + }); + }, + }); + }; + const onActiveModel = () => { Modal.confirm({ title: '激活模型', @@ -101,6 +125,100 @@ const ModelItem = ({ }); }; + // 添加状态标签渲染函数 + const renderStatusLabel = () => { + // 根据 is_active 和 status 字段判断状态 + if (data.is_active && data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusActive) { + return ( + + 可选 + + ); + } else if (data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusInactive) { + return ( + + 未激活 + + ); + } else { + // 默认状态 + return ( + + 默认 + + ); + } + }; + return ( + {/* 美化的右上角状态标签 */} + + {renderStatusLabel()} + + {data.show_name || '未命名'} @@ -195,11 +326,25 @@ const ModelItem = ({ gap={2} sx={{ mt: 2 }} > - - {data.is_active ? '正在使用' : '未激活'} - + - {!data.is_active && ( + {(data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusActive || + data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusInactive) && ( + + 设为默认 + + )} + + {data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusInactive && (