Merge pull request #322 from yokowu/feat-multi-model

feat: 支持插件选择模型
This commit is contained in:
Yoko
2025-08-29 18:32:32 +08:00
committed by GitHub
14 changed files with 434 additions and 67 deletions

View File

@@ -75,7 +75,7 @@ func newServer() (*Server, error) {
proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo, securityScanningRepo, slogLogger, configConfig, redisClient)
llmProxy := proxy.NewLLMProxy(slogLogger, configConfig, proxyUsecase)
openAIRepo := repo4.NewOpenAIRepo(client)
openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, modelRepo, slogLogger)
openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, modelRepo, slogLogger, redisClient)
extensionRepo := repo5.NewExtensionRepo(client)
extensionUsecase := usecase2.NewExtensionUsecase(extensionRepo, configConfig, slogLogger)
ipdbIPDB, err := ipdb.NewIPDB(slogLogger)
@@ -85,7 +85,7 @@ func newServer() (*Server, error) {
userRepo := repo6.NewUserRepo(client, ipdbIPDB, redisClient, configConfig)
sessionSession := session.NewSession(configConfig)
userUsecase := usecase3.NewUserUsecase(configConfig, redisClient, userRepo, slogLogger, sessionSession)
proxyMiddleware := middleware.NewProxyMiddleware(proxyUsecase)
proxyMiddleware := middleware.NewProxyMiddleware(proxyUsecase, redisClient, slogLogger)
activeMiddleware := middleware.NewActiveMiddleware(redisClient, slogLogger)
v1Handler := v1.NewV1Handler(slogLogger, web, llmProxy, proxyUsecase, openAIUsecase, extensionUsecase, userUsecase, proxyMiddleware, activeMiddleware, configConfig)
modelUsecase := usecase4.NewModelUsecase(slogLogger, modelRepo, configConfig)

View File

@@ -3,6 +3,7 @@ package consts
type ModelStatus string
const (
ModelStatusDefault ModelStatus = "default"
ModelStatusActive ModelStatus = "active"
ModelStatusInactive ModelStatus = "inactive"
)

View File

@@ -20,7 +20,7 @@ type ModelUsecase interface {
}
type ModelRepo interface {
GetWithCache(ctx context.Context, modelType consts.ModelType) (*db.Model, error)
GetWithCache(ctx context.Context, modelType consts.ModelType) ([]*db.Model, error)
List(ctx context.Context) (*AllModelResp, error)
Create(ctx context.Context, m *CreateModelReq) (*db.Model, error)
Update(ctx context.Context, id string, fn func(tx *db.Tx, old *db.Model, up *db.ModelUpdateOne) error) (*db.Model, error)
@@ -181,7 +181,7 @@ func (m *Model) From(e *db.Model) *Model {
m.ModelType = e.ModelType
m.Status = e.Status
m.IsInternal = e.IsInternal
m.IsActive = e.Status == consts.ModelStatusActive
m.IsActive = e.Status == consts.ModelStatusActive || e.Status == consts.ModelStatusDefault
if p := e.Parameters; p != nil {
m.Param = ModelParam{
R1Enabled: p.R1Enabled,

View File

@@ -68,7 +68,7 @@ type ConfigReq struct {
type ConfigResp struct {
Type consts.ConfigType `json:"type"`
Content string `json:"content"`
Content any `json:"content"`
}
type OpenAIResp struct {
Object string `json:"object"`

52
backend/domain/plugin.go Normal file
View File

@@ -0,0 +1,52 @@
package domain
type PluginConfig struct {
ProviderProfiles ProviderProfiles `json:"providerProfiles"`
CtcodeTabCompletions CtcodeTabCompletions `json:"ctcodeTabCompletions"`
GlobalSettings GlobalSettings `json:"globalSettings"`
}
type ProviderProfiles struct {
CurrentApiConfigName string `json:"currentApiConfigName"`
ApiConfigs map[string]ApiConfig `json:"apiConfigs"`
ModeApiConfigs map[string]string `json:"modeApiConfigs"`
Migrations Migrations `json:"migrations"`
}
type ApiConfig struct {
ApiProvider string `json:"apiProvider"`
ApiModelId string `json:"apiModelId"`
OpenAiBaseUrl string `json:"openAiBaseUrl"`
OpenAiApiKey string `json:"openAiApiKey"`
OpenAiModelId string `json:"openAiModelId"`
OpenAiR1FormatEnabled bool `json:"openAiR1FormatEnabled"`
OpenAiCustomModelInfo OpenAiCustomModelInfo `json:"openAiCustomModelInfo"`
Id string `json:"id"`
}
type OpenAiCustomModelInfo struct {
MaxTokens int `json:"maxTokens"`
ContextWindow int `json:"contextWindow"`
SupportsImages bool `json:"supportsImages"`
SupportsComputerUse bool `json:"supportsComputerUse"`
SupportsPromptCache bool `json:"supportsPromptCache"`
}
type Migrations struct {
RateLimitSecondsMigrated bool `json:"rateLimitSecondsMigrated"`
DiffSettingsMigrated bool `json:"diffSettingsMigrated"`
}
type CtcodeTabCompletions struct {
Enabled bool `json:"enabled"`
ApiProvider string `json:"apiProvider"`
OpenAiBaseUrl string `json:"openAiBaseUrl"`
OpenAiApiKey string `json:"openAiApiKey"`
OpenAiModelId string `json:"openAiModelId"`
}
type GlobalSettings struct {
AllowedCommands []string `json:"allowedCommands"`
Mode string `json:"mode"`
CustomModes []string `json:"customModes"`
}

View File

@@ -2,10 +2,13 @@ package middleware
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"github.com/redis/go-redis/v9"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/ent/rule"
@@ -16,15 +19,23 @@ const (
ApiContextKey = "session:apikey"
)
type proxyModelKey struct{}
type ProxyMiddleware struct {
usecase domain.ProxyUsecase
redis *redis.Client
logger *slog.Logger
}
func NewProxyMiddleware(
usecase domain.ProxyUsecase,
redis *redis.Client,
logger *slog.Logger,
) *ProxyMiddleware {
return &ProxyMiddleware{
usecase: usecase,
redis: redis,
logger: logger.With("module", "ProxyMiddleware"),
}
}
@@ -39,21 +50,54 @@ func (p *ProxyMiddleware) Auth() echo.MiddlewareFunc {
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
}
key, err := p.usecase.ValidateApiKey(c.Request().Context(), apiKey)
if err != nil {
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
ctx := c.Request().Context()
p.logger.With("apiKey", apiKey).DebugContext(ctx, "v1 auth")
if strings.Contains(apiKey, ".") {
s, err := p.redis.Get(ctx, apiKey).Result()
if err != nil {
p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to get api key from redis")
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
}
var model *domain.Model
if err := json.Unmarshal([]byte(s), &model); err != nil {
p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to unmarshal model from redis")
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
}
parts := strings.Split(apiKey, ".")
if len(parts) != 2 {
p.logger.With("fn", "Auth").With("apiKey", apiKey).ErrorContext(ctx, "invalid api key")
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
}
ctx = context.WithValue(ctx, proxyModelKey{}, model)
ctx = context.WithValue(ctx, logger.UserIDKey{}, parts[0])
c.Set(ApiContextKey, &domain.ApiKey{
UserID: parts[0],
Key: apiKey,
})
} else {
key, err := p.usecase.ValidateApiKey(ctx, apiKey)
if err != nil {
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
}
ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID)
c.Set(ApiContextKey, key)
}
ctx := c.Request().Context()
ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID)
ctx = rule.SkipPermission(ctx)
c.SetRequest(c.Request().WithContext(ctx))
c.Set(ApiContextKey, key)
return next(c)
}
}
}
func GetProxyModel(ctx context.Context) *domain.Model {
m := ctx.Value(proxyModelKey{})
if m == nil {
return nil
}
return m.(*domain.Model)
}
func GetApiKey(c echo.Context) *domain.ApiKey {
i := c.Get(ApiContextKey)
if i == nil {

View File

@@ -30,15 +30,16 @@ func NewModelRepo(db *db.Client) domain.ModelRepo {
return &ModelRepo{db: db, cache: cache}
}
func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType) (*db.Model, error) {
func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType) ([]*db.Model, error) {
if v, ok := r.cache.Get(string(modelType)); ok {
return v.(*db.Model), nil
return v.([]*db.Model), nil
}
m, err := r.db.Model.Query().
Where(model.ModelType(modelType)).
Where(model.Status(consts.ModelStatusActive)).
Only(ctx)
Where(model.StatusIn(consts.ModelStatusActive, consts.ModelStatusDefault)).
Order(ByStatusOrder()).
All(ctx)
if err != nil {
return nil, err
}
@@ -47,14 +48,22 @@ func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType
return m, nil
}
func ByStatusOrder() func(s *sql.Selector) {
return func(s *sql.Selector) {
s.OrderExprFunc(func(b *sql.Builder) {
b.WriteString("case when status = 'default' then 3 when status = 'active' then 2 else 1 end desc")
})
}
}
func (r *ModelRepo) Create(ctx context.Context, m *domain.CreateModelReq) (*db.Model, error) {
n, err := r.db.Model.Query().Where(model.ModelType(m.ModelType)).Count(ctx)
if err != nil {
return nil, err
}
status := consts.ModelStatusInactive
status := consts.ModelStatusActive
if n == 0 {
status = consts.ModelStatusActive
status = consts.ModelStatusDefault
}
r.cache.Delete(string(m.ModelType))

View File

@@ -108,14 +108,27 @@ func (m *ModelUsecase) Update(ctx context.Context, req *domain.UpdateModelReq) (
up.SetShowName(*req.ShowName)
}
if req.Status != nil {
if *req.Status == consts.ModelStatusActive {
if *req.Status == consts.ModelStatusDefault {
if err := tx.Model.Update().
Where(model.Status(consts.ModelStatusDefault)).
Where(model.ModelType(old.ModelType)).
SetStatus(consts.ModelStatusInactive).
SetStatus(consts.ModelStatusActive).
Exec(ctx); err != nil {
return err
}
}
if *req.Status == consts.ModelStatusActive {
n, err := tx.Model.Query().
Where(model.Status(consts.ModelStatusDefault)).
Where(model.ModelType(old.ModelType)).
Count(ctx)
if err != nil {
return err
}
if n == 0 {
*req.Status = consts.ModelStatusDefault
}
}
up.SetStatus(*req.Status)
}
if req.Param != nil {

View File

@@ -1,18 +1,21 @@
package openai
import (
"bytes"
"context"
"html/template"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"github.com/chaitin/MonkeyCode/backend/ent/types"
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
"github.com/redis/go-redis/v9"
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/ent/types"
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
)
type OpenAIUsecase struct {
@@ -20,6 +23,7 @@ type OpenAIUsecase struct {
modelRepo domain.ModelRepo
cfg *config.Config
logger *slog.Logger
redis *redis.Client
}
func NewOpenAIUsecase(
@@ -27,12 +31,14 @@ func NewOpenAIUsecase(
repo domain.OpenAIRepo,
modelRepo domain.ModelRepo,
logger *slog.Logger,
redis *redis.Client,
) domain.OpenAIUsecase {
return &OpenAIUsecase{
repo: repo,
modelRepo: modelRepo,
cfg: cfg,
logger: logger,
redis: redis,
}
}
@@ -42,10 +48,6 @@ func (u *OpenAIUsecase) ModelList(ctx context.Context) (*domain.ModelListResp, e
return nil, err
}
for _, v := range models {
u.logger.DebugContext(ctx, "model", slog.Any("model", v))
}
resp := &domain.ModelListResp{
Object: "list",
Data: cvt.Iter(models, func(_ int, m *db.Model) *domain.ModelData {
@@ -61,44 +63,106 @@ func (u *OpenAIUsecase) GetConfig(ctx context.Context, req *domain.ConfigReq) (*
if err != nil {
return nil, err
}
llm, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeLLM)
llms, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeLLM)
if err != nil {
return nil, err
}
coder, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeCoder)
coders, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeCoder)
if err != nil {
return nil, err
}
u.logger.With(
"llms", len(llms),
"coders", len(coders),
).DebugContext(ctx, "get config")
if len(llms) == 0 || len(coders) == 0 {
return nil, errors.New("no model")
}
llm := llms[0]
coder := coders[0]
coderkey := fmt.Sprintf("%s.%s", apiKey.UserID.String(), coder.ID.String())
if err = u.redis.Get(ctx, coderkey).Err(); err != nil {
b, err := json.Marshal(cvt.From(coder, &domain.Model{}))
if err != nil {
return nil, err
}
if err = u.redis.Set(ctx, coderkey, string(b), time.Hour*24).Err(); err != nil {
return nil, err
}
}
if llm.Parameters == nil {
llm.Parameters = types.DefaultModelParam()
}
t, err := template.New("config").Parse(string(config.ConfigTmpl))
if err != nil {
return nil, err
config := &domain.PluginConfig{
ProviderProfiles: domain.ProviderProfiles{
CurrentApiConfigName: "default",
ApiConfigs: map[string]domain.ApiConfig{},
ModeApiConfigs: map[string]string{
"code": "59admorkig4",
"architect": "59admorkig4",
"ask": "59admorkig4",
"debug": "59admorkig4",
"deepresearch": "59admorkig4",
},
Migrations: domain.Migrations{
RateLimitSecondsMigrated: true,
DiffSettingsMigrated: true,
},
},
CtcodeTabCompletions: domain.CtcodeTabCompletions{
Enabled: true,
ApiProvider: "openai",
OpenAiBaseUrl: req.BaseURL + "/v1",
OpenAiApiKey: coderkey,
OpenAiModelId: coder.ModelName,
},
}
u.logger.With("param", llm.Parameters).DebugContext(ctx, "get config")
cnt := bytes.NewBuffer(nil)
data := map[string]any{
"apiBase": req.BaseURL,
"apikey": apiKey.Key,
"chatModel": llm.ModelName,
"codeModel": coder.ModelName,
"r1Enabled": llm.Parameters.R1Enabled,
"maxTokens": llm.Parameters.MaxTokens,
"contextWindow": llm.Parameters.ContextWindow,
"supportsImages": llm.Parameters.SupprtImages,
"supportsComputerUse": llm.Parameters.SupportComputerUse,
"supportsPromptCache": llm.Parameters.SupportPromptCache,
}
if err := t.Execute(cnt, data); err != nil {
return nil, err
for _, m := range llms {
key := fmt.Sprintf("%s.%s", apiKey.UserID.String(), m.ID.String())
if m.Parameters == nil {
m.Parameters = types.DefaultModelParam()
}
name := fmt.Sprintf("%s (%s)", m.ModelName, m.Provider)
if m.Status == consts.ModelStatusDefault {
name = "default"
}
config.ProviderProfiles.ApiConfigs[name] = domain.ApiConfig{
ApiProvider: "openai",
ApiModelId: m.ModelName,
OpenAiBaseUrl: req.BaseURL + "/v1",
OpenAiApiKey: key,
OpenAiModelId: m.ModelName,
OpenAiR1FormatEnabled: m.Parameters.R1Enabled,
OpenAiCustomModelInfo: domain.OpenAiCustomModelInfo{
MaxTokens: m.Parameters.MaxTokens,
ContextWindow: m.Parameters.ContextWindow,
SupportsImages: m.Parameters.SupprtImages,
SupportsComputerUse: m.Parameters.SupportComputerUse,
SupportsPromptCache: m.Parameters.SupportPromptCache,
},
Id: m.ID.String(),
}
if err = u.redis.Get(ctx, key).Err(); err == nil {
continue
}
b, err := json.Marshal(cvt.From(m, &domain.Model{}))
if err != nil {
return nil, err
}
if err := u.redis.Set(ctx, key, string(b), time.Hour*24).Err(); err != nil {
return nil, err
}
}
return &domain.ConfigResp{
Type: req.Type,
Content: cnt.String(),
Content: config,
}, nil
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/internal/middleware"
"github.com/chaitin/MonkeyCode/backend/pkg/logger"
"github.com/chaitin/MonkeyCode/backend/pkg/tee"
)
@@ -96,11 +97,16 @@ func (l *LLMProxy) rewrite(r *httputil.ProxyRequest) {
return
}
m, err := l.usecase.SelectModelWithLoadBalancing("", mt)
if err != nil {
l.logger.ErrorContext(r.In.Context(), "select model with load balancing failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
return
var m *domain.Model
var err error
if m = middleware.GetProxyModel(r.In.Context()); m == nil {
m, err = l.usecase.SelectModelWithLoadBalancing("", mt)
if err != nil {
l.logger.ErrorContext(r.In.Context(), "select model with load balancing failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
return
}
}
ul, err := url.Parse(m.APIBase)
if err != nil {
l.logger.ErrorContext(r.In.Context(), "parse model api base failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))

View File

@@ -9,6 +9,7 @@ import (
"path"
"time"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/chaitin/MonkeyCode/backend/config"
@@ -16,6 +17,7 @@ import (
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/ent/rule"
"github.com/chaitin/MonkeyCode/backend/internal/middleware"
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
"github.com/chaitin/MonkeyCode/backend/pkg/queuerunner"
"github.com/chaitin/MonkeyCode/backend/pkg/request"
@@ -85,7 +87,16 @@ func (p *ProxyUsecase) SelectModelWithLoadBalancing(modelName string, modelType
if err != nil {
return nil, err
}
return cvt.From(model, &domain.Model{}), nil
if len(model) == 0 {
return nil, fmt.Errorf("no model found")
}
m := model[0]
for _, mm := range model {
if mm.Status == consts.ModelStatusDefault {
m = mm
}
}
return cvt.From(m, &domain.Model{}), nil
}
func (p *ProxyUsecase) ValidateApiKey(ctx context.Context, key string) (*domain.ApiKey, error) {
@@ -102,12 +113,33 @@ func (p *ProxyUsecase) AcceptCompletion(ctx context.Context, req *domain.AcceptC
func (p *ProxyUsecase) Report(ctx context.Context, req *domain.ReportReq) error {
var model *db.Model
var err error
if req.Action == consts.ReportActionNewTask {
model, err = p.modelRepo.GetWithCache(context.Background(), consts.ModelTypeLLM)
if err != nil {
p.logger.With("fn", "Report").With("error", err).ErrorContext(ctx, "failed to get model")
return err
m := middleware.GetProxyModel(ctx)
if m == nil {
ms, err := p.modelRepo.GetWithCache(ctx, consts.ModelTypeLLM)
if err != nil {
return fmt.Errorf("get model with cache failed: %w", err)
}
if len(ms) == 0 {
return fmt.Errorf("no model found")
}
model = ms[0]
for _, mm := range ms {
if mm.Status == consts.ModelStatusDefault {
model = mm
break
}
}
} else {
mid, err := uuid.Parse(m.ID)
if err != nil {
return fmt.Errorf("parse proxy model id failed: %w", err)
}
model = &db.Model{
ID: mid,
ModelName: m.ModelName,
ModelType: m.ModelType,
}
}
}
return p.repo.Report(ctx, model, req)

View File

@@ -19,6 +19,7 @@ export enum GithubComChaitinMonkeyCodeBackendConstsModelType {
}
export enum GithubComChaitinMonkeyCodeBackendConstsModelStatus {
ModelStatusDefault = "default",
ModelStatusActive = "active",
ModelStatusInactive = "inactive",
}

View File

@@ -6,7 +6,6 @@ import {
} from '@/api/Model';
import { DomainModel, GithubComChaitinMonkeyCodeBackendConstsModelStatus, GithubComChaitinMonkeyCodeBackendConstsModelType, } from '@/api/types';
import { Stack, Box, Button, Grid2 as Grid, ButtonBase } from '@mui/material';
import StyledLabel from '@/components/label';
import { Icon, Modal, message } from '@c-x/ui';
import { addCommasToNumber } from '@/utils';
import NoData from '@/assets/images/nodata.png';
@@ -76,6 +75,31 @@ const ModelItem = ({
});
};
const onSetDefaultModel = () => {
Modal.confirm({
title: '设为默认模型',
content: (
<>
{' '}
<Box component='span' sx={{ fontWeight: 700, color: 'text.primary' }}>
{data.model_name}
</Box>{' '}
</>
),
onOk: () => {
putUpdateModel({
id: data.id,
status: GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusDefault,
provider: data.provider!,
}).then(() => {
message.success('设为默认模型成功');
refresh();
});
},
});
};
const onActiveModel = () => {
Modal.confirm({
title: '激活模型',
@@ -101,6 +125,100 @@ const ModelItem = ({
});
};
// 添加状态标签渲染函数
const renderStatusLabel = () => {
// 根据 is_active 和 status 字段判断状态
if (data.is_active && data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusActive) {
return (
<Box
sx={{
px: 1.5,
py: 0.5,
backgroundColor: 'success.main',
color: 'success.contrastText',
borderRadius: '0 0 0 8px', // 左下角圆角,贴合卡片右上角
fontSize: 12,
fontWeight: 600,
boxShadow: '0 2px 4px rgba(0,0,0,0.1)',
position: 'relative',
'&::before': {
content: '""',
position: 'absolute',
top: 0,
right: 0,
width: 0,
height: 0,
borderLeft: '6px solid transparent',
borderTop: '6px solid',
borderTopColor: 'success.dark',
}
}}
>
</Box>
);
} else if (data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusInactive) {
return (
<Box
sx={{
px: 1.5,
py: 0.5,
backgroundColor: 'grey.400',
color: 'grey.50',
borderRadius: '0 0 0 8px',
fontSize: 12,
fontWeight: 600,
boxShadow: '0 2px 4px rgba(0,0,0,0.1)',
position: 'relative',
'&::before': {
content: '""',
position: 'absolute',
top: 0,
right: 0,
width: 0,
height: 0,
borderLeft: '6px solid transparent',
borderTop: '6px solid',
borderTopColor: 'grey.600',
}
}}
>
</Box>
);
} else {
// 默认状态
return (
<Box
sx={{
px: 1.5,
py: 0.5,
backgroundColor: 'primary.main',
color: 'primary.contrastText',
borderRadius: '0 0 0 8px',
fontSize: 12,
fontWeight: 600,
boxShadow: '0 2px 4px rgba(0,0,0,0.1)',
position: 'relative',
'&::before': {
content: '""',
position: 'absolute',
top: 0,
right: 0,
width: 0,
height: 0,
borderLeft: '6px solid transparent',
borderTop: '6px solid',
borderTopColor: 'primary.dark',
}
}}
>
</Box>
);
}
};
return (
<Card
sx={{
@@ -109,7 +227,7 @@ const ModelItem = ({
transition: 'all 0.3s ease',
borderStyle: 'solid',
borderWidth: '1px',
borderColor: data.is_active ? 'success.main' : 'transparent',
borderColor: 'transparent',
boxShadow:
'0px 0px 10px 0px rgba(68, 80, 91, 0.1), 0px 0px 2px 0px rgba(68, 80, 91, 0.1)',
'&:hover': {
@@ -118,6 +236,18 @@ const ModelItem = ({
},
}}
>
{/* 美化的右上角状态标签 */}
<Box
sx={{
position: 'absolute',
top: 0,
right: 0,
zIndex: 2,
}}
>
{renderStatusLabel()}
</Box>
<Stack
direction='row'
alignItems='center'
@@ -129,7 +259,7 @@ const ModelItem = ({
type={
DEFAULT_MODEL_PROVIDERS[data.provider as keyof typeof DEFAULT_MODEL_PROVIDERS]?.icon
}
sx={{ fontSize: 24 }}
sx={{ fontSize: 24, color: data.is_active ? 'inherit' : 'grey.400' }}
/>
<Stack
direction='row'
@@ -143,6 +273,7 @@ const ModelItem = ({
overflow: 'hidden',
textOverflow: 'ellipsis',
whiteSpace: 'nowrap',
color: data.is_active ? 'inherit' : 'grey.400',
}}
>
{data.show_name || '未命名'}
@@ -195,11 +326,25 @@ const ModelItem = ({
gap={2}
sx={{ mt: 2 }}
>
<Stack direction='row' alignItems='center'>
<StyledLabel color={data.is_active ? 'success' : 'disabled'}>{data.is_active ? '正在使用' : '未激活'}</StyledLabel>
</Stack>
<Stack direction='row' alignItems='center'> </Stack>
<Stack direction='row' sx={{ button: { minWidth: 0 } }} gap={2}>
{!data.is_active && (
{(data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusActive ||
data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusInactive) && (
<ButtonBase
disableRipple
sx={{
color: 'text.primary',
'&:hover': {
fontWeight: 700
},
}}
onClick={onSetDefaultModel}
>
</ButtonBase>
)}
{data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusInactive && (
<ButtonBase
disableRipple
sx={{