Files
MonkeyCode/backend/internal/security/repo/security.go
2025-08-18 14:57:29 +08:00

350 lines
9.6 KiB
Go

package repo
import (
"context"
"fmt"
"strings"
"entgo.io/ent/dialect/sql"
"github.com/google/uuid"
"github.com/chaitin/MonkeyCode/backend/consts"
"github.com/chaitin/MonkeyCode/backend/db"
"github.com/chaitin/MonkeyCode/backend/db/securityscanning"
"github.com/chaitin/MonkeyCode/backend/db/securityscanningresult"
"github.com/chaitin/MonkeyCode/backend/db/workspace"
"github.com/chaitin/MonkeyCode/backend/db/workspacefile"
"github.com/chaitin/MonkeyCode/backend/domain"
"github.com/chaitin/MonkeyCode/backend/ent/rule"
"github.com/chaitin/MonkeyCode/backend/ent/types"
"github.com/chaitin/MonkeyCode/backend/pkg/cvt"
"github.com/chaitin/MonkeyCode/backend/pkg/entx"
"github.com/chaitin/MonkeyCode/backend/pkg/scan"
)
type SecurityScanningRepo struct {
db *db.Client
}
func NewSecurityScanningRepo(db *db.Client) domain.SecurityScanningRepo {
return &SecurityScanningRepo{
db: db,
}
}
// Create implements domain.SecurityScanningRepo.
func (s *SecurityScanningRepo) Create(ctx context.Context, req domain.CreateSecurityScanningReq) (string, error) {
id := uuid.New()
uid, err := uuid.Parse(req.UserID)
if err != nil {
return "", err
}
w, err := s.db.Workspace.Query().
Where(workspace.UserID(uid)).
Where(workspace.RootPath(req.Workspace)).
First(ctx)
if err != nil {
return "", err
}
_, err = s.db.SecurityScanning.Create().
SetID(id).
SetUserID(uid).
SetWorkspaceID(w.ID).
SetLanguage(req.Language).
SetRule(req.Language.RuleName()).
SetWorkspace(req.Workspace).
SetStatus(consts.SecurityScanningStatusPending).
Save(ctx)
if err != nil {
return "", err
}
return id.String(), nil
}
// Update implements domain.SecurityScanningRepo.
func (s *SecurityScanningRepo) Update(ctx context.Context, id string, fileMap map[string]string, status consts.SecurityScanningStatus, result *scan.Result) error {
uid, err := uuid.Parse(id)
if err != nil {
return err
}
return entx.WithTx(ctx, s.db, func(tx *db.Tx) error {
up := s.db.SecurityScanning.Update().
Where(securityscanning.ID(uid)).
SetStatus(status)
if result != nil && result.Output != "" {
up.SetErrorMessage(result.Output)
}
if err := up.Exec(ctx); err != nil {
return err
}
if result == nil {
return nil
}
cs := make([]*db.SecurityScanningResultCreate, 0)
for _, item := range result.Results {
c := s.db.SecurityScanningResult.Create().
SetSecurityScanningID(uid).
SetCheckID(item.CheckID).
SetEngineKind(item.Extra.EngineKind).
SetLines(item.Extra.Lines).
SetMessage(item.Extra.Message).
SetMessageZh(item.Extra.Metadata.MessageZh).
SetSeverity(item.Extra.Severity).
SetAbstractEn(item.Extra.Metadata.AbstractFeysh["en-US"]).
SetAbstractZh(item.Extra.Metadata.AbstractFeysh["zh-CN"]).
SetCategoryEn(item.Extra.Metadata.CategoryFeysh["en-US"]).
SetCategoryZh(item.Extra.Metadata.CategoryFeysh["zh-CN"]).
SetConfidence(item.Extra.Metadata.Confidence).
SetCwe([]any{item.Extra.Metadata.Cwe}).
SetImpact(item.Extra.Metadata.Impact).
SetOwasp([]any{item.Extra.Metadata.Owasp}).
SetPath(strings.ReplaceAll(item.Path, result.Prefix, "")).
SetFileContent(fileMap[item.Path]).
SetStartPosition(&types.Position{
Col: item.Start.Col,
Line: item.Start.Line,
Offset: item.Start.Offset,
}).
SetEndPosition(&types.Position{
Col: item.End.Col,
Line: item.End.Line,
Offset: item.End.Offset,
})
cs = append(cs, c)
if len(cs) >= 10 {
if err := s.db.SecurityScanningResult.CreateBulk(cs...).Exec(ctx); err != nil {
return err
}
cs = cs[:0]
}
}
if len(cs) > 0 {
if err := s.db.SecurityScanningResult.CreateBulk(cs...).Exec(ctx); err != nil {
return err
}
}
return nil
})
}
// List implements domain.SecurityScanningRepo.
func (s *SecurityScanningRepo) List(ctx context.Context, req domain.ListSecurityScanningReq) (*domain.ListSecurityScanningResp, error) {
query := s.db.SecurityScanning.Query().
WithResults().
WithUser()
if req.UserID != "" {
uid, err := uuid.Parse(req.UserID)
if err != nil {
return nil, err
}
query.Where(securityscanning.UserID(uid))
}
scannings, p, err := query.
Order(db.Desc("created_at")).
Page(ctx, int(req.Page), int(req.Size))
if err != nil {
return nil, err
}
ids := cvt.Iter(scannings, func(_ int, s *db.SecurityScanning) uuid.UUID {
return s.ID
})
riskCount, err := s.RiskCountByIDs(ctx, ids)
if err != nil {
return nil, err
}
return &domain.ListSecurityScanningResp{
PageInfo: p,
Items: cvt.Iter(scannings, func(_ int, s *db.SecurityScanning) *domain.SecurityScanningResult {
return cvt.From(s, &domain.SecurityScanningResult{
Risk: riskCount[s.ID],
})
}),
}, nil
}
// ListBrief implements domain.SecurityScanningRepo.
func (s *SecurityScanningRepo) ListBrief(ctx context.Context, req domain.ListSecurityScanningReq) (*domain.ListSecurityScanningBriefResp, error) {
query := s.db.SecurityScanning.Query().
WithUser().
WithResults()
if req.UserID != "" {
uid, err := uuid.Parse(req.UserID)
if err != nil {
return nil, err
}
query.Where(securityscanning.UserID(uid))
}
scannings, p, err := query.
Order(securityscanning.ByCreatedAt(sql.OrderDesc())).
Page(ctx, int(req.Page), int(req.Size))
if err != nil {
return nil, err
}
return &domain.ListSecurityScanningBriefResp{
PageInfo: p,
Items: cvt.Iter(scannings, func(_ int, s *db.SecurityScanning) *domain.SecurityScanningBrief {
return cvt.From(s, &domain.SecurityScanningBrief{
ReportURL: fmt.Sprintf("%s/user/codescan", req.BaseURL),
})
}),
}, nil
}
func (s *SecurityScanningRepo) Detail(ctx context.Context, userID, id string) ([]*domain.SecurityScanningRiskDetail, error) {
sid, err := uuid.Parse(id)
if err != nil {
return nil, err
}
q := s.db.SecurityScanningResult.Query().
Where(securityscanningresult.SecurityScanningID(sid)).
Order(
BySeverityOrder(),
securityscanningresult.ByCreatedAt(sql.OrderDesc()),
)
if userID != "" {
uid, err := uuid.Parse(userID)
if err != nil {
return nil, err
}
q.Where(securityscanningresult.HasSecurityScanningWith(func(s *sql.Selector) {
s.Where(sql.EQ(securityscanning.FieldUserID, uid))
}))
}
scannings, err := q.All(ctx)
if err != nil {
return nil, err
}
rs := cvt.Iter(scannings, func(_ int, r *db.SecurityScanningResult) *domain.SecurityScanningRiskDetail {
return cvt.From(r, &domain.SecurityScanningRiskDetail{})
})
return rs, nil
}
// RiskCountByIDs implements domain.SecurityScanningRepo.
func (s *SecurityScanningRepo) RiskCountByIDs(ctx context.Context, ids []uuid.UUID) (map[uuid.UUID]domain.SecurityScanningRiskResult, error) {
rs := make([]domain.SecurityScanningRiskResult, 0)
if err := s.db.SecurityScanningResult.Query().
Where(securityscanningresult.SecurityScanningIDIn(ids...)).
Modify(func(s *sql.Selector) {
s.Select(
sql.As("security_scanning_id", "id"),
sql.As("count(*) filter (where severity in ('CRITICAL', 'ERROR'))", "severe_count"),
sql.As("count(*) filter (where severity = 'WARNING')", "critical_count"),
sql.As("count(*) filter (where severity = 'INFO')", "suggest_count"),
).
GroupBy(securityscanningresult.FieldSecurityScanningID)
}).
Scan(ctx, &rs); err != nil {
return nil, err
}
return cvt.IterToMap(rs, func(_ int, r domain.SecurityScanningRiskResult) (uuid.UUID, domain.SecurityScanningRiskResult) {
return r.ID, r
}), nil
}
// AllRunning implements domain.SecurityScanningRepo.
func (s *SecurityScanningRepo) AllRunning(ctx context.Context) ([]*db.SecurityScanning, error) {
ctx = rule.SkipPermission(ctx)
return s.db.SecurityScanning.Query().
Where(securityscanning.Status(consts.SecurityScanningStatusRunning)).
Order(securityscanning.ByCreatedAt(sql.OrderAsc())).
All(ctx)
}
func (s *SecurityScanningRepo) Get(ctx context.Context, id string) (*db.SecurityScanning, error) {
sid, err := uuid.Parse(id)
if err != nil {
return nil, err
}
return s.db.SecurityScanning.Query().
WithWorkspaceEdge().
Where(securityscanning.ID(sid)).
First(ctx)
}
// PageWorkspaceFiles implements domain.SecurityScanningRepo.
func (s *SecurityScanningRepo) PageWorkspaceFiles(ctx context.Context, id string, size int, fn func([]*db.WorkspaceFile) error) error {
wid, err := uuid.Parse(id)
if err != nil {
return err
}
page := 1
hasMore := true
for hasMore {
rs, p, err := s.db.WorkspaceFile.Query().
Where(workspacefile.WorkspaceID(wid)).
Order(workspacefile.ByCreatedAt(sql.OrderAsc())).
Page(ctx, page, size)
if err != nil {
return err
}
if err := fn(rs); err != nil {
return err
}
hasMore = p.HasNextPage
page++
}
return nil
}
func (s *SecurityScanningRepo) ListDetail(ctx context.Context, req domain.ListSecurityScanningDetailReq) (*domain.ListSecurityScanningDetailResp, error) {
sid, err := uuid.Parse(req.ID)
if err != nil {
return nil, err
}
q := s.db.SecurityScanningResult.Query().
Where(securityscanningresult.SecurityScanningID(sid)).
Order(
BySeverityOrder(),
securityscanningresult.ByCreatedAt(sql.OrderDesc()),
securityscanningresult.ByID(sql.OrderDesc()),
)
rs, p, err := q.Page(ctx, req.Page, req.Size)
if err != nil {
return nil, err
}
return &domain.ListSecurityScanningDetailResp{
PageInfo: p,
Items: cvt.Iter(rs, func(_ int, r *db.SecurityScanningResult) *domain.SecurityScanningRiskDetail {
return cvt.From(r, &domain.SecurityScanningRiskDetail{})
}),
}, nil
}
func BySeverityOrder() func(s *sql.Selector) {
return func(s *sql.Selector) {
s.OrderExprFunc(func(b *sql.Builder) {
b.WriteString("case when severity = 'CRITICAL' then 5 when severity = 'ERROR' then 4 when severity = 'WARNING' then 3 when severity = 'INFO' then 2 else 1 end desc")
})
}
}