Files
MonkeyCode/backend/internal/proxy/usecase/proxy.go
2025-08-07 22:04:00 +08:00

179 lines
5.9 KiB
Go

package usecase
import (
"context"
"fmt"
"log/slog"
"os"
"path"
"strings"
"time"
"github.com/redis/go-redis/v9"
"github.com/chaitin/MonkeyCode/backend/config"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
"github.com/chaitin/MonkeyCode/backend/pkg/queuerunner"
"github.com/chaitin/MonkeyCode/backend/pkg/request"
"github.com/chaitin/MonkeyCode/backend/pkg/scan"
)
type ProxyUsecase struct {
repo domain.ProxyRepo
modelRepo domain.ModelRepo
securityRepo domain.SecurityScanningRepo
logger *slog.Logger
queuerunner *queuerunner.QueueRunner[domain.CreateSecurityScanningReq]
client *request.Client
}
func NewProxyUsecase(
repo domain.ProxyRepo,
modelRepo domain.ModelRepo,
securityRepo domain.SecurityScanningRepo,
logger *slog.Logger,
cfg *config.Config,
redis *redis.Client,
) domain.ProxyUsecase {
client := request.NewClient("http", "monkeycode-scanner:8888", 15*time.Second)
client.SetDebug(cfg.Debug)
p := &ProxyUsecase{
repo: repo,
modelRepo: modelRepo,
securityRepo: securityRepo,
logger: logger.With("module", "ProxyUsecase"),
queuerunner: queuerunner.NewQueueRunner[domain.CreateSecurityScanningReq](cfg, redis, logger),
client: client,
}
go p.queuerunner.Run(context.Background())
go p.requeue()
return p
}
func (p *ProxyUsecase) requeue() {
scannings, err := p.securityRepo.AllRunning(context.Background())
if err != nil {
p.logger.With("fn", "requeue").With("error", err).ErrorContext(context.Background(), "failed to get running scannings")
return
}
for _, scanning := range scannings {
p.queuerunner.Enqueue(context.Background(), scanning.ID.String(), domain.CreateSecurityScanningReq{
UserID: scanning.UserID.String(),
Workspace: scanning.Workspace,
Language: consts.SecurityScanningLanguage(scanning.Rule),
}, p.TaskHandle)
}
}
func (p *ProxyUsecase) Record(ctx context.Context, record *domain.RecordParam) error {
return p.repo.Record(ctx, record)
}
// SelectModelWithLoadBalancing implements domain.ProxyUsecase.
func (p *ProxyUsecase) SelectModelWithLoadBalancing(modelName string, modelType consts.ModelType) (*domain.Model, error) {
model, err := p.modelRepo.GetWithCache(context.Background(), modelType)
if err != nil {
return nil, err
}
return cvt.From(model, &domain.Model{}), nil
}
func (p *ProxyUsecase) ValidateApiKey(ctx context.Context, key string) (*domain.ApiKey, error) {
apiKey, err := p.repo.ValidateApiKey(ctx, key)
if err != nil {
return nil, err
}
return cvt.From(apiKey, &domain.ApiKey{}), nil
}
func (p *ProxyUsecase) AcceptCompletion(ctx context.Context, req *domain.AcceptCompletionReq) error {
return p.repo.AcceptCompletion(ctx, req)
}
func (p *ProxyUsecase) Report(ctx context.Context, req *domain.ReportReq) error {
var model *db.Model
var err error
if req.Action == consts.ReportActionNewTask {
model, err = p.modelRepo.GetWithCache(context.Background(), consts.ModelTypeLLM)
if err != nil {
p.logger.With("fn", "Report").With("error", err).ErrorContext(ctx, "failed to get model")
return err
}
}
return p.repo.Report(ctx, model, req)
}
func (p *ProxyUsecase) CreateSecurityScanning(ctx context.Context, req *domain.CreateSecurityScanningReq) (string, error) {
id, err := p.securityRepo.Create(ctx, *req)
if err != nil {
return "", err
}
return p.queuerunner.Enqueue(ctx, id, *req, p.TaskHandle)
}
func (p *ProxyUsecase) TaskHandle(ctx context.Context, task *queuerunner.Task[domain.CreateSecurityScanningReq]) error {
id := task.ID
if err := p.securityRepo.Update(ctx, id, consts.SecurityScanningStatusRunning, nil); err != nil {
p.logger.With("id", task.ID).With("error", err).ErrorContext(ctx, "failed to update security scanning")
return err
}
p.logger.With("id", id).DebugContext(ctx, "task started")
// 落盘文件
scanning, err := p.securityRepo.Get(ctx, id)
if err != nil {
p.logger.With("id", id).With("error", err).ErrorContext(ctx, "failed to get security scanning")
return err
}
prefix := fmt.Sprintf("/app/codes/%s", id)
rootPath := path.Join(prefix, scanning.Edges.WorkspaceEdge.RootPath)
defer os.RemoveAll(prefix)
if err = p.securityRepo.PageWorkspaceFiles(ctx, scanning.WorkspaceID.String(), 20, func(rs []*db.WorkspaceFile) error {
for _, r := range rs {
filename := path.Join(rootPath, r.Path)
dir := path.Dir(filename)
p.logger.With("path", dir).DebugContext(ctx, "create dir")
if err = os.MkdirAll(dir, 0755); err != nil {
p.logger.With("path", dir).With("id", id).With("error", err).ErrorContext(ctx, "failed to create dir")
continue
}
if err = os.WriteFile(filename, []byte(r.Content), 0644); err != nil {
p.logger.With("path", filename).With("id", id).With("error", err).ErrorContext(ctx, "failed to write file")
continue
}
}
return nil
}); err != nil {
return err
}
rule := strings.ToLower(string(scanning.Language))
result, err := scan.Scan(task.ID, rootPath, rule)
if err != nil {
if err = p.securityRepo.Update(ctx, id, consts.SecurityScanningStatusFailed, &scan.Result{
Output: err.Error(),
}); err != nil {
p.logger.With("id", task.ID).With("error", err).ErrorContext(ctx, "failed to update security scanning")
}
p.logger.With("id", task.ID).With("error", err).ErrorContext(ctx, "failed to scan")
return err
}
result.Prefix = prefix
if err := p.securityRepo.Update(ctx, id, consts.SecurityScanningStatusSuccess, result); err != nil {
p.logger.With("id", task.ID).With("error", err).ErrorContext(ctx, "failed to update security scanning")
return err
}
p.logger.With("id", task.ID).DebugContext(ctx, "task done")
return nil
}
func (p *ProxyUsecase) ListSecurityScanning(ctx context.Context, req *domain.ListSecurityScanningReq) (*domain.ListSecurityScanningBriefResp, error) {
return p.securityRepo.ListBrief(ctx, *req)
}