Files
xingrin/server/internal/handler/vulnerability_snapshot.go
2026-01-15 16:19:00 +08:00

220 lines
5.6 KiB
Go

package handler
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"github.com/gin-gonic/gin"
"github.com/xingrin/server/internal/dto"
"github.com/xingrin/server/internal/model"
"github.com/xingrin/server/internal/pkg/csv"
"github.com/xingrin/server/internal/service"
)
// VulnerabilitySnapshotHandler handles vulnerability snapshot endpoints
type VulnerabilitySnapshotHandler struct {
svc *service.VulnerabilitySnapshotService
}
// NewVulnerabilitySnapshotHandler creates a new vulnerability snapshot handler
func NewVulnerabilitySnapshotHandler(svc *service.VulnerabilitySnapshotService) *VulnerabilitySnapshotHandler {
return &VulnerabilitySnapshotHandler{svc: svc}
}
// BulkCreate creates vulnerability snapshots and syncs to asset table
// POST /api/scans/:id/vulnerabilities/bulk-create
func (h *VulnerabilitySnapshotHandler) BulkCreate(c *gin.Context) {
scanID, err := strconv.Atoi(c.Param("id"))
if err != nil {
dto.BadRequest(c, "Invalid scan ID")
return
}
var req dto.BulkCreateVulnerabilitySnapshotsRequest
if !dto.BindJSON(c, &req) {
return
}
snapshotCount, assetCount, err := h.svc.SaveAndSync(scanID, req.Vulnerabilities)
if err != nil {
if errors.Is(err, service.ErrScanNotFoundForSnapshot) {
dto.NotFound(c, "Scan not found")
return
}
dto.InternalError(c, "Failed to save vulnerability snapshots")
return
}
dto.Success(c, dto.BulkCreateVulnerabilitySnapshotsResponse{
SnapshotCount: int(snapshotCount),
AssetCount: int(assetCount),
})
}
// ListByScan returns paginated vulnerability snapshots for a scan
// GET /api/scans/:id/vulnerabilities/
func (h *VulnerabilitySnapshotHandler) ListByScan(c *gin.Context) {
scanID, err := strconv.Atoi(c.Param("id"))
if err != nil {
dto.BadRequest(c, "Invalid scan ID")
return
}
var query dto.VulnerabilitySnapshotListQuery
if !dto.BindQuery(c, &query) {
return
}
snapshots, total, err := h.svc.ListByScan(scanID, &query)
if err != nil {
if errors.Is(err, service.ErrScanNotFoundForSnapshot) {
dto.NotFound(c, "Scan not found")
return
}
dto.InternalError(c, "Failed to list vulnerability snapshots")
return
}
// Convert to response
var resp []dto.VulnerabilitySnapshotResponse
for _, s := range snapshots {
resp = append(resp, toVulnerabilitySnapshotResponse(&s))
}
dto.Paginated(c, resp, total, query.GetPage(), query.GetPageSize())
}
// ListAll returns paginated vulnerability snapshots for all scans
// GET /api/vulnerability-snapshots/
func (h *VulnerabilitySnapshotHandler) ListAll(c *gin.Context) {
var query dto.VulnerabilitySnapshotListQuery
if !dto.BindQuery(c, &query) {
return
}
snapshots, total, err := h.svc.ListAll(&query)
if err != nil {
dto.InternalError(c, "Failed to list vulnerability snapshots")
return
}
// Convert to response
var resp []dto.VulnerabilitySnapshotResponse
for _, s := range snapshots {
resp = append(resp, toVulnerabilitySnapshotResponse(&s))
}
dto.Paginated(c, resp, total, query.GetPage(), query.GetPageSize())
}
// GetByID returns a vulnerability snapshot by ID
// GET /api/vulnerability-snapshots/:id/
func (h *VulnerabilitySnapshotHandler) GetByID(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
dto.BadRequest(c, "Invalid vulnerability snapshot ID")
return
}
snapshot, err := h.svc.GetByID(id)
if err != nil {
if errors.Is(err, service.ErrVulnerabilitySnapshotNotFound) {
dto.NotFound(c, "Vulnerability snapshot not found")
return
}
dto.InternalError(c, "Failed to get vulnerability snapshot")
return
}
dto.OK(c, toVulnerabilitySnapshotResponse(snapshot))
}
// Export exports vulnerability snapshots as CSV
// GET /api/scans/:id/vulnerabilities/export/
func (h *VulnerabilitySnapshotHandler) Export(c *gin.Context) {
scanID, err := strconv.Atoi(c.Param("id"))
if err != nil {
dto.BadRequest(c, "Invalid scan ID")
return
}
// Get count for progress estimation
count, err := h.svc.CountByScan(scanID)
if err != nil {
if errors.Is(err, service.ErrScanNotFoundForSnapshot) {
dto.NotFound(c, "Scan not found")
return
}
dto.InternalError(c, "Failed to export vulnerability snapshots")
return
}
rows, err := h.svc.StreamByScan(scanID)
if err != nil {
dto.InternalError(c, "Failed to export vulnerability snapshots")
return
}
headers := []string{
"url", "vuln_type", "severity", "source", "cvss_score",
"description", "raw_output", "created_at",
}
filename := fmt.Sprintf("scan-%d-vulnerabilities.csv", scanID)
mapper := func(rows *sql.Rows) ([]string, error) {
snapshot, err := h.svc.ScanRow(rows)
if err != nil {
return nil, err
}
cvssScore := ""
if snapshot.CVSSScore != nil {
cvssScore = snapshot.CVSSScore.String()
}
rawOutput := ""
if len(snapshot.RawOutput) > 0 {
rawOutput = string(snapshot.RawOutput)
}
return []string{
snapshot.URL,
snapshot.VulnType,
snapshot.Severity,
snapshot.Source,
cvssScore,
snapshot.Description,
rawOutput,
snapshot.CreatedAt.Format("2006-01-02 15:04:05"),
}, nil
}
if err := csv.StreamCSV(c, rows, headers, filename, mapper, count); err != nil {
return
}
}
// toVulnerabilitySnapshotResponse converts model to response DTO
func toVulnerabilitySnapshotResponse(s *model.VulnerabilitySnapshot) dto.VulnerabilitySnapshotResponse {
rawOutput := s.RawOutput
if rawOutput == nil {
rawOutput, _ = json.Marshal(map[string]any{})
}
return dto.VulnerabilitySnapshotResponse{
ID: s.ID,
ScanID: s.ScanID,
URL: s.URL,
VulnType: s.VulnType,
Severity: s.Severity,
Source: s.Source,
CVSSScore: s.CVSSScore,
Description: s.Description,
RawOutput: rawOutput,
CreatedAt: s.CreatedAt,
}
}