From 4b357a07d7558917feb12c7aacb67c8afb1a11de Mon Sep 17 00:00:00 2001
From: yokowu <18836617@qq.com>
Date: Thu, 28 Aug 2025 16:45:23 +0800
Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E6=8F=92=E4=BB=B6?=
=?UTF-8?q?=E9=80=89=E6=8B=A9=E6=A8=A1=E5=9E=8B?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/cmd/server/wire_gen.go | 4 +-
backend/consts/model.go | 1 +
backend/domain/model.go | 4 +-
backend/domain/openai.go | 2 +-
backend/domain/plugin.go | 52 +++++++
backend/internal/middleware/proxy.go | 56 ++++++-
backend/internal/model/repo/model.go | 21 ++-
backend/internal/model/usecase/model.go | 17 ++-
backend/internal/openai/usecase/openai.go | 124 +++++++++++----
backend/internal/proxy/proxy.go | 14 +-
backend/internal/proxy/usecase/proxy.go | 44 +++++-
backend/pro | 2 +-
ui/src/api/types.ts | 1 +
ui/src/pages/model/components/modelCard.tsx | 159 +++++++++++++++++++-
14 files changed, 434 insertions(+), 67 deletions(-)
create mode 100644 backend/domain/plugin.go
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 64bf52f..123da85 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -75,7 +75,7 @@ func newServer() (*Server, error) {
proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo, securityScanningRepo, slogLogger, configConfig, redisClient)
llmProxy := proxy.NewLLMProxy(slogLogger, configConfig, proxyUsecase)
openAIRepo := repo4.NewOpenAIRepo(client)
- openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, modelRepo, slogLogger)
+ openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, modelRepo, slogLogger, redisClient)
extensionRepo := repo5.NewExtensionRepo(client)
extensionUsecase := usecase2.NewExtensionUsecase(extensionRepo, configConfig, slogLogger)
ipdbIPDB, err := ipdb.NewIPDB(slogLogger)
@@ -85,7 +85,7 @@ func newServer() (*Server, error) {
userRepo := repo6.NewUserRepo(client, ipdbIPDB, redisClient, configConfig)
sessionSession := session.NewSession(configConfig)
userUsecase := usecase3.NewUserUsecase(configConfig, redisClient, userRepo, slogLogger, sessionSession)
- proxyMiddleware := middleware.NewProxyMiddleware(proxyUsecase)
+ proxyMiddleware := middleware.NewProxyMiddleware(proxyUsecase, redisClient, slogLogger)
activeMiddleware := middleware.NewActiveMiddleware(redisClient, slogLogger)
v1Handler := v1.NewV1Handler(slogLogger, web, llmProxy, proxyUsecase, openAIUsecase, extensionUsecase, userUsecase, proxyMiddleware, activeMiddleware, configConfig)
modelUsecase := usecase4.NewModelUsecase(slogLogger, modelRepo, configConfig)
diff --git a/backend/consts/model.go b/backend/consts/model.go
index fad90d9..f1a37f8 100644
--- a/backend/consts/model.go
+++ b/backend/consts/model.go
@@ -3,6 +3,7 @@ package consts
type ModelStatus string
const (
+ ModelStatusDefault ModelStatus = "default"
ModelStatusActive ModelStatus = "active"
ModelStatusInactive ModelStatus = "inactive"
)
diff --git a/backend/domain/model.go b/backend/domain/model.go
index 4b37b19..2da8243 100644
--- a/backend/domain/model.go
+++ b/backend/domain/model.go
@@ -20,7 +20,7 @@ type ModelUsecase interface {
}
type ModelRepo interface {
- GetWithCache(ctx context.Context, modelType consts.ModelType) (*db.Model, error)
+ GetWithCache(ctx context.Context, modelType consts.ModelType) ([]*db.Model, error)
List(ctx context.Context) (*AllModelResp, error)
Create(ctx context.Context, m *CreateModelReq) (*db.Model, error)
Update(ctx context.Context, id string, fn func(tx *db.Tx, old *db.Model, up *db.ModelUpdateOne) error) (*db.Model, error)
@@ -181,7 +181,7 @@ func (m *Model) From(e *db.Model) *Model {
m.ModelType = e.ModelType
m.Status = e.Status
m.IsInternal = e.IsInternal
- m.IsActive = e.Status == consts.ModelStatusActive
+ m.IsActive = e.Status == consts.ModelStatusActive || e.Status == consts.ModelStatusDefault
if p := e.Parameters; p != nil {
m.Param = ModelParam{
R1Enabled: p.R1Enabled,
diff --git a/backend/domain/openai.go b/backend/domain/openai.go
index 95ec6a9..e3cfcce 100644
--- a/backend/domain/openai.go
+++ b/backend/domain/openai.go
@@ -68,7 +68,7 @@ type ConfigReq struct {
type ConfigResp struct {
Type consts.ConfigType `json:"type"`
- Content string `json:"content"`
+ Content any `json:"content"`
}
type OpenAIResp struct {
Object string `json:"object"`
diff --git a/backend/domain/plugin.go b/backend/domain/plugin.go
new file mode 100644
index 0000000..e979246
--- /dev/null
+++ b/backend/domain/plugin.go
@@ -0,0 +1,52 @@
+package domain
+
+type PluginConfig struct {
+ ProviderProfiles ProviderProfiles `json:"providerProfiles"`
+ CtcodeTabCompletions CtcodeTabCompletions `json:"ctcodeTabCompletions"`
+ GlobalSettings GlobalSettings `json:"globalSettings"`
+}
+
+type ProviderProfiles struct {
+ CurrentApiConfigName string `json:"currentApiConfigName"`
+ ApiConfigs map[string]ApiConfig `json:"apiConfigs"`
+ ModeApiConfigs map[string]string `json:"modeApiConfigs"`
+ Migrations Migrations `json:"migrations"`
+}
+
+type ApiConfig struct {
+ ApiProvider string `json:"apiProvider"`
+ ApiModelId string `json:"apiModelId"`
+ OpenAiBaseUrl string `json:"openAiBaseUrl"`
+ OpenAiApiKey string `json:"openAiApiKey"`
+ OpenAiModelId string `json:"openAiModelId"`
+ OpenAiR1FormatEnabled bool `json:"openAiR1FormatEnabled"`
+ OpenAiCustomModelInfo OpenAiCustomModelInfo `json:"openAiCustomModelInfo"`
+ Id string `json:"id"`
+}
+
+type OpenAiCustomModelInfo struct {
+ MaxTokens int `json:"maxTokens"`
+ ContextWindow int `json:"contextWindow"`
+ SupportsImages bool `json:"supportsImages"`
+ SupportsComputerUse bool `json:"supportsComputerUse"`
+ SupportsPromptCache bool `json:"supportsPromptCache"`
+}
+
+type Migrations struct {
+ RateLimitSecondsMigrated bool `json:"rateLimitSecondsMigrated"`
+ DiffSettingsMigrated bool `json:"diffSettingsMigrated"`
+}
+
+type CtcodeTabCompletions struct {
+ Enabled bool `json:"enabled"`
+ ApiProvider string `json:"apiProvider"`
+ OpenAiBaseUrl string `json:"openAiBaseUrl"`
+ OpenAiApiKey string `json:"openAiApiKey"`
+ OpenAiModelId string `json:"openAiModelId"`
+}
+
+type GlobalSettings struct {
+ AllowedCommands []string `json:"allowedCommands"`
+ Mode string `json:"mode"`
+ CustomModes []string `json:"customModes"`
+}
diff --git a/backend/internal/middleware/proxy.go b/backend/internal/middleware/proxy.go
index 702dc08..4b424c2 100644
--- a/backend/internal/middleware/proxy.go
+++ b/backend/internal/middleware/proxy.go
@@ -2,10 +2,13 @@ package middleware
import (
"context"
+ "encoding/json"
+ "log/slog"
"net/http"
"strings"
"github.com/labstack/echo/v4"
+ "github.com/redis/go-redis/v9"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/ent/rule"
@@ -16,15 +19,23 @@ const (
ApiContextKey = "session:apikey"
)
+type proxyModelKey struct{}
+
type ProxyMiddleware struct {
usecase domain.ProxyUsecase
+ redis *redis.Client
+ logger *slog.Logger
}
func NewProxyMiddleware(
usecase domain.ProxyUsecase,
+ redis *redis.Client,
+ logger *slog.Logger,
) *ProxyMiddleware {
return &ProxyMiddleware{
usecase: usecase,
+ redis: redis,
+ logger: logger.With("module", "ProxyMiddleware"),
}
}
@@ -39,21 +50,54 @@ func (p *ProxyMiddleware) Auth() echo.MiddlewareFunc {
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
}
- key, err := p.usecase.ValidateApiKey(c.Request().Context(), apiKey)
- if err != nil {
- return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
+ ctx := c.Request().Context()
+ p.logger.With("apiKey", apiKey).DebugContext(ctx, "v1 auth")
+ if strings.Contains(apiKey, ".") {
+ s, err := p.redis.Get(ctx, apiKey).Result()
+ if err != nil {
+ p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to get api key from redis")
+ return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
+ }
+ var model *domain.Model
+ if err := json.Unmarshal([]byte(s), &model); err != nil {
+ p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to unmarshal model from redis")
+ return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
+ }
+ parts := strings.Split(apiKey, ".")
+ if len(parts) != 2 {
+ p.logger.With("fn", "Auth").With("apiKey", apiKey).ErrorContext(ctx, "invalid api key")
+ return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
+ }
+ ctx = context.WithValue(ctx, proxyModelKey{}, model)
+ ctx = context.WithValue(ctx, logger.UserIDKey{}, parts[0])
+ c.Set(ApiContextKey, &domain.ApiKey{
+ UserID: parts[0],
+ Key: apiKey,
+ })
+ } else {
+ key, err := p.usecase.ValidateApiKey(ctx, apiKey)
+ if err != nil {
+ return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
+ }
+ ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID)
+ c.Set(ApiContextKey, key)
}
- ctx := c.Request().Context()
- ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID)
ctx = rule.SkipPermission(ctx)
c.SetRequest(c.Request().WithContext(ctx))
- c.Set(ApiContextKey, key)
return next(c)
}
}
}
+func GetProxyModel(ctx context.Context) *domain.Model {
+ m := ctx.Value(proxyModelKey{})
+ if m == nil {
+ return nil
+ }
+ return m.(*domain.Model)
+}
+
func GetApiKey(c echo.Context) *domain.ApiKey {
i := c.Get(ApiContextKey)
if i == nil {
diff --git a/backend/internal/model/repo/model.go b/backend/internal/model/repo/model.go
index 50c0fd9..9e18841 100644
--- a/backend/internal/model/repo/model.go
+++ b/backend/internal/model/repo/model.go
@@ -30,15 +30,16 @@ func NewModelRepo(db *db.Client) domain.ModelRepo {
return &ModelRepo{db: db, cache: cache}
}
-func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType) (*db.Model, error) {
+func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType) ([]*db.Model, error) {
if v, ok := r.cache.Get(string(modelType)); ok {
- return v.(*db.Model), nil
+ return v.([]*db.Model), nil
}
m, err := r.db.Model.Query().
Where(model.ModelType(modelType)).
- Where(model.Status(consts.ModelStatusActive)).
- Only(ctx)
+ Where(model.StatusIn(consts.ModelStatusActive, consts.ModelStatusDefault)).
+ Order(ByStatusOrder()).
+ All(ctx)
if err != nil {
return nil, err
}
@@ -47,14 +48,22 @@ func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType
return m, nil
}
+func ByStatusOrder() func(s *sql.Selector) {
+ return func(s *sql.Selector) {
+ s.OrderExprFunc(func(b *sql.Builder) {
+ b.WriteString("case when status = 'default' then 3 when status = 'active' then 2 else 1 end desc")
+ })
+ }
+}
+
func (r *ModelRepo) Create(ctx context.Context, m *domain.CreateModelReq) (*db.Model, error) {
n, err := r.db.Model.Query().Where(model.ModelType(m.ModelType)).Count(ctx)
if err != nil {
return nil, err
}
- status := consts.ModelStatusInactive
+ status := consts.ModelStatusActive
if n == 0 {
- status = consts.ModelStatusActive
+ status = consts.ModelStatusDefault
}
r.cache.Delete(string(m.ModelType))
diff --git a/backend/internal/model/usecase/model.go b/backend/internal/model/usecase/model.go
index f046ea9..a7d5dd9 100644
--- a/backend/internal/model/usecase/model.go
+++ b/backend/internal/model/usecase/model.go
@@ -108,14 +108,27 @@ func (m *ModelUsecase) Update(ctx context.Context, req *domain.UpdateModelReq) (
up.SetShowName(*req.ShowName)
}
if req.Status != nil {
- if *req.Status == consts.ModelStatusActive {
+ if *req.Status == consts.ModelStatusDefault {
if err := tx.Model.Update().
+ Where(model.Status(consts.ModelStatusDefault)).
Where(model.ModelType(old.ModelType)).
- SetStatus(consts.ModelStatusInactive).
+ SetStatus(consts.ModelStatusActive).
Exec(ctx); err != nil {
return err
}
}
+ if *req.Status == consts.ModelStatusActive {
+ n, err := tx.Model.Query().
+ Where(model.Status(consts.ModelStatusDefault)).
+ Where(model.ModelType(old.ModelType)).
+ Count(ctx)
+ if err != nil {
+ return err
+ }
+ if n == 0 {
+ *req.Status = consts.ModelStatusDefault
+ }
+ }
up.SetStatus(*req.Status)
}
if req.Param != nil {
diff --git a/backend/internal/openai/usecase/openai.go b/backend/internal/openai/usecase/openai.go
index 5a4abe4..8fedd7d 100644
--- a/backend/internal/openai/usecase/openai.go
+++ b/backend/internal/openai/usecase/openai.go
@@ -1,18 +1,21 @@
package openai
import (
- "bytes"
"context"
- "html/template"
+ "encoding/json"
+ "errors"
+ "fmt"
"log/slog"
+ "time"
- "github.com/chaitin/MonkeyCode/backend/ent/types"
- "github.com/chaitin/MonkeyCode/backend/pkg/cvt"
+ "github.com/redis/go-redis/v9"
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/domain"
+ "github.com/chaitin/MonkeyCode/backend/ent/types"
+ "github.com/chaitin/MonkeyCode/backend/pkg/cvt"
)
type OpenAIUsecase struct {
@@ -20,6 +23,7 @@ type OpenAIUsecase struct {
modelRepo domain.ModelRepo
cfg *config.Config
logger *slog.Logger
+ redis *redis.Client
}
func NewOpenAIUsecase(
@@ -27,12 +31,14 @@ func NewOpenAIUsecase(
repo domain.OpenAIRepo,
modelRepo domain.ModelRepo,
logger *slog.Logger,
+ redis *redis.Client,
) domain.OpenAIUsecase {
return &OpenAIUsecase{
repo: repo,
modelRepo: modelRepo,
cfg: cfg,
logger: logger,
+ redis: redis,
}
}
@@ -42,10 +48,6 @@ func (u *OpenAIUsecase) ModelList(ctx context.Context) (*domain.ModelListResp, e
return nil, err
}
- for _, v := range models {
- u.logger.DebugContext(ctx, "model", slog.Any("model", v))
- }
-
resp := &domain.ModelListResp{
Object: "list",
Data: cvt.Iter(models, func(_ int, m *db.Model) *domain.ModelData {
@@ -61,44 +63,106 @@ func (u *OpenAIUsecase) GetConfig(ctx context.Context, req *domain.ConfigReq) (*
if err != nil {
return nil, err
}
- llm, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeLLM)
+ llms, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeLLM)
if err != nil {
return nil, err
}
- coder, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeCoder)
+ coders, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeCoder)
if err != nil {
return nil, err
}
+ u.logger.With(
+ "llms", len(llms),
+ "coders", len(coders),
+ ).DebugContext(ctx, "get config")
+
+ if len(llms) == 0 || len(coders) == 0 {
+ return nil, errors.New("no model")
+ }
+
+ llm := llms[0]
+ coder := coders[0]
+ coderkey := fmt.Sprintf("%s.%s", apiKey.UserID.String(), coder.ID.String())
+ if err = u.redis.Get(ctx, coderkey).Err(); err != nil {
+ b, err := json.Marshal(cvt.From(coder, &domain.Model{}))
+ if err != nil {
+ return nil, err
+ }
+ if err = u.redis.Set(ctx, coderkey, string(b), time.Hour*24).Err(); err != nil {
+ return nil, err
+ }
+ }
+
if llm.Parameters == nil {
llm.Parameters = types.DefaultModelParam()
}
- t, err := template.New("config").Parse(string(config.ConfigTmpl))
- if err != nil {
- return nil, err
+ config := &domain.PluginConfig{
+ ProviderProfiles: domain.ProviderProfiles{
+ CurrentApiConfigName: "default",
+ ApiConfigs: map[string]domain.ApiConfig{},
+ ModeApiConfigs: map[string]string{
+ "code": "59admorkig4",
+ "architect": "59admorkig4",
+ "ask": "59admorkig4",
+ "debug": "59admorkig4",
+ "deepresearch": "59admorkig4",
+ },
+ Migrations: domain.Migrations{
+ RateLimitSecondsMigrated: true,
+ DiffSettingsMigrated: true,
+ },
+ },
+ CtcodeTabCompletions: domain.CtcodeTabCompletions{
+ Enabled: true,
+ ApiProvider: "openai",
+ OpenAiBaseUrl: req.BaseURL + "/v1",
+ OpenAiApiKey: coderkey,
+ OpenAiModelId: coder.ModelName,
+ },
}
- u.logger.With("param", llm.Parameters).DebugContext(ctx, "get config")
- cnt := bytes.NewBuffer(nil)
- data := map[string]any{
- "apiBase": req.BaseURL,
- "apikey": apiKey.Key,
- "chatModel": llm.ModelName,
- "codeModel": coder.ModelName,
- "r1Enabled": llm.Parameters.R1Enabled,
- "maxTokens": llm.Parameters.MaxTokens,
- "contextWindow": llm.Parameters.ContextWindow,
- "supportsImages": llm.Parameters.SupprtImages,
- "supportsComputerUse": llm.Parameters.SupportComputerUse,
- "supportsPromptCache": llm.Parameters.SupportPromptCache,
- }
- if err := t.Execute(cnt, data); err != nil {
- return nil, err
+ for _, m := range llms {
+ key := fmt.Sprintf("%s.%s", apiKey.UserID.String(), m.ID.String())
+ if m.Parameters == nil {
+ m.Parameters = types.DefaultModelParam()
+ }
+ name := fmt.Sprintf("%s (%s)", m.ModelName, m.Provider)
+ if m.Status == consts.ModelStatusDefault {
+ name = "default"
+ }
+ config.ProviderProfiles.ApiConfigs[name] = domain.ApiConfig{
+ ApiProvider: "openai",
+ ApiModelId: m.ModelName,
+ OpenAiBaseUrl: req.BaseURL + "/v1",
+ OpenAiApiKey: key,
+ OpenAiModelId: m.ModelName,
+ OpenAiR1FormatEnabled: m.Parameters.R1Enabled,
+ OpenAiCustomModelInfo: domain.OpenAiCustomModelInfo{
+ MaxTokens: m.Parameters.MaxTokens,
+ ContextWindow: m.Parameters.ContextWindow,
+ SupportsImages: m.Parameters.SupprtImages,
+ SupportsComputerUse: m.Parameters.SupportComputerUse,
+ SupportsPromptCache: m.Parameters.SupportPromptCache,
+ },
+ Id: m.ID.String(),
+ }
+
+ if err = u.redis.Get(ctx, key).Err(); err == nil {
+ continue
+ }
+ b, err := json.Marshal(cvt.From(m, &domain.Model{}))
+ if err != nil {
+ return nil, err
+ }
+ if err := u.redis.Set(ctx, key, string(b), time.Hour*24).Err(); err != nil {
+ return nil, err
+ }
}
return &domain.ConfigResp{
Type: req.Type,
- Content: cnt.String(),
+ Content: config,
}, nil
}
diff --git a/backend/internal/proxy/proxy.go b/backend/internal/proxy/proxy.go
index 612c2dc..bb1f0fe 100644
--- a/backend/internal/proxy/proxy.go
+++ b/backend/internal/proxy/proxy.go
@@ -16,6 +16,7 @@ import (
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/domain"
+ "github.com/chaitin/MonkeyCode/backend/internal/middleware"
"github.com/chaitin/MonkeyCode/backend/pkg/logger"
"github.com/chaitin/MonkeyCode/backend/pkg/tee"
)
@@ -96,11 +97,16 @@ func (l *LLMProxy) rewrite(r *httputil.ProxyRequest) {
return
}
- m, err := l.usecase.SelectModelWithLoadBalancing("", mt)
- if err != nil {
- l.logger.ErrorContext(r.In.Context(), "select model with load balancing failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
- return
+ var m *domain.Model
+ var err error
+ if m = middleware.GetProxyModel(r.In.Context()); m == nil {
+ m, err = l.usecase.SelectModelWithLoadBalancing("", mt)
+ if err != nil {
+ l.logger.ErrorContext(r.In.Context(), "select model with load balancing failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
+ return
+ }
}
+
ul, err := url.Parse(m.APIBase)
if err != nil {
l.logger.ErrorContext(r.In.Context(), "parse model api base failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
diff --git a/backend/internal/proxy/usecase/proxy.go b/backend/internal/proxy/usecase/proxy.go
index 4970ce9..b3b9ea1 100644
--- a/backend/internal/proxy/usecase/proxy.go
+++ b/backend/internal/proxy/usecase/proxy.go
@@ -9,6 +9,7 @@ import (
"path"
"time"
+ "github.com/google/uuid"
"github.com/redis/go-redis/v9"
"github.com/chaitin/MonkeyCode/backend/config"
@@ -16,6 +17,7 @@ import (
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/ent/rule"
+ "github.com/chaitin/MonkeyCode/backend/internal/middleware"
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
"github.com/chaitin/MonkeyCode/backend/pkg/queuerunner"
"github.com/chaitin/MonkeyCode/backend/pkg/request"
@@ -85,7 +87,16 @@ func (p *ProxyUsecase) SelectModelWithLoadBalancing(modelName string, modelType
if err != nil {
return nil, err
}
- return cvt.From(model, &domain.Model{}), nil
+ if len(model) == 0 {
+ return nil, fmt.Errorf("no model found")
+ }
+ m := model[0]
+ for _, mm := range model {
+ if mm.Status == consts.ModelStatusDefault {
+ m = mm
+ }
+ }
+ return cvt.From(m, &domain.Model{}), nil
}
func (p *ProxyUsecase) ValidateApiKey(ctx context.Context, key string) (*domain.ApiKey, error) {
@@ -102,12 +113,33 @@ func (p *ProxyUsecase) AcceptCompletion(ctx context.Context, req *domain.AcceptC
func (p *ProxyUsecase) Report(ctx context.Context, req *domain.ReportReq) error {
var model *db.Model
- var err error
if req.Action == consts.ReportActionNewTask {
- model, err = p.modelRepo.GetWithCache(context.Background(), consts.ModelTypeLLM)
- if err != nil {
- p.logger.With("fn", "Report").With("error", err).ErrorContext(ctx, "failed to get model")
- return err
+ m := middleware.GetProxyModel(ctx)
+ if m == nil {
+ ms, err := p.modelRepo.GetWithCache(ctx, consts.ModelTypeLLM)
+ if err != nil {
+ return fmt.Errorf("get model with cache failed: %w", err)
+ }
+ if len(ms) == 0 {
+ return fmt.Errorf("no model found")
+ }
+ model = ms[0]
+ for _, mm := range ms {
+ if mm.Status == consts.ModelStatusDefault {
+ model = mm
+ break
+ }
+ }
+ } else {
+ mid, err := uuid.Parse(m.ID)
+ if err != nil {
+ return fmt.Errorf("parse proxy model id failed: %w", err)
+ }
+ model = &db.Model{
+ ID: mid,
+ ModelName: m.ModelName,
+ ModelType: m.ModelType,
+ }
}
}
return p.repo.Report(ctx, model, req)
diff --git a/backend/pro b/backend/pro
index 482ab56..a566152 160000
--- a/backend/pro
+++ b/backend/pro
@@ -1 +1 @@
-Subproject commit 482ab56daf6d7d205f47e25c057bcab4804a6b73
+Subproject commit a566152111f123a8ccb3bef9a21b1d40f0d5bb47
diff --git a/ui/src/api/types.ts b/ui/src/api/types.ts
index e2966ea..f7f8d11 100644
--- a/ui/src/api/types.ts
+++ b/ui/src/api/types.ts
@@ -19,6 +19,7 @@ export enum GithubComChaitinMonkeyCodeBackendConstsModelType {
}
export enum GithubComChaitinMonkeyCodeBackendConstsModelStatus {
+ ModelStatusDefault = "default",
ModelStatusActive = "active",
ModelStatusInactive = "inactive",
}
diff --git a/ui/src/pages/model/components/modelCard.tsx b/ui/src/pages/model/components/modelCard.tsx
index e52af6c..cf838eb 100644
--- a/ui/src/pages/model/components/modelCard.tsx
+++ b/ui/src/pages/model/components/modelCard.tsx
@@ -6,7 +6,6 @@ import {
} from '@/api/Model';
import { DomainModel, GithubComChaitinMonkeyCodeBackendConstsModelStatus, GithubComChaitinMonkeyCodeBackendConstsModelType, } from '@/api/types';
import { Stack, Box, Button, Grid2 as Grid, ButtonBase } from '@mui/material';
-import StyledLabel from '@/components/label';
import { Icon, Modal, message } from '@c-x/ui';
import { addCommasToNumber } from '@/utils';
import NoData from '@/assets/images/nodata.png';
@@ -76,6 +75,31 @@ const ModelItem = ({
});
};
+ const onSetDefaultModel = () => {
+ Modal.confirm({
+ title: '设为默认模型',
+ content: (
+ <>
+ 确定要设置{' '}
+
+ {data.model_name}
+ {' '}
+ 为默认模型吗?
+ >
+ ),
+ onOk: () => {
+ putUpdateModel({
+ id: data.id,
+ status: GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusDefault,
+ provider: data.provider!,
+ }).then(() => {
+ message.success('设为默认模型成功');
+ refresh();
+ });
+ },
+ });
+ };
+
const onActiveModel = () => {
Modal.confirm({
title: '激活模型',
@@ -101,6 +125,100 @@ const ModelItem = ({
});
};
+ // 添加状态标签渲染函数
+ const renderStatusLabel = () => {
+ // 根据 is_active 和 status 字段判断状态
+ if (data.is_active && data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusActive) {
+ return (
+
+ 可选
+
+ );
+ } else if (data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusInactive) {
+ return (
+
+ 未激活
+
+ );
+ } else {
+ // 默认状态
+ return (
+
+ 默认
+
+ );
+ }
+ };
+
return (
+ {/* 美化的右上角状态标签 */}
+
+ {renderStatusLabel()}
+
+
{data.show_name || '未命名'}
@@ -195,11 +326,25 @@ const ModelItem = ({
gap={2}
sx={{ mt: 2 }}
>
-
- {data.is_active ? '正在使用' : '未激活'}
-
+
- {!data.is_active && (
+ {(data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusActive ||
+ data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusInactive) && (
+
+ 设为默认
+
+ )}
+
+ {data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusInactive && (