Files
xingrin/server/internal/handler/vulnerability_snapshot_test.go
2026-01-15 16:19:00 +08:00

668 lines
20 KiB
Go

package handler
import (
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
"github.com/xingrin/server/internal/dto"
"github.com/xingrin/server/internal/model"
"github.com/xingrin/server/internal/service"
"gorm.io/datatypes"
)
// MockVulnerabilitySnapshotService is a mock implementation for testing
type MockVulnerabilitySnapshotService struct {
SaveAndSyncFunc func(scanID int, items []dto.VulnerabilitySnapshotItem) (int64, int64, error)
ListByScanFunc func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error)
ListAllFunc func(query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error)
GetByIDFunc func(id int) (*model.VulnerabilitySnapshot, error)
StreamByScanFunc func(scanID int) (*sql.Rows, error)
CountByScanFunc func(scanID int) (int64, error)
ScanRowFunc func(rows *sql.Rows) (*model.VulnerabilitySnapshot, error)
}
func (m *MockVulnerabilitySnapshotService) SaveAndSync(scanID int, items []dto.VulnerabilitySnapshotItem) (int64, int64, error) {
if m.SaveAndSyncFunc != nil {
return m.SaveAndSyncFunc(scanID, items)
}
return 0, 0, nil
}
func (m *MockVulnerabilitySnapshotService) ListByScan(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
if m.ListByScanFunc != nil {
return m.ListByScanFunc(scanID, query)
}
return nil, 0, nil
}
func (m *MockVulnerabilitySnapshotService) ListAll(query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
if m.ListAllFunc != nil {
return m.ListAllFunc(query)
}
return nil, 0, nil
}
func (m *MockVulnerabilitySnapshotService) GetByID(id int) (*model.VulnerabilitySnapshot, error) {
if m.GetByIDFunc != nil {
return m.GetByIDFunc(id)
}
return nil, nil
}
func (m *MockVulnerabilitySnapshotService) CountByScan(scanID int) (int64, error) {
if m.CountByScanFunc != nil {
return m.CountByScanFunc(scanID)
}
return 0, nil
}
// TestVulnerabilitySnapshotBulkCreate tests the BulkCreate endpoint
func TestVulnerabilitySnapshotBulkCreate(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
scanID string
body string
mockFunc func(scanID int, items []dto.VulnerabilitySnapshotItem) (int64, int64, error)
expectedStatus int
expectedBody string
}{
{
name: "successful bulk create",
scanID: "1",
body: `{"vulnerabilities":[{"url":"https://example.com/vuln","vulnType":"XSS","severity":"high"}]}`,
mockFunc: func(scanID int, items []dto.VulnerabilitySnapshotItem) (int64, int64, error) {
return 1, 1, nil
},
expectedStatus: http.StatusOK,
expectedBody: `"snapshotCount":1,"assetCount":1`,
},
{
name: "invalid scan ID",
scanID: "invalid",
body: `{"vulnerabilities":[]}`,
expectedStatus: http.StatusBadRequest,
expectedBody: `"message":"Invalid scan ID"`,
},
{
name: "scan not found",
scanID: "999",
body: `{"vulnerabilities":[{"url":"https://example.com/vuln","vulnType":"XSS","severity":"high"}]}`,
mockFunc: func(scanID int, items []dto.VulnerabilitySnapshotItem) (int64, int64, error) {
return 0, 0, service.ErrScanNotFoundForSnapshot
},
expectedStatus: http.StatusNotFound,
expectedBody: `"message":"Scan not found"`,
},
{
name: "multiple vulnerabilities",
scanID: "1",
body: `{"vulnerabilities":[{"url":"https://example.com/vuln1","vulnType":"XSS","severity":"high"},{"url":"https://example.com/vuln2","vulnType":"SQLi","severity":"critical"}]}`,
mockFunc: func(scanID int, items []dto.VulnerabilitySnapshotItem) (int64, int64, error) {
return 2, 2, nil
},
expectedStatus: http.StatusOK,
expectedBody: `"snapshotCount":2,"assetCount":2`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockSvc := &MockVulnerabilitySnapshotService{
SaveAndSyncFunc: tt.mockFunc,
}
router := gin.New()
router.POST("/api/scans/:id/vulnerabilities/bulk-create", func(c *gin.Context) {
scanID := c.Param("id")
if scanID == "invalid" {
dto.BadRequest(c, "Invalid scan ID")
return
}
var req dto.BulkCreateVulnerabilitySnapshotsRequest
if err := c.ShouldBindJSON(&req); err != nil {
dto.BadRequest(c, "Invalid request body")
return
}
snapshotCount, assetCount, err := mockSvc.SaveAndSync(1, req.Vulnerabilities)
if err != nil {
if err == service.ErrScanNotFoundForSnapshot {
dto.NotFound(c, "Scan not found")
return
}
dto.InternalError(c, "Failed to save vulnerability snapshots")
return
}
dto.Success(c, dto.BulkCreateVulnerabilitySnapshotsResponse{
SnapshotCount: int(snapshotCount),
AssetCount: int(assetCount),
})
})
req := httptest.NewRequest(http.MethodPost, "/api/scans/"+tt.scanID+"/vulnerabilities/bulk-create", strings.NewReader(tt.body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
}
if !strings.Contains(w.Body.String(), tt.expectedBody) {
t.Errorf("expected body to contain %q, got %q", tt.expectedBody, w.Body.String())
}
})
}
}
// TestVulnerabilitySnapshotListByScan tests the ListByScan endpoint
func TestVulnerabilitySnapshotListByScan(t *testing.T) {
gin.SetMode(gin.TestMode)
now := time.Now()
score := decimal.NewFromFloat(7.5)
mockSnapshots := []model.VulnerabilitySnapshot{
{ID: 1, ScanID: 1, URL: "https://example.com/vuln1", VulnType: "XSS", Severity: "high", CVSSScore: &score, CreatedAt: now},
{ID: 2, ScanID: 1, URL: "https://example.com/vuln2", VulnType: "SQLi", Severity: "critical", CreatedAt: now},
}
tests := []struct {
name string
scanID string
queryParams string
mockFunc func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error)
expectedStatus int
checkResponse func(t *testing.T, body string)
}{
{
name: "list with default pagination",
scanID: "1",
queryParams: "",
mockFunc: func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
if query.GetPage() != 1 {
t.Errorf("expected page 1, got %d", query.GetPage())
}
if query.GetPageSize() != 20 {
t.Errorf("expected pageSize 20, got %d", query.GetPageSize())
}
return mockSnapshots, 2, nil
},
expectedStatus: http.StatusOK,
checkResponse: func(t *testing.T, body string) {
var resp dto.PaginatedResponse[dto.VulnerabilitySnapshotResponse]
if err := json.Unmarshal([]byte(body), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 2 {
t.Errorf("expected total 2, got %d", resp.Total)
}
if resp.Page != 1 {
t.Errorf("expected page 1, got %d", resp.Page)
}
},
},
{
name: "list with custom pagination",
scanID: "1",
queryParams: "?page=2&pageSize=10",
mockFunc: func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
if query.GetPage() != 2 {
t.Errorf("expected page 2, got %d", query.GetPage())
}
if query.GetPageSize() != 10 {
t.Errorf("expected pageSize 10, got %d", query.GetPageSize())
}
return []model.VulnerabilitySnapshot{}, 30, nil
},
expectedStatus: http.StatusOK,
checkResponse: func(t *testing.T, body string) {
var resp dto.PaginatedResponse[dto.VulnerabilitySnapshotResponse]
if err := json.Unmarshal([]byte(body), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Page != 2 {
t.Errorf("expected page 2, got %d", resp.Page)
}
if resp.PageSize != 10 {
t.Errorf("expected pageSize 10, got %d", resp.PageSize)
}
},
},
{
name: "list with severity filter",
scanID: "1",
queryParams: "?severity=critical",
mockFunc: func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
if query.Severity != "critical" {
t.Errorf("expected severity 'critical', got %q", query.Severity)
}
return mockSnapshots[1:], 1, nil
},
expectedStatus: http.StatusOK,
checkResponse: func(t *testing.T, body string) {
var resp dto.PaginatedResponse[dto.VulnerabilitySnapshotResponse]
if err := json.Unmarshal([]byte(body), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 1 {
t.Errorf("expected total 1, got %d", resp.Total)
}
},
},
{
name: "scan not found",
scanID: "999",
queryParams: "",
mockFunc: func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
return nil, 0, service.ErrScanNotFoundForSnapshot
},
expectedStatus: http.StatusNotFound,
checkResponse: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockSvc := &MockVulnerabilitySnapshotService{
ListByScanFunc: tt.mockFunc,
}
router := gin.New()
router.GET("/api/scans/:id/vulnerabilities/", func(c *gin.Context) {
scanID := c.Param("id")
if scanID == "invalid" {
dto.BadRequest(c, "Invalid scan ID")
return
}
var query dto.VulnerabilitySnapshotListQuery
if err := c.ShouldBindQuery(&query); err != nil {
dto.BadRequest(c, "Invalid query parameters")
return
}
snapshots, total, err := mockSvc.ListByScan(1, &query)
if err != nil {
if err == service.ErrScanNotFoundForSnapshot {
dto.NotFound(c, "Scan not found")
return
}
dto.InternalError(c, "Failed to list vulnerability snapshots")
return
}
var resp []dto.VulnerabilitySnapshotResponse
for _, s := range snapshots {
resp = append(resp, toVulnerabilitySnapshotResponse(&s))
}
dto.Paginated(c, resp, total, query.GetPage(), query.GetPageSize())
})
req := httptest.NewRequest(http.MethodGet, "/api/scans/"+tt.scanID+"/vulnerabilities/"+tt.queryParams, nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
}
if tt.checkResponse != nil {
tt.checkResponse(t, w.Body.String())
}
})
}
}
// TestVulnerabilitySnapshotListAll tests the ListAll endpoint
func TestVulnerabilitySnapshotListAll(t *testing.T) {
gin.SetMode(gin.TestMode)
now := time.Now()
mockSnapshots := []model.VulnerabilitySnapshot{
{ID: 1, ScanID: 1, URL: "https://example.com/vuln1", VulnType: "XSS", Severity: "high", CreatedAt: now},
{ID: 2, ScanID: 2, URL: "https://example.com/vuln2", VulnType: "SQLi", Severity: "critical", CreatedAt: now},
}
tests := []struct {
name string
queryParams string
mockFunc func(query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error)
expectedStatus int
checkResponse func(t *testing.T, body string)
}{
{
name: "list all with default pagination",
queryParams: "",
mockFunc: func(query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
return mockSnapshots, 2, nil
},
expectedStatus: http.StatusOK,
checkResponse: func(t *testing.T, body string) {
var resp dto.PaginatedResponse[dto.VulnerabilitySnapshotResponse]
if err := json.Unmarshal([]byte(body), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 2 {
t.Errorf("expected total 2, got %d", resp.Total)
}
},
},
{
name: "list all with filter",
queryParams: "?filter=XSS",
mockFunc: func(query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
if query.Filter != "XSS" {
t.Errorf("expected filter 'XSS', got %q", query.Filter)
}
return mockSnapshots[:1], 1, nil
},
expectedStatus: http.StatusOK,
checkResponse: func(t *testing.T, body string) {
var resp dto.PaginatedResponse[dto.VulnerabilitySnapshotResponse]
if err := json.Unmarshal([]byte(body), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.Total != 1 {
t.Errorf("expected total 1, got %d", resp.Total)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockSvc := &MockVulnerabilitySnapshotService{
ListAllFunc: tt.mockFunc,
}
router := gin.New()
router.GET("/api/vulnerability-snapshots/", func(c *gin.Context) {
var query dto.VulnerabilitySnapshotListQuery
if err := c.ShouldBindQuery(&query); err != nil {
dto.BadRequest(c, "Invalid query parameters")
return
}
snapshots, total, err := mockSvc.ListAll(&query)
if err != nil {
dto.InternalError(c, "Failed to list vulnerability snapshots")
return
}
var resp []dto.VulnerabilitySnapshotResponse
for _, s := range snapshots {
resp = append(resp, toVulnerabilitySnapshotResponse(&s))
}
dto.Paginated(c, resp, total, query.GetPage(), query.GetPageSize())
})
req := httptest.NewRequest(http.MethodGet, "/api/vulnerability-snapshots/"+tt.queryParams, nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
}
if tt.checkResponse != nil {
tt.checkResponse(t, w.Body.String())
}
})
}
}
// TestVulnerabilitySnapshotGetByID tests the GetByID endpoint
func TestVulnerabilitySnapshotGetByID(t *testing.T) {
gin.SetMode(gin.TestMode)
now := time.Now()
score := decimal.NewFromFloat(7.5)
mockSnapshot := &model.VulnerabilitySnapshot{
ID: 1,
ScanID: 1,
URL: "https://example.com/vuln",
VulnType: "XSS",
Severity: "high",
Source: "nuclei",
CVSSScore: &score,
Description: "XSS vulnerability found",
RawOutput: datatypes.JSON(`{"template":"xss-test"}`),
CreatedAt: now,
}
tests := []struct {
name string
id string
mockFunc func(id int) (*model.VulnerabilitySnapshot, error)
expectedStatus int
checkResponse func(t *testing.T, body string)
}{
{
name: "get existing snapshot",
id: "1",
mockFunc: func(id int) (*model.VulnerabilitySnapshot, error) {
return mockSnapshot, nil
},
expectedStatus: http.StatusOK,
checkResponse: func(t *testing.T, body string) {
var resp dto.VulnerabilitySnapshotResponse
if err := json.Unmarshal([]byte(body), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp.ID != 1 {
t.Errorf("expected ID 1, got %d", resp.ID)
}
if resp.URL != "https://example.com/vuln" {
t.Errorf("expected URL 'https://example.com/vuln', got %q", resp.URL)
}
},
},
{
name: "invalid ID",
id: "invalid",
expectedStatus: http.StatusBadRequest,
checkResponse: nil,
},
{
name: "snapshot not found",
id: "999",
mockFunc: func(id int) (*model.VulnerabilitySnapshot, error) {
return nil, service.ErrVulnerabilitySnapshotNotFound
},
expectedStatus: http.StatusNotFound,
checkResponse: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockSvc := &MockVulnerabilitySnapshotService{
GetByIDFunc: tt.mockFunc,
}
router := gin.New()
router.GET("/api/vulnerability-snapshots/:id/", func(c *gin.Context) {
idStr := c.Param("id")
if idStr == "invalid" {
dto.BadRequest(c, "Invalid vulnerability snapshot ID")
return
}
var id int
switch idStr {
case "1":
id = 1
case "999":
id = 999
}
snapshot, err := mockSvc.GetByID(id)
if err != nil {
if err == service.ErrVulnerabilitySnapshotNotFound {
dto.NotFound(c, "Vulnerability snapshot not found")
return
}
dto.InternalError(c, "Failed to get vulnerability snapshot")
return
}
dto.OK(c, toVulnerabilitySnapshotResponse(snapshot))
})
req := httptest.NewRequest(http.MethodGet, "/api/vulnerability-snapshots/"+tt.id+"/", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
}
if tt.checkResponse != nil {
tt.checkResponse(t, w.Body.String())
}
})
}
}
// TestVulnerabilitySnapshotPaginationProperties tests pagination correctness
func TestVulnerabilitySnapshotPaginationProperties(t *testing.T) {
// Property: totalPages = ceil(total / pageSize)
tests := []struct {
total int64
pageSize int
wantPages int
}{
{total: 0, pageSize: 20, wantPages: 0},
{total: 1, pageSize: 20, wantPages: 1},
{total: 20, pageSize: 20, wantPages: 1},
{total: 21, pageSize: 20, wantPages: 2},
{total: 100, pageSize: 10, wantPages: 10},
{total: 101, pageSize: 10, wantPages: 11},
}
for _, tt := range tests {
totalPages := int(tt.total) / tt.pageSize
if int(tt.total)%tt.pageSize > 0 {
totalPages++
}
if tt.total == 0 {
totalPages = 0
}
if totalPages != tt.wantPages {
t.Errorf("total=%d, pageSize=%d: expected totalPages=%d, got %d",
tt.total, tt.pageSize, tt.wantPages, totalPages)
}
}
}
// TestVulnerabilitySnapshotFilterProperties tests filter correctness
func TestVulnerabilitySnapshotFilterProperties(t *testing.T) {
gin.SetMode(gin.TestMode)
filterTests := []string{
"",
"XSS",
"SQLi",
}
for _, filter := range filterTests {
t.Run("filter_"+filter, func(t *testing.T) {
var receivedFilter string
mockSvc := &MockVulnerabilitySnapshotService{
ListAllFunc: func(query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
receivedFilter = query.Filter
return nil, 0, nil
},
}
router := gin.New()
router.GET("/api/vulnerability-snapshots/", func(c *gin.Context) {
var query dto.VulnerabilitySnapshotListQuery
_ = c.ShouldBindQuery(&query)
_, _, _ = mockSvc.ListAll(&query)
dto.Paginated(c, []dto.VulnerabilitySnapshotResponse{}, 0, 1, 20)
})
url := "/api/vulnerability-snapshots/"
if filter != "" {
url += "?filter=" + filter
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if receivedFilter != filter {
t.Errorf("expected filter %q, got %q", filter, receivedFilter)
}
})
}
}
// TestVulnerabilitySnapshotSeverityFilterProperties tests severity filter correctness
func TestVulnerabilitySnapshotSeverityFilterProperties(t *testing.T) {
gin.SetMode(gin.TestMode)
severityTests := []string{
"",
"unknown",
"info",
"low",
"medium",
"high",
"critical",
}
for _, severity := range severityTests {
t.Run("severity_"+severity, func(t *testing.T) {
var receivedSeverity string
mockSvc := &MockVulnerabilitySnapshotService{
ListByScanFunc: func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
receivedSeverity = query.Severity
return nil, 0, nil
},
}
router := gin.New()
router.GET("/api/scans/:id/vulnerabilities/", func(c *gin.Context) {
var query dto.VulnerabilitySnapshotListQuery
_ = c.ShouldBindQuery(&query)
_, _, _ = mockSvc.ListByScan(1, &query)
dto.Paginated(c, []dto.VulnerabilitySnapshotResponse{}, 0, 1, 20)
})
url := "/api/scans/1/vulnerabilities/"
if severity != "" {
url += "?severity=" + severity
}
req := httptest.NewRequest(http.MethodGet, url, nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if receivedSeverity != severity {
t.Errorf("expected severity %q, got %q", severity, receivedSeverity)
}
})
}
}