mirror of
https://github.com/chaitin/MonkeyCode.git
synced 2026-02-02 14:53:55 +08:00
@@ -75,7 +75,7 @@ func newServer() (*Server, error) {
|
||||
proxyUsecase := usecase.NewProxyUsecase(proxyRepo, modelRepo, securityScanningRepo, slogLogger, configConfig, redisClient)
|
||||
llmProxy := proxy.NewLLMProxy(slogLogger, configConfig, proxyUsecase)
|
||||
openAIRepo := repo4.NewOpenAIRepo(client)
|
||||
openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, modelRepo, slogLogger)
|
||||
openAIUsecase := openai.NewOpenAIUsecase(configConfig, openAIRepo, modelRepo, slogLogger, redisClient)
|
||||
extensionRepo := repo5.NewExtensionRepo(client)
|
||||
extensionUsecase := usecase2.NewExtensionUsecase(extensionRepo, configConfig, slogLogger)
|
||||
ipdbIPDB, err := ipdb.NewIPDB(slogLogger)
|
||||
@@ -85,7 +85,7 @@ func newServer() (*Server, error) {
|
||||
userRepo := repo6.NewUserRepo(client, ipdbIPDB, redisClient, configConfig)
|
||||
sessionSession := session.NewSession(configConfig)
|
||||
userUsecase := usecase3.NewUserUsecase(configConfig, redisClient, userRepo, slogLogger, sessionSession)
|
||||
proxyMiddleware := middleware.NewProxyMiddleware(proxyUsecase)
|
||||
proxyMiddleware := middleware.NewProxyMiddleware(proxyUsecase, redisClient, slogLogger)
|
||||
activeMiddleware := middleware.NewActiveMiddleware(redisClient, slogLogger)
|
||||
v1Handler := v1.NewV1Handler(slogLogger, web, llmProxy, proxyUsecase, openAIUsecase, extensionUsecase, userUsecase, proxyMiddleware, activeMiddleware, configConfig)
|
||||
modelUsecase := usecase4.NewModelUsecase(slogLogger, modelRepo, configConfig)
|
||||
|
||||
@@ -3,6 +3,7 @@ package consts
|
||||
type ModelStatus string
|
||||
|
||||
const (
|
||||
ModelStatusDefault ModelStatus = "default"
|
||||
ModelStatusActive ModelStatus = "active"
|
||||
ModelStatusInactive ModelStatus = "inactive"
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ type ModelUsecase interface {
|
||||
}
|
||||
|
||||
type ModelRepo interface {
|
||||
GetWithCache(ctx context.Context, modelType consts.ModelType) (*db.Model, error)
|
||||
GetWithCache(ctx context.Context, modelType consts.ModelType) ([]*db.Model, error)
|
||||
List(ctx context.Context) (*AllModelResp, error)
|
||||
Create(ctx context.Context, m *CreateModelReq) (*db.Model, error)
|
||||
Update(ctx context.Context, id string, fn func(tx *db.Tx, old *db.Model, up *db.ModelUpdateOne) error) (*db.Model, error)
|
||||
@@ -181,7 +181,7 @@ func (m *Model) From(e *db.Model) *Model {
|
||||
m.ModelType = e.ModelType
|
||||
m.Status = e.Status
|
||||
m.IsInternal = e.IsInternal
|
||||
m.IsActive = e.Status == consts.ModelStatusActive
|
||||
m.IsActive = e.Status == consts.ModelStatusActive || e.Status == consts.ModelStatusDefault
|
||||
if p := e.Parameters; p != nil {
|
||||
m.Param = ModelParam{
|
||||
R1Enabled: p.R1Enabled,
|
||||
|
||||
@@ -68,7 +68,7 @@ type ConfigReq struct {
|
||||
|
||||
type ConfigResp struct {
|
||||
Type consts.ConfigType `json:"type"`
|
||||
Content string `json:"content"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
type OpenAIResp struct {
|
||||
Object string `json:"object"`
|
||||
|
||||
52
backend/domain/plugin.go
Normal file
52
backend/domain/plugin.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package domain
|
||||
|
||||
type PluginConfig struct {
|
||||
ProviderProfiles ProviderProfiles `json:"providerProfiles"`
|
||||
CtcodeTabCompletions CtcodeTabCompletions `json:"ctcodeTabCompletions"`
|
||||
GlobalSettings GlobalSettings `json:"globalSettings"`
|
||||
}
|
||||
|
||||
type ProviderProfiles struct {
|
||||
CurrentApiConfigName string `json:"currentApiConfigName"`
|
||||
ApiConfigs map[string]ApiConfig `json:"apiConfigs"`
|
||||
ModeApiConfigs map[string]string `json:"modeApiConfigs"`
|
||||
Migrations Migrations `json:"migrations"`
|
||||
}
|
||||
|
||||
type ApiConfig struct {
|
||||
ApiProvider string `json:"apiProvider"`
|
||||
ApiModelId string `json:"apiModelId"`
|
||||
OpenAiBaseUrl string `json:"openAiBaseUrl"`
|
||||
OpenAiApiKey string `json:"openAiApiKey"`
|
||||
OpenAiModelId string `json:"openAiModelId"`
|
||||
OpenAiR1FormatEnabled bool `json:"openAiR1FormatEnabled"`
|
||||
OpenAiCustomModelInfo OpenAiCustomModelInfo `json:"openAiCustomModelInfo"`
|
||||
Id string `json:"id"`
|
||||
}
|
||||
|
||||
type OpenAiCustomModelInfo struct {
|
||||
MaxTokens int `json:"maxTokens"`
|
||||
ContextWindow int `json:"contextWindow"`
|
||||
SupportsImages bool `json:"supportsImages"`
|
||||
SupportsComputerUse bool `json:"supportsComputerUse"`
|
||||
SupportsPromptCache bool `json:"supportsPromptCache"`
|
||||
}
|
||||
|
||||
type Migrations struct {
|
||||
RateLimitSecondsMigrated bool `json:"rateLimitSecondsMigrated"`
|
||||
DiffSettingsMigrated bool `json:"diffSettingsMigrated"`
|
||||
}
|
||||
|
||||
type CtcodeTabCompletions struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ApiProvider string `json:"apiProvider"`
|
||||
OpenAiBaseUrl string `json:"openAiBaseUrl"`
|
||||
OpenAiApiKey string `json:"openAiApiKey"`
|
||||
OpenAiModelId string `json:"openAiModelId"`
|
||||
}
|
||||
|
||||
type GlobalSettings struct {
|
||||
AllowedCommands []string `json:"allowedCommands"`
|
||||
Mode string `json:"mode"`
|
||||
CustomModes []string `json:"customModes"`
|
||||
}
|
||||
@@ -2,10 +2,13 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/chaitin/MonkeyCode/backend/domain"
|
||||
"github.com/chaitin/MonkeyCode/backend/ent/rule"
|
||||
@@ -16,15 +19,23 @@ const (
|
||||
ApiContextKey = "session:apikey"
|
||||
)
|
||||
|
||||
type proxyModelKey struct{}
|
||||
|
||||
type ProxyMiddleware struct {
|
||||
usecase domain.ProxyUsecase
|
||||
redis *redis.Client
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewProxyMiddleware(
|
||||
usecase domain.ProxyUsecase,
|
||||
redis *redis.Client,
|
||||
logger *slog.Logger,
|
||||
) *ProxyMiddleware {
|
||||
return &ProxyMiddleware{
|
||||
usecase: usecase,
|
||||
redis: redis,
|
||||
logger: logger.With("module", "ProxyMiddleware"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,21 +50,54 @@ func (p *ProxyMiddleware) Auth() echo.MiddlewareFunc {
|
||||
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
|
||||
}
|
||||
|
||||
key, err := p.usecase.ValidateApiKey(c.Request().Context(), apiKey)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
|
||||
ctx := c.Request().Context()
|
||||
p.logger.With("apiKey", apiKey).DebugContext(ctx, "v1 auth")
|
||||
if strings.Contains(apiKey, ".") {
|
||||
s, err := p.redis.Get(ctx, apiKey).Result()
|
||||
if err != nil {
|
||||
p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to get api key from redis")
|
||||
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
|
||||
}
|
||||
var model *domain.Model
|
||||
if err := json.Unmarshal([]byte(s), &model); err != nil {
|
||||
p.logger.With("fn", "Auth").With("error", err).ErrorContext(ctx, "failed to unmarshal model from redis")
|
||||
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
|
||||
}
|
||||
parts := strings.Split(apiKey, ".")
|
||||
if len(parts) != 2 {
|
||||
p.logger.With("fn", "Auth").With("apiKey", apiKey).ErrorContext(ctx, "invalid api key")
|
||||
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
|
||||
}
|
||||
ctx = context.WithValue(ctx, proxyModelKey{}, model)
|
||||
ctx = context.WithValue(ctx, logger.UserIDKey{}, parts[0])
|
||||
c.Set(ApiContextKey, &domain.ApiKey{
|
||||
UserID: parts[0],
|
||||
Key: apiKey,
|
||||
})
|
||||
} else {
|
||||
key, err := p.usecase.ValidateApiKey(ctx, apiKey)
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusUnauthorized, echo.Map{"error": "Unauthorized"})
|
||||
}
|
||||
ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID)
|
||||
c.Set(ApiContextKey, key)
|
||||
}
|
||||
|
||||
ctx := c.Request().Context()
|
||||
ctx = context.WithValue(ctx, logger.UserIDKey{}, key.UserID)
|
||||
ctx = rule.SkipPermission(ctx)
|
||||
c.SetRequest(c.Request().WithContext(ctx))
|
||||
c.Set(ApiContextKey, key)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetProxyModel(ctx context.Context) *domain.Model {
|
||||
m := ctx.Value(proxyModelKey{})
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return m.(*domain.Model)
|
||||
}
|
||||
|
||||
func GetApiKey(c echo.Context) *domain.ApiKey {
|
||||
i := c.Get(ApiContextKey)
|
||||
if i == nil {
|
||||
|
||||
@@ -30,15 +30,16 @@ func NewModelRepo(db *db.Client) domain.ModelRepo {
|
||||
return &ModelRepo{db: db, cache: cache}
|
||||
}
|
||||
|
||||
func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType) (*db.Model, error) {
|
||||
func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType) ([]*db.Model, error) {
|
||||
if v, ok := r.cache.Get(string(modelType)); ok {
|
||||
return v.(*db.Model), nil
|
||||
return v.([]*db.Model), nil
|
||||
}
|
||||
|
||||
m, err := r.db.Model.Query().
|
||||
Where(model.ModelType(modelType)).
|
||||
Where(model.Status(consts.ModelStatusActive)).
|
||||
Only(ctx)
|
||||
Where(model.StatusIn(consts.ModelStatusActive, consts.ModelStatusDefault)).
|
||||
Order(ByStatusOrder()).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -47,14 +48,22 @@ func (r *ModelRepo) GetWithCache(ctx context.Context, modelType consts.ModelType
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func ByStatusOrder() func(s *sql.Selector) {
|
||||
return func(s *sql.Selector) {
|
||||
s.OrderExprFunc(func(b *sql.Builder) {
|
||||
b.WriteString("case when status = 'default' then 3 when status = 'active' then 2 else 1 end desc")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ModelRepo) Create(ctx context.Context, m *domain.CreateModelReq) (*db.Model, error) {
|
||||
n, err := r.db.Model.Query().Where(model.ModelType(m.ModelType)).Count(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
status := consts.ModelStatusInactive
|
||||
status := consts.ModelStatusActive
|
||||
if n == 0 {
|
||||
status = consts.ModelStatusActive
|
||||
status = consts.ModelStatusDefault
|
||||
}
|
||||
|
||||
r.cache.Delete(string(m.ModelType))
|
||||
|
||||
@@ -108,14 +108,27 @@ func (m *ModelUsecase) Update(ctx context.Context, req *domain.UpdateModelReq) (
|
||||
up.SetShowName(*req.ShowName)
|
||||
}
|
||||
if req.Status != nil {
|
||||
if *req.Status == consts.ModelStatusActive {
|
||||
if *req.Status == consts.ModelStatusDefault {
|
||||
if err := tx.Model.Update().
|
||||
Where(model.Status(consts.ModelStatusDefault)).
|
||||
Where(model.ModelType(old.ModelType)).
|
||||
SetStatus(consts.ModelStatusInactive).
|
||||
SetStatus(consts.ModelStatusActive).
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if *req.Status == consts.ModelStatusActive {
|
||||
n, err := tx.Model.Query().
|
||||
Where(model.Status(consts.ModelStatusDefault)).
|
||||
Where(model.ModelType(old.ModelType)).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
*req.Status = consts.ModelStatusDefault
|
||||
}
|
||||
}
|
||||
up.SetStatus(*req.Status)
|
||||
}
|
||||
if req.Param != nil {
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"html/template"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/chaitin/MonkeyCode/backend/ent/types"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/chaitin/MonkeyCode/backend/config"
|
||||
"github.com/chaitin/MonkeyCode/backend/consts"
|
||||
"github.com/chaitin/MonkeyCode/backend/db"
|
||||
"github.com/chaitin/MonkeyCode/backend/domain"
|
||||
"github.com/chaitin/MonkeyCode/backend/ent/types"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
|
||||
)
|
||||
|
||||
type OpenAIUsecase struct {
|
||||
@@ -20,6 +23,7 @@ type OpenAIUsecase struct {
|
||||
modelRepo domain.ModelRepo
|
||||
cfg *config.Config
|
||||
logger *slog.Logger
|
||||
redis *redis.Client
|
||||
}
|
||||
|
||||
func NewOpenAIUsecase(
|
||||
@@ -27,12 +31,14 @@ func NewOpenAIUsecase(
|
||||
repo domain.OpenAIRepo,
|
||||
modelRepo domain.ModelRepo,
|
||||
logger *slog.Logger,
|
||||
redis *redis.Client,
|
||||
) domain.OpenAIUsecase {
|
||||
return &OpenAIUsecase{
|
||||
repo: repo,
|
||||
modelRepo: modelRepo,
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
redis: redis,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,10 +48,6 @@ func (u *OpenAIUsecase) ModelList(ctx context.Context) (*domain.ModelListResp, e
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, v := range models {
|
||||
u.logger.DebugContext(ctx, "model", slog.Any("model", v))
|
||||
}
|
||||
|
||||
resp := &domain.ModelListResp{
|
||||
Object: "list",
|
||||
Data: cvt.Iter(models, func(_ int, m *db.Model) *domain.ModelData {
|
||||
@@ -61,44 +63,106 @@ func (u *OpenAIUsecase) GetConfig(ctx context.Context, req *domain.ConfigReq) (*
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
llm, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeLLM)
|
||||
llms, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeLLM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
coder, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeCoder)
|
||||
coders, err := u.modelRepo.GetWithCache(ctx, consts.ModelTypeCoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u.logger.With(
|
||||
"llms", len(llms),
|
||||
"coders", len(coders),
|
||||
).DebugContext(ctx, "get config")
|
||||
|
||||
if len(llms) == 0 || len(coders) == 0 {
|
||||
return nil, errors.New("no model")
|
||||
}
|
||||
|
||||
llm := llms[0]
|
||||
coder := coders[0]
|
||||
coderkey := fmt.Sprintf("%s.%s", apiKey.UserID.String(), coder.ID.String())
|
||||
if err = u.redis.Get(ctx, coderkey).Err(); err != nil {
|
||||
b, err := json.Marshal(cvt.From(coder, &domain.Model{}))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = u.redis.Set(ctx, coderkey, string(b), time.Hour*24).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if llm.Parameters == nil {
|
||||
llm.Parameters = types.DefaultModelParam()
|
||||
}
|
||||
|
||||
t, err := template.New("config").Parse(string(config.ConfigTmpl))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
config := &domain.PluginConfig{
|
||||
ProviderProfiles: domain.ProviderProfiles{
|
||||
CurrentApiConfigName: "default",
|
||||
ApiConfigs: map[string]domain.ApiConfig{},
|
||||
ModeApiConfigs: map[string]string{
|
||||
"code": "59admorkig4",
|
||||
"architect": "59admorkig4",
|
||||
"ask": "59admorkig4",
|
||||
"debug": "59admorkig4",
|
||||
"deepresearch": "59admorkig4",
|
||||
},
|
||||
Migrations: domain.Migrations{
|
||||
RateLimitSecondsMigrated: true,
|
||||
DiffSettingsMigrated: true,
|
||||
},
|
||||
},
|
||||
CtcodeTabCompletions: domain.CtcodeTabCompletions{
|
||||
Enabled: true,
|
||||
ApiProvider: "openai",
|
||||
OpenAiBaseUrl: req.BaseURL + "/v1",
|
||||
OpenAiApiKey: coderkey,
|
||||
OpenAiModelId: coder.ModelName,
|
||||
},
|
||||
}
|
||||
|
||||
u.logger.With("param", llm.Parameters).DebugContext(ctx, "get config")
|
||||
cnt := bytes.NewBuffer(nil)
|
||||
data := map[string]any{
|
||||
"apiBase": req.BaseURL,
|
||||
"apikey": apiKey.Key,
|
||||
"chatModel": llm.ModelName,
|
||||
"codeModel": coder.ModelName,
|
||||
"r1Enabled": llm.Parameters.R1Enabled,
|
||||
"maxTokens": llm.Parameters.MaxTokens,
|
||||
"contextWindow": llm.Parameters.ContextWindow,
|
||||
"supportsImages": llm.Parameters.SupprtImages,
|
||||
"supportsComputerUse": llm.Parameters.SupportComputerUse,
|
||||
"supportsPromptCache": llm.Parameters.SupportPromptCache,
|
||||
}
|
||||
if err := t.Execute(cnt, data); err != nil {
|
||||
return nil, err
|
||||
for _, m := range llms {
|
||||
key := fmt.Sprintf("%s.%s", apiKey.UserID.String(), m.ID.String())
|
||||
if m.Parameters == nil {
|
||||
m.Parameters = types.DefaultModelParam()
|
||||
}
|
||||
name := fmt.Sprintf("%s (%s)", m.ModelName, m.Provider)
|
||||
if m.Status == consts.ModelStatusDefault {
|
||||
name = "default"
|
||||
}
|
||||
config.ProviderProfiles.ApiConfigs[name] = domain.ApiConfig{
|
||||
ApiProvider: "openai",
|
||||
ApiModelId: m.ModelName,
|
||||
OpenAiBaseUrl: req.BaseURL + "/v1",
|
||||
OpenAiApiKey: key,
|
||||
OpenAiModelId: m.ModelName,
|
||||
OpenAiR1FormatEnabled: m.Parameters.R1Enabled,
|
||||
OpenAiCustomModelInfo: domain.OpenAiCustomModelInfo{
|
||||
MaxTokens: m.Parameters.MaxTokens,
|
||||
ContextWindow: m.Parameters.ContextWindow,
|
||||
SupportsImages: m.Parameters.SupprtImages,
|
||||
SupportsComputerUse: m.Parameters.SupportComputerUse,
|
||||
SupportsPromptCache: m.Parameters.SupportPromptCache,
|
||||
},
|
||||
Id: m.ID.String(),
|
||||
}
|
||||
|
||||
if err = u.redis.Get(ctx, key).Err(); err == nil {
|
||||
continue
|
||||
}
|
||||
b, err := json.Marshal(cvt.From(m, &domain.Model{}))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := u.redis.Set(ctx, key, string(b), time.Hour*24).Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &domain.ConfigResp{
|
||||
Type: req.Type,
|
||||
Content: cnt.String(),
|
||||
Content: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/chaitin/MonkeyCode/backend/config"
|
||||
"github.com/chaitin/MonkeyCode/backend/consts"
|
||||
"github.com/chaitin/MonkeyCode/backend/domain"
|
||||
"github.com/chaitin/MonkeyCode/backend/internal/middleware"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/logger"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/tee"
|
||||
)
|
||||
@@ -96,11 +97,16 @@ func (l *LLMProxy) rewrite(r *httputil.ProxyRequest) {
|
||||
return
|
||||
}
|
||||
|
||||
m, err := l.usecase.SelectModelWithLoadBalancing("", mt)
|
||||
if err != nil {
|
||||
l.logger.ErrorContext(r.In.Context(), "select model with load balancing failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
|
||||
return
|
||||
var m *domain.Model
|
||||
var err error
|
||||
if m = middleware.GetProxyModel(r.In.Context()); m == nil {
|
||||
m, err = l.usecase.SelectModelWithLoadBalancing("", mt)
|
||||
if err != nil {
|
||||
l.logger.ErrorContext(r.In.Context(), "select model with load balancing failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ul, err := url.Parse(m.APIBase)
|
||||
if err != nil {
|
||||
l.logger.ErrorContext(r.In.Context(), "parse model api base failed", slog.String("path", r.In.URL.Path), slog.Any("err", err))
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/chaitin/MonkeyCode/backend/config"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"github.com/chaitin/MonkeyCode/backend/db"
|
||||
"github.com/chaitin/MonkeyCode/backend/domain"
|
||||
"github.com/chaitin/MonkeyCode/backend/ent/rule"
|
||||
"github.com/chaitin/MonkeyCode/backend/internal/middleware"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/queuerunner"
|
||||
"github.com/chaitin/MonkeyCode/backend/pkg/request"
|
||||
@@ -85,7 +87,16 @@ func (p *ProxyUsecase) SelectModelWithLoadBalancing(modelName string, modelType
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cvt.From(model, &domain.Model{}), nil
|
||||
if len(model) == 0 {
|
||||
return nil, fmt.Errorf("no model found")
|
||||
}
|
||||
m := model[0]
|
||||
for _, mm := range model {
|
||||
if mm.Status == consts.ModelStatusDefault {
|
||||
m = mm
|
||||
}
|
||||
}
|
||||
return cvt.From(m, &domain.Model{}), nil
|
||||
}
|
||||
|
||||
func (p *ProxyUsecase) ValidateApiKey(ctx context.Context, key string) (*domain.ApiKey, error) {
|
||||
@@ -102,12 +113,33 @@ func (p *ProxyUsecase) AcceptCompletion(ctx context.Context, req *domain.AcceptC
|
||||
|
||||
func (p *ProxyUsecase) Report(ctx context.Context, req *domain.ReportReq) error {
|
||||
var model *db.Model
|
||||
var err error
|
||||
if req.Action == consts.ReportActionNewTask {
|
||||
model, err = p.modelRepo.GetWithCache(context.Background(), consts.ModelTypeLLM)
|
||||
if err != nil {
|
||||
p.logger.With("fn", "Report").With("error", err).ErrorContext(ctx, "failed to get model")
|
||||
return err
|
||||
m := middleware.GetProxyModel(ctx)
|
||||
if m == nil {
|
||||
ms, err := p.modelRepo.GetWithCache(ctx, consts.ModelTypeLLM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get model with cache failed: %w", err)
|
||||
}
|
||||
if len(ms) == 0 {
|
||||
return fmt.Errorf("no model found")
|
||||
}
|
||||
model = ms[0]
|
||||
for _, mm := range ms {
|
||||
if mm.Status == consts.ModelStatusDefault {
|
||||
model = mm
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mid, err := uuid.Parse(m.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse proxy model id failed: %w", err)
|
||||
}
|
||||
model = &db.Model{
|
||||
ID: mid,
|
||||
ModelName: m.ModelName,
|
||||
ModelType: m.ModelType,
|
||||
}
|
||||
}
|
||||
}
|
||||
return p.repo.Report(ctx, model, req)
|
||||
|
||||
Submodule backend/pro updated: 482ab56daf...a566152111
@@ -19,6 +19,7 @@ export enum GithubComChaitinMonkeyCodeBackendConstsModelType {
|
||||
}
|
||||
|
||||
export enum GithubComChaitinMonkeyCodeBackendConstsModelStatus {
|
||||
ModelStatusDefault = "default",
|
||||
ModelStatusActive = "active",
|
||||
ModelStatusInactive = "inactive",
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import {
|
||||
} from '@/api/Model';
|
||||
import { DomainModel, GithubComChaitinMonkeyCodeBackendConstsModelStatus, GithubComChaitinMonkeyCodeBackendConstsModelType, } from '@/api/types';
|
||||
import { Stack, Box, Button, Grid2 as Grid, ButtonBase } from '@mui/material';
|
||||
import StyledLabel from '@/components/label';
|
||||
import { Icon, Modal, message } from '@c-x/ui';
|
||||
import { addCommasToNumber } from '@/utils';
|
||||
import NoData from '@/assets/images/nodata.png';
|
||||
@@ -76,6 +75,31 @@ const ModelItem = ({
|
||||
});
|
||||
};
|
||||
|
||||
const onSetDefaultModel = () => {
|
||||
Modal.confirm({
|
||||
title: '设为默认模型',
|
||||
content: (
|
||||
<>
|
||||
确定要设置{' '}
|
||||
<Box component='span' sx={{ fontWeight: 700, color: 'text.primary' }}>
|
||||
{data.model_name}
|
||||
</Box>{' '}
|
||||
为默认模型吗?
|
||||
</>
|
||||
),
|
||||
onOk: () => {
|
||||
putUpdateModel({
|
||||
id: data.id,
|
||||
status: GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusDefault,
|
||||
provider: data.provider!,
|
||||
}).then(() => {
|
||||
message.success('设为默认模型成功');
|
||||
refresh();
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const onActiveModel = () => {
|
||||
Modal.confirm({
|
||||
title: '激活模型',
|
||||
@@ -101,6 +125,100 @@ const ModelItem = ({
|
||||
});
|
||||
};
|
||||
|
||||
// 添加状态标签渲染函数
|
||||
const renderStatusLabel = () => {
|
||||
// 根据 is_active 和 status 字段判断状态
|
||||
if (data.is_active && data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusActive) {
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
px: 1.5,
|
||||
py: 0.5,
|
||||
backgroundColor: 'success.main',
|
||||
color: 'success.contrastText',
|
||||
borderRadius: '0 0 0 8px', // 左下角圆角,贴合卡片右上角
|
||||
fontSize: 12,
|
||||
fontWeight: 600,
|
||||
boxShadow: '0 2px 4px rgba(0,0,0,0.1)',
|
||||
position: 'relative',
|
||||
'&::before': {
|
||||
content: '""',
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
right: 0,
|
||||
width: 0,
|
||||
height: 0,
|
||||
borderLeft: '6px solid transparent',
|
||||
borderTop: '6px solid',
|
||||
borderTopColor: 'success.dark',
|
||||
}
|
||||
}}
|
||||
>
|
||||
可选
|
||||
</Box>
|
||||
);
|
||||
} else if (data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusInactive) {
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
px: 1.5,
|
||||
py: 0.5,
|
||||
backgroundColor: 'grey.400',
|
||||
color: 'grey.50',
|
||||
borderRadius: '0 0 0 8px',
|
||||
fontSize: 12,
|
||||
fontWeight: 600,
|
||||
boxShadow: '0 2px 4px rgba(0,0,0,0.1)',
|
||||
position: 'relative',
|
||||
'&::before': {
|
||||
content: '""',
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
right: 0,
|
||||
width: 0,
|
||||
height: 0,
|
||||
borderLeft: '6px solid transparent',
|
||||
borderTop: '6px solid',
|
||||
borderTopColor: 'grey.600',
|
||||
}
|
||||
}}
|
||||
>
|
||||
未激活
|
||||
</Box>
|
||||
);
|
||||
} else {
|
||||
// 默认状态
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
px: 1.5,
|
||||
py: 0.5,
|
||||
backgroundColor: 'primary.main',
|
||||
color: 'primary.contrastText',
|
||||
borderRadius: '0 0 0 8px',
|
||||
fontSize: 12,
|
||||
fontWeight: 600,
|
||||
boxShadow: '0 2px 4px rgba(0,0,0,0.1)',
|
||||
position: 'relative',
|
||||
'&::before': {
|
||||
content: '""',
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
right: 0,
|
||||
width: 0,
|
||||
height: 0,
|
||||
borderLeft: '6px solid transparent',
|
||||
borderTop: '6px solid',
|
||||
borderTopColor: 'primary.dark',
|
||||
}
|
||||
}}
|
||||
>
|
||||
默认
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Card
|
||||
sx={{
|
||||
@@ -109,7 +227,7 @@ const ModelItem = ({
|
||||
transition: 'all 0.3s ease',
|
||||
borderStyle: 'solid',
|
||||
borderWidth: '1px',
|
||||
borderColor: data.is_active ? 'success.main' : 'transparent',
|
||||
borderColor: 'transparent',
|
||||
boxShadow:
|
||||
'0px 0px 10px 0px rgba(68, 80, 91, 0.1), 0px 0px 2px 0px rgba(68, 80, 91, 0.1)',
|
||||
'&:hover': {
|
||||
@@ -118,6 +236,18 @@ const ModelItem = ({
|
||||
},
|
||||
}}
|
||||
>
|
||||
{/* 美化的右上角状态标签 */}
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
right: 0,
|
||||
zIndex: 2,
|
||||
}}
|
||||
>
|
||||
{renderStatusLabel()}
|
||||
</Box>
|
||||
|
||||
<Stack
|
||||
direction='row'
|
||||
alignItems='center'
|
||||
@@ -129,7 +259,7 @@ const ModelItem = ({
|
||||
type={
|
||||
DEFAULT_MODEL_PROVIDERS[data.provider as keyof typeof DEFAULT_MODEL_PROVIDERS]?.icon
|
||||
}
|
||||
sx={{ fontSize: 24 }}
|
||||
sx={{ fontSize: 24, color: data.is_active ? 'inherit' : 'grey.400' }}
|
||||
/>
|
||||
<Stack
|
||||
direction='row'
|
||||
@@ -143,6 +273,7 @@ const ModelItem = ({
|
||||
overflow: 'hidden',
|
||||
textOverflow: 'ellipsis',
|
||||
whiteSpace: 'nowrap',
|
||||
color: data.is_active ? 'inherit' : 'grey.400',
|
||||
}}
|
||||
>
|
||||
{data.show_name || '未命名'}
|
||||
@@ -195,11 +326,25 @@ const ModelItem = ({
|
||||
gap={2}
|
||||
sx={{ mt: 2 }}
|
||||
>
|
||||
<Stack direction='row' alignItems='center'>
|
||||
<StyledLabel color={data.is_active ? 'success' : 'disabled'}>{data.is_active ? '正在使用' : '未激活'}</StyledLabel>
|
||||
</Stack>
|
||||
<Stack direction='row' alignItems='center'> </Stack>
|
||||
<Stack direction='row' sx={{ button: { minWidth: 0 } }} gap={2}>
|
||||
{!data.is_active && (
|
||||
{(data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusActive ||
|
||||
data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusInactive) && (
|
||||
<ButtonBase
|
||||
disableRipple
|
||||
sx={{
|
||||
color: 'text.primary',
|
||||
'&:hover': {
|
||||
fontWeight: 700
|
||||
},
|
||||
}}
|
||||
onClick={onSetDefaultModel}
|
||||
>
|
||||
设为默认
|
||||
</ButtonBase>
|
||||
)}
|
||||
|
||||
{data.status === GithubComChaitinMonkeyCodeBackendConstsModelStatus.ModelStatusInactive && (
|
||||
<ButtonBase
|
||||
disableRipple
|
||||
sx={{
|
||||
|
||||
Reference in New Issue
Block a user