diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 630e8b6..bfdf3a2 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -40,6 +40,10 @@ func main() { panic(err) } + if err := s.modelV1.InitModel(); err != nil { + panic(err) + } + svc := service.NewService(service.WithPprof()) svc.Add(s) if err := svc.Run(); err != nil { diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index b4e36d8..d8831bb 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -1,6 +1,6 @@ // Code generated by Wire. DO NOT EDIT. -//go:generate go run -mod=mod github.com/google/wire/cmd/wire +//go:generate go run github.com/google/wire/cmd/wire //go:build !wireinject // +build !wireinject @@ -59,7 +59,7 @@ func newServer(dir string) (*Server, error) { proxyMiddleware := middleware.NewProxyMiddleware(proxyUsecase) v1Handler := v1.NewV1Handler(slogLogger, web, domainProxy, openAIUsecase, proxyMiddleware) modelRepo := repo3.NewModelRepo(client) - modelUsecase := usecase2.NewModelUsecase(modelRepo) + modelUsecase := usecase2.NewModelUsecase(slogLogger, modelRepo, configConfig) sessionSession := session.NewSession(configConfig) authMiddleware := middleware.NewAuthMiddleware(sessionSession, slogLogger) modelHandler := v1_2.NewModelHandler(web, modelUsecase, authMiddleware, slogLogger) diff --git a/backend/config/config.go b/backend/config/config.go index c3a86bf..f4080c5 100644 --- a/backend/config/config.go +++ b/backend/config/config.go @@ -60,6 +60,12 @@ type Config struct { VSCode struct { VSIXFile string `mapstructure:"vsix_file"` } `mapstructure:"vscode"` + + InitModel struct { + ModelName string `mapstructure:"model_name"` + ModelKey string `mapstructure:"model_key"` + ModelURL string `mapstructure:"model_url"` + } `mapstructure:"init_model"` } func Init(dir string) (*Config, error) { diff --git a/backend/config/config.yaml b/backend/config/config.yaml index b3f4a1e..6f102b7 100644 --- a/backend/config/config.yaml +++ b/backend/config/config.yaml @@ -27,4 +27,8 @@ llm_proxy: client_pool_size: 10 request_log_path: /app/request/logs vscode: - vsix_file: /app/static/monkeycode.vsix \ No newline at end of file + vsix_file: /app/static/monkeycode.vsix +init_model: + model_name: "" + model_key: "" + model_url: "" \ No newline at end of file diff --git a/backend/domain/model.go b/backend/domain/model.go index 039666f..583f26b 100644 --- a/backend/domain/model.go +++ b/backend/domain/model.go @@ -16,6 +16,7 @@ type ModelUsecase interface { Update(ctx context.Context, req *UpdateModelReq) (*Model, error) Check(ctx context.Context, req *CheckModelReq) (*Model, error) GetTokenUsage(ctx context.Context, modelType consts.ModelType) (*ModelTokenUsageResp, error) + InitModel(ctx context.Context) error } type ModelRepo interface { @@ -25,6 +26,7 @@ type ModelRepo interface { MyModelList(ctx context.Context, req *MyModelListReq) ([]*db.Model, error) ModelUsage(ctx context.Context, ids []uuid.UUID) (map[uuid.UUID]ModelUsage, error) GetTokenUsage(ctx context.Context, modelType consts.ModelType) (*ModelTokenUsageResp, error) + InitModel(ctx context.Context, modelName, modelKey, modelURL string) error } type MyModelListReq struct { diff --git a/backend/internal/model/handler/http/v1/model.go b/backend/internal/model/handler/http/v1/model.go index 432d682..cfd929a 100644 --- a/backend/internal/model/handler/http/v1/model.go +++ b/backend/internal/model/handler/http/v1/model.go @@ -1,6 +1,7 @@ package v1 import ( + "context" "log/slog" "github.com/GoYoko/web" @@ -151,3 +152,7 @@ func (h *ModelHandler) GetTokenUsage(c *web.Context, req domain.GetTokenUsageReq } return c.Success(resp) } + +func (h *ModelHandler) InitModel() error { + return h.usecase.InitModel(context.Background()) +} diff --git a/backend/internal/model/repo/model.go b/backend/internal/model/repo/model.go index b3ad010..c12d6e9 100644 --- a/backend/internal/model/repo/model.go +++ b/backend/internal/model/repo/model.go @@ -9,6 +9,7 @@ import ( "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/db" + "github.com/chaitin/MonkeyCode/backend/db/admin" "github.com/chaitin/MonkeyCode/backend/db/model" "github.com/chaitin/MonkeyCode/backend/db/task" "github.com/chaitin/MonkeyCode/backend/domain" @@ -177,3 +178,31 @@ func (r *ModelRepo) List(ctx context.Context) (*domain.AllModelResp, error) { } return resp, nil } + +func (r *ModelRepo) InitModel(ctx context.Context, modelName, modelKey, modelURL string) error { + n, err := r.db.Model.Query(). + Where(model.ModelName(modelName)). + Where(model.Provider("百智云")). + Count(ctx) + if err != nil { + return err + } + if n > 0 { + return nil + } + + a, err := r.db.Admin.Query().Where(admin.Username("admin")).Only(ctx) + if err != nil { + return err + } + + return r.db.Model.Create(). + SetAPIKey(modelKey). + SetModelName(modelName). + SetModelType(consts.ModelTypeCoder). + SetAPIBase(modelURL). + SetProvider("百智云"). + SetStatus(consts.ModelStatusActive). + SetUserID(a.ID). + Exec(ctx) +} diff --git a/backend/internal/model/usecase/model.go b/backend/internal/model/usecase/model.go index b79fe4d..9f13027 100644 --- a/backend/internal/model/usecase/model.go +++ b/backend/internal/model/usecase/model.go @@ -2,9 +2,11 @@ package usecase import ( "context" + "log/slog" "github.com/google/uuid" + "github.com/chaitin/MonkeyCode/backend/config" "github.com/chaitin/MonkeyCode/backend/db/model" "github.com/chaitin/MonkeyCode/backend/pkg/cvt" @@ -14,11 +16,17 @@ import ( ) type ModelUsecase struct { - repo domain.ModelRepo + logger *slog.Logger + repo domain.ModelRepo + cfg *config.Config } -func NewModelUsecase(repo domain.ModelRepo) domain.ModelUsecase { - return &ModelUsecase{repo: repo} +func NewModelUsecase( + logger *slog.Logger, + repo domain.ModelRepo, + cfg *config.Config, +) domain.ModelUsecase { + return &ModelUsecase{repo: repo, cfg: cfg, logger: logger} } func (m *ModelUsecase) Check(ctx context.Context, req *domain.CheckModelReq) (*domain.Model, error) { @@ -98,3 +106,8 @@ func (m *ModelUsecase) Update(ctx context.Context, req *domain.UpdateModelReq) ( } return cvt.From(model, &domain.Model{}), nil } + +func (m *ModelUsecase) InitModel(ctx context.Context) error { + m.logger.With("init_model", m.cfg.InitModel).Debug("init model") + return m.repo.InitModel(ctx, m.cfg.InitModel.ModelName, m.cfg.InitModel.ModelKey, m.cfg.InitModel.ModelURL) +}