mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-01-31 11:46:16 +08:00
668 lines
20 KiB
Go
668 lines
20 KiB
Go
package handler
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/shopspring/decimal"
|
|
"github.com/xingrin/server/internal/dto"
|
|
"github.com/xingrin/server/internal/model"
|
|
"github.com/xingrin/server/internal/service"
|
|
"gorm.io/datatypes"
|
|
)
|
|
|
|
// MockVulnerabilitySnapshotService is a mock implementation for testing
|
|
type MockVulnerabilitySnapshotService struct {
|
|
SaveAndSyncFunc func(scanID int, items []dto.VulnerabilitySnapshotItem) (int64, int64, error)
|
|
ListByScanFunc func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error)
|
|
ListAllFunc func(query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error)
|
|
GetByIDFunc func(id int) (*model.VulnerabilitySnapshot, error)
|
|
StreamByScanFunc func(scanID int) (*sql.Rows, error)
|
|
CountByScanFunc func(scanID int) (int64, error)
|
|
ScanRowFunc func(rows *sql.Rows) (*model.VulnerabilitySnapshot, error)
|
|
}
|
|
|
|
func (m *MockVulnerabilitySnapshotService) SaveAndSync(scanID int, items []dto.VulnerabilitySnapshotItem) (int64, int64, error) {
|
|
if m.SaveAndSyncFunc != nil {
|
|
return m.SaveAndSyncFunc(scanID, items)
|
|
}
|
|
return 0, 0, nil
|
|
}
|
|
|
|
func (m *MockVulnerabilitySnapshotService) ListByScan(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
|
|
if m.ListByScanFunc != nil {
|
|
return m.ListByScanFunc(scanID, query)
|
|
}
|
|
return nil, 0, nil
|
|
}
|
|
|
|
func (m *MockVulnerabilitySnapshotService) ListAll(query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
|
|
if m.ListAllFunc != nil {
|
|
return m.ListAllFunc(query)
|
|
}
|
|
return nil, 0, nil
|
|
}
|
|
|
|
func (m *MockVulnerabilitySnapshotService) GetByID(id int) (*model.VulnerabilitySnapshot, error) {
|
|
if m.GetByIDFunc != nil {
|
|
return m.GetByIDFunc(id)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *MockVulnerabilitySnapshotService) CountByScan(scanID int) (int64, error) {
|
|
if m.CountByScanFunc != nil {
|
|
return m.CountByScanFunc(scanID)
|
|
}
|
|
return 0, nil
|
|
}
|
|
|
|
// TestVulnerabilitySnapshotBulkCreate tests the BulkCreate endpoint
|
|
func TestVulnerabilitySnapshotBulkCreate(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
tests := []struct {
|
|
name string
|
|
scanID string
|
|
body string
|
|
mockFunc func(scanID int, items []dto.VulnerabilitySnapshotItem) (int64, int64, error)
|
|
expectedStatus int
|
|
expectedBody string
|
|
}{
|
|
{
|
|
name: "successful bulk create",
|
|
scanID: "1",
|
|
body: `{"vulnerabilities":[{"url":"https://example.com/vuln","vulnType":"XSS","severity":"high"}]}`,
|
|
mockFunc: func(scanID int, items []dto.VulnerabilitySnapshotItem) (int64, int64, error) {
|
|
return 1, 1, nil
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: `"snapshotCount":1,"assetCount":1`,
|
|
},
|
|
{
|
|
name: "invalid scan ID",
|
|
scanID: "invalid",
|
|
body: `{"vulnerabilities":[]}`,
|
|
expectedStatus: http.StatusBadRequest,
|
|
expectedBody: `"message":"Invalid scan ID"`,
|
|
},
|
|
{
|
|
name: "scan not found",
|
|
scanID: "999",
|
|
body: `{"vulnerabilities":[{"url":"https://example.com/vuln","vulnType":"XSS","severity":"high"}]}`,
|
|
mockFunc: func(scanID int, items []dto.VulnerabilitySnapshotItem) (int64, int64, error) {
|
|
return 0, 0, service.ErrScanNotFoundForSnapshot
|
|
},
|
|
expectedStatus: http.StatusNotFound,
|
|
expectedBody: `"message":"Scan not found"`,
|
|
},
|
|
{
|
|
name: "multiple vulnerabilities",
|
|
scanID: "1",
|
|
body: `{"vulnerabilities":[{"url":"https://example.com/vuln1","vulnType":"XSS","severity":"high"},{"url":"https://example.com/vuln2","vulnType":"SQLi","severity":"critical"}]}`,
|
|
mockFunc: func(scanID int, items []dto.VulnerabilitySnapshotItem) (int64, int64, error) {
|
|
return 2, 2, nil
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: `"snapshotCount":2,"assetCount":2`,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockSvc := &MockVulnerabilitySnapshotService{
|
|
SaveAndSyncFunc: tt.mockFunc,
|
|
}
|
|
|
|
router := gin.New()
|
|
router.POST("/api/scans/:id/vulnerabilities/bulk-create", func(c *gin.Context) {
|
|
scanID := c.Param("id")
|
|
if scanID == "invalid" {
|
|
dto.BadRequest(c, "Invalid scan ID")
|
|
return
|
|
}
|
|
|
|
var req dto.BulkCreateVulnerabilitySnapshotsRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
dto.BadRequest(c, "Invalid request body")
|
|
return
|
|
}
|
|
|
|
snapshotCount, assetCount, err := mockSvc.SaveAndSync(1, req.Vulnerabilities)
|
|
if err != nil {
|
|
if err == service.ErrScanNotFoundForSnapshot {
|
|
dto.NotFound(c, "Scan not found")
|
|
return
|
|
}
|
|
dto.InternalError(c, "Failed to save vulnerability snapshots")
|
|
return
|
|
}
|
|
|
|
dto.Success(c, dto.BulkCreateVulnerabilitySnapshotsResponse{
|
|
SnapshotCount: int(snapshotCount),
|
|
AssetCount: int(assetCount),
|
|
})
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/scans/"+tt.scanID+"/vulnerabilities/bulk-create", strings.NewReader(tt.body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != tt.expectedStatus {
|
|
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
|
|
}
|
|
|
|
if !strings.Contains(w.Body.String(), tt.expectedBody) {
|
|
t.Errorf("expected body to contain %q, got %q", tt.expectedBody, w.Body.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestVulnerabilitySnapshotListByScan tests the ListByScan endpoint
|
|
func TestVulnerabilitySnapshotListByScan(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
now := time.Now()
|
|
score := decimal.NewFromFloat(7.5)
|
|
mockSnapshots := []model.VulnerabilitySnapshot{
|
|
{ID: 1, ScanID: 1, URL: "https://example.com/vuln1", VulnType: "XSS", Severity: "high", CVSSScore: &score, CreatedAt: now},
|
|
{ID: 2, ScanID: 1, URL: "https://example.com/vuln2", VulnType: "SQLi", Severity: "critical", CreatedAt: now},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
scanID string
|
|
queryParams string
|
|
mockFunc func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error)
|
|
expectedStatus int
|
|
checkResponse func(t *testing.T, body string)
|
|
}{
|
|
{
|
|
name: "list with default pagination",
|
|
scanID: "1",
|
|
queryParams: "",
|
|
mockFunc: func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
|
|
if query.GetPage() != 1 {
|
|
t.Errorf("expected page 1, got %d", query.GetPage())
|
|
}
|
|
if query.GetPageSize() != 20 {
|
|
t.Errorf("expected pageSize 20, got %d", query.GetPageSize())
|
|
}
|
|
return mockSnapshots, 2, nil
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
checkResponse: func(t *testing.T, body string) {
|
|
var resp dto.PaginatedResponse[dto.VulnerabilitySnapshotResponse]
|
|
if err := json.Unmarshal([]byte(body), &resp); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
if resp.Total != 2 {
|
|
t.Errorf("expected total 2, got %d", resp.Total)
|
|
}
|
|
if resp.Page != 1 {
|
|
t.Errorf("expected page 1, got %d", resp.Page)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "list with custom pagination",
|
|
scanID: "1",
|
|
queryParams: "?page=2&pageSize=10",
|
|
mockFunc: func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
|
|
if query.GetPage() != 2 {
|
|
t.Errorf("expected page 2, got %d", query.GetPage())
|
|
}
|
|
if query.GetPageSize() != 10 {
|
|
t.Errorf("expected pageSize 10, got %d", query.GetPageSize())
|
|
}
|
|
return []model.VulnerabilitySnapshot{}, 30, nil
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
checkResponse: func(t *testing.T, body string) {
|
|
var resp dto.PaginatedResponse[dto.VulnerabilitySnapshotResponse]
|
|
if err := json.Unmarshal([]byte(body), &resp); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
if resp.Page != 2 {
|
|
t.Errorf("expected page 2, got %d", resp.Page)
|
|
}
|
|
if resp.PageSize != 10 {
|
|
t.Errorf("expected pageSize 10, got %d", resp.PageSize)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "list with severity filter",
|
|
scanID: "1",
|
|
queryParams: "?severity=critical",
|
|
mockFunc: func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
|
|
if query.Severity != "critical" {
|
|
t.Errorf("expected severity 'critical', got %q", query.Severity)
|
|
}
|
|
return mockSnapshots[1:], 1, nil
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
checkResponse: func(t *testing.T, body string) {
|
|
var resp dto.PaginatedResponse[dto.VulnerabilitySnapshotResponse]
|
|
if err := json.Unmarshal([]byte(body), &resp); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
if resp.Total != 1 {
|
|
t.Errorf("expected total 1, got %d", resp.Total)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "scan not found",
|
|
scanID: "999",
|
|
queryParams: "",
|
|
mockFunc: func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
|
|
return nil, 0, service.ErrScanNotFoundForSnapshot
|
|
},
|
|
expectedStatus: http.StatusNotFound,
|
|
checkResponse: nil,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockSvc := &MockVulnerabilitySnapshotService{
|
|
ListByScanFunc: tt.mockFunc,
|
|
}
|
|
|
|
router := gin.New()
|
|
router.GET("/api/scans/:id/vulnerabilities/", func(c *gin.Context) {
|
|
scanID := c.Param("id")
|
|
if scanID == "invalid" {
|
|
dto.BadRequest(c, "Invalid scan ID")
|
|
return
|
|
}
|
|
|
|
var query dto.VulnerabilitySnapshotListQuery
|
|
if err := c.ShouldBindQuery(&query); err != nil {
|
|
dto.BadRequest(c, "Invalid query parameters")
|
|
return
|
|
}
|
|
|
|
snapshots, total, err := mockSvc.ListByScan(1, &query)
|
|
if err != nil {
|
|
if err == service.ErrScanNotFoundForSnapshot {
|
|
dto.NotFound(c, "Scan not found")
|
|
return
|
|
}
|
|
dto.InternalError(c, "Failed to list vulnerability snapshots")
|
|
return
|
|
}
|
|
|
|
var resp []dto.VulnerabilitySnapshotResponse
|
|
for _, s := range snapshots {
|
|
resp = append(resp, toVulnerabilitySnapshotResponse(&s))
|
|
}
|
|
|
|
dto.Paginated(c, resp, total, query.GetPage(), query.GetPageSize())
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/scans/"+tt.scanID+"/vulnerabilities/"+tt.queryParams, nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != tt.expectedStatus {
|
|
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
|
|
}
|
|
|
|
if tt.checkResponse != nil {
|
|
tt.checkResponse(t, w.Body.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
|
|
// TestVulnerabilitySnapshotListAll tests the ListAll endpoint
|
|
func TestVulnerabilitySnapshotListAll(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
now := time.Now()
|
|
mockSnapshots := []model.VulnerabilitySnapshot{
|
|
{ID: 1, ScanID: 1, URL: "https://example.com/vuln1", VulnType: "XSS", Severity: "high", CreatedAt: now},
|
|
{ID: 2, ScanID: 2, URL: "https://example.com/vuln2", VulnType: "SQLi", Severity: "critical", CreatedAt: now},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
queryParams string
|
|
mockFunc func(query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error)
|
|
expectedStatus int
|
|
checkResponse func(t *testing.T, body string)
|
|
}{
|
|
{
|
|
name: "list all with default pagination",
|
|
queryParams: "",
|
|
mockFunc: func(query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
|
|
return mockSnapshots, 2, nil
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
checkResponse: func(t *testing.T, body string) {
|
|
var resp dto.PaginatedResponse[dto.VulnerabilitySnapshotResponse]
|
|
if err := json.Unmarshal([]byte(body), &resp); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
if resp.Total != 2 {
|
|
t.Errorf("expected total 2, got %d", resp.Total)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "list all with filter",
|
|
queryParams: "?filter=XSS",
|
|
mockFunc: func(query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
|
|
if query.Filter != "XSS" {
|
|
t.Errorf("expected filter 'XSS', got %q", query.Filter)
|
|
}
|
|
return mockSnapshots[:1], 1, nil
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
checkResponse: func(t *testing.T, body string) {
|
|
var resp dto.PaginatedResponse[dto.VulnerabilitySnapshotResponse]
|
|
if err := json.Unmarshal([]byte(body), &resp); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
if resp.Total != 1 {
|
|
t.Errorf("expected total 1, got %d", resp.Total)
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockSvc := &MockVulnerabilitySnapshotService{
|
|
ListAllFunc: tt.mockFunc,
|
|
}
|
|
|
|
router := gin.New()
|
|
router.GET("/api/vulnerability-snapshots/", func(c *gin.Context) {
|
|
var query dto.VulnerabilitySnapshotListQuery
|
|
if err := c.ShouldBindQuery(&query); err != nil {
|
|
dto.BadRequest(c, "Invalid query parameters")
|
|
return
|
|
}
|
|
|
|
snapshots, total, err := mockSvc.ListAll(&query)
|
|
if err != nil {
|
|
dto.InternalError(c, "Failed to list vulnerability snapshots")
|
|
return
|
|
}
|
|
|
|
var resp []dto.VulnerabilitySnapshotResponse
|
|
for _, s := range snapshots {
|
|
resp = append(resp, toVulnerabilitySnapshotResponse(&s))
|
|
}
|
|
|
|
dto.Paginated(c, resp, total, query.GetPage(), query.GetPageSize())
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/vulnerability-snapshots/"+tt.queryParams, nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != tt.expectedStatus {
|
|
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
|
|
}
|
|
|
|
if tt.checkResponse != nil {
|
|
tt.checkResponse(t, w.Body.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestVulnerabilitySnapshotGetByID tests the GetByID endpoint
|
|
func TestVulnerabilitySnapshotGetByID(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
now := time.Now()
|
|
score := decimal.NewFromFloat(7.5)
|
|
mockSnapshot := &model.VulnerabilitySnapshot{
|
|
ID: 1,
|
|
ScanID: 1,
|
|
URL: "https://example.com/vuln",
|
|
VulnType: "XSS",
|
|
Severity: "high",
|
|
Source: "nuclei",
|
|
CVSSScore: &score,
|
|
Description: "XSS vulnerability found",
|
|
RawOutput: datatypes.JSON(`{"template":"xss-test"}`),
|
|
CreatedAt: now,
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
id string
|
|
mockFunc func(id int) (*model.VulnerabilitySnapshot, error)
|
|
expectedStatus int
|
|
checkResponse func(t *testing.T, body string)
|
|
}{
|
|
{
|
|
name: "get existing snapshot",
|
|
id: "1",
|
|
mockFunc: func(id int) (*model.VulnerabilitySnapshot, error) {
|
|
return mockSnapshot, nil
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
checkResponse: func(t *testing.T, body string) {
|
|
var resp dto.VulnerabilitySnapshotResponse
|
|
if err := json.Unmarshal([]byte(body), &resp); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
if resp.ID != 1 {
|
|
t.Errorf("expected ID 1, got %d", resp.ID)
|
|
}
|
|
if resp.URL != "https://example.com/vuln" {
|
|
t.Errorf("expected URL 'https://example.com/vuln', got %q", resp.URL)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "invalid ID",
|
|
id: "invalid",
|
|
expectedStatus: http.StatusBadRequest,
|
|
checkResponse: nil,
|
|
},
|
|
{
|
|
name: "snapshot not found",
|
|
id: "999",
|
|
mockFunc: func(id int) (*model.VulnerabilitySnapshot, error) {
|
|
return nil, service.ErrVulnerabilitySnapshotNotFound
|
|
},
|
|
expectedStatus: http.StatusNotFound,
|
|
checkResponse: nil,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockSvc := &MockVulnerabilitySnapshotService{
|
|
GetByIDFunc: tt.mockFunc,
|
|
}
|
|
|
|
router := gin.New()
|
|
router.GET("/api/vulnerability-snapshots/:id/", func(c *gin.Context) {
|
|
idStr := c.Param("id")
|
|
if idStr == "invalid" {
|
|
dto.BadRequest(c, "Invalid vulnerability snapshot ID")
|
|
return
|
|
}
|
|
|
|
var id int
|
|
switch idStr {
|
|
case "1":
|
|
id = 1
|
|
case "999":
|
|
id = 999
|
|
}
|
|
|
|
snapshot, err := mockSvc.GetByID(id)
|
|
if err != nil {
|
|
if err == service.ErrVulnerabilitySnapshotNotFound {
|
|
dto.NotFound(c, "Vulnerability snapshot not found")
|
|
return
|
|
}
|
|
dto.InternalError(c, "Failed to get vulnerability snapshot")
|
|
return
|
|
}
|
|
|
|
dto.OK(c, toVulnerabilitySnapshotResponse(snapshot))
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/api/vulnerability-snapshots/"+tt.id+"/", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != tt.expectedStatus {
|
|
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
|
|
}
|
|
|
|
if tt.checkResponse != nil {
|
|
tt.checkResponse(t, w.Body.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestVulnerabilitySnapshotPaginationProperties tests pagination correctness
|
|
func TestVulnerabilitySnapshotPaginationProperties(t *testing.T) {
|
|
// Property: totalPages = ceil(total / pageSize)
|
|
tests := []struct {
|
|
total int64
|
|
pageSize int
|
|
wantPages int
|
|
}{
|
|
{total: 0, pageSize: 20, wantPages: 0},
|
|
{total: 1, pageSize: 20, wantPages: 1},
|
|
{total: 20, pageSize: 20, wantPages: 1},
|
|
{total: 21, pageSize: 20, wantPages: 2},
|
|
{total: 100, pageSize: 10, wantPages: 10},
|
|
{total: 101, pageSize: 10, wantPages: 11},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
totalPages := int(tt.total) / tt.pageSize
|
|
if int(tt.total)%tt.pageSize > 0 {
|
|
totalPages++
|
|
}
|
|
if tt.total == 0 {
|
|
totalPages = 0
|
|
}
|
|
|
|
if totalPages != tt.wantPages {
|
|
t.Errorf("total=%d, pageSize=%d: expected totalPages=%d, got %d",
|
|
tt.total, tt.pageSize, tt.wantPages, totalPages)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestVulnerabilitySnapshotFilterProperties tests filter correctness
|
|
func TestVulnerabilitySnapshotFilterProperties(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
filterTests := []string{
|
|
"",
|
|
"XSS",
|
|
"SQLi",
|
|
}
|
|
|
|
for _, filter := range filterTests {
|
|
t.Run("filter_"+filter, func(t *testing.T) {
|
|
var receivedFilter string
|
|
mockSvc := &MockVulnerabilitySnapshotService{
|
|
ListAllFunc: func(query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
|
|
receivedFilter = query.Filter
|
|
return nil, 0, nil
|
|
},
|
|
}
|
|
|
|
router := gin.New()
|
|
router.GET("/api/vulnerability-snapshots/", func(c *gin.Context) {
|
|
var query dto.VulnerabilitySnapshotListQuery
|
|
_ = c.ShouldBindQuery(&query)
|
|
_, _, _ = mockSvc.ListAll(&query)
|
|
dto.Paginated(c, []dto.VulnerabilitySnapshotResponse{}, 0, 1, 20)
|
|
})
|
|
|
|
url := "/api/vulnerability-snapshots/"
|
|
if filter != "" {
|
|
url += "?filter=" + filter
|
|
}
|
|
req := httptest.NewRequest(http.MethodGet, url, nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
if receivedFilter != filter {
|
|
t.Errorf("expected filter %q, got %q", filter, receivedFilter)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestVulnerabilitySnapshotSeverityFilterProperties tests severity filter correctness
|
|
func TestVulnerabilitySnapshotSeverityFilterProperties(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
severityTests := []string{
|
|
"",
|
|
"unknown",
|
|
"info",
|
|
"low",
|
|
"medium",
|
|
"high",
|
|
"critical",
|
|
}
|
|
|
|
for _, severity := range severityTests {
|
|
t.Run("severity_"+severity, func(t *testing.T) {
|
|
var receivedSeverity string
|
|
mockSvc := &MockVulnerabilitySnapshotService{
|
|
ListByScanFunc: func(scanID int, query *dto.VulnerabilitySnapshotListQuery) ([]model.VulnerabilitySnapshot, int64, error) {
|
|
receivedSeverity = query.Severity
|
|
return nil, 0, nil
|
|
},
|
|
}
|
|
|
|
router := gin.New()
|
|
router.GET("/api/scans/:id/vulnerabilities/", func(c *gin.Context) {
|
|
var query dto.VulnerabilitySnapshotListQuery
|
|
_ = c.ShouldBindQuery(&query)
|
|
_, _, _ = mockSvc.ListByScan(1, &query)
|
|
dto.Paginated(c, []dto.VulnerabilitySnapshotResponse{}, 0, 1, 20)
|
|
})
|
|
|
|
url := "/api/scans/1/vulnerabilities/"
|
|
if severity != "" {
|
|
url += "?severity=" + severity
|
|
}
|
|
req := httptest.NewRequest(http.MethodGet, url, nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
if receivedSeverity != severity {
|
|
t.Errorf("expected severity %q, got %q", severity, receivedSeverity)
|
|
}
|
|
})
|
|
}
|
|
}
|