Merge pull request #6542 from roiswd/feat-openapi-direct-fuzzing

feat(openapi/swagger): direct fuzzing using target url
This commit is contained in:
Dogan Can Bakir
2025-11-23 23:35:30 +09:00
committed by GitHub
7 changed files with 1024 additions and 18 deletions

View File

@@ -254,8 +254,23 @@ func New(options *types.Options) (*Runner, error) {
os.Exit(0)
}
tmpDir, err := os.MkdirTemp("", "nuclei-tmp-*")
if err != nil {
return nil, errors.Wrap(err, "could not create temporary directory")
}
runner.tmpDir = tmpDir
// Cleanup tmpDir only if initialization fails
// On successful initialization, Close() method will handle cleanup
cleanupOnError := true
defer func() {
if cleanupOnError && runner.tmpDir != "" {
_ = os.RemoveAll(runner.tmpDir)
}
}()
// create the input provider and load the inputs
inputProvider, err := provider.NewInputProvider(provider.InputOptions{Options: options})
inputProvider, err := provider.NewInputProvider(provider.InputOptions{Options: options, TempDir: runner.tmpDir})
if err != nil {
return nil, errors.Wrap(err, "could not create input provider")
}
@@ -386,10 +401,8 @@ func New(options *types.Options) (*Runner, error) {
}
runner.rateLimiter = utils.GetRateLimiter(context.Background(), options.RateLimit, options.RateLimitDuration)
if tmpDir, err := os.MkdirTemp("", "nuclei-tmp-*"); err == nil {
runner.tmpDir = tmpDir
}
// Initialization successful, disable cleanup on error
cleanupOnError = false
return runner, nil
}

View File

@@ -7,6 +7,7 @@ import (
"strings"
"github.com/projectdiscovery/nuclei/v3/pkg/input/types"
"github.com/projectdiscovery/retryablehttp-go"
fileutil "github.com/projectdiscovery/utils/file"
"gopkg.in/yaml.v3"
)
@@ -47,6 +48,16 @@ type Format interface {
SetOptions(options InputFormatOptions)
}
// SpecDownloader is an interface for downloading API specifications from URLs
type SpecDownloader interface {
// Download downloads the spec from the given URL and saves it to tmpDir
// Returns the path to the downloaded file
// httpClient is a retryablehttp.Client instance (can be nil for fallback)
Download(url, tmpDir string, httpClient *retryablehttp.Client) (string, error)
// SupportedExtensions returns the list of supported file extensions
SupportedExtensions() []string
}
var (
DefaultVarDumpFileName = "required_openapi_params.yaml"
ErrNoVarsDumpFile = errors.New("no required params file found")

View File

@@ -0,0 +1,136 @@
package openapi
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/pkg/errors"
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats"
"github.com/projectdiscovery/retryablehttp-go"
)
// OpenAPIDownloader implements the SpecDownloader interface for OpenAPI 3.0 specs
type OpenAPIDownloader struct{}
// NewDownloader creates a new OpenAPI downloader
func NewDownloader() formats.SpecDownloader {
return &OpenAPIDownloader{}
}
// This function downloads an OpenAPI 3.0 spec from the given URL and saves it to tmpDir
func (d *OpenAPIDownloader) Download(urlStr, tmpDir string, httpClient *retryablehttp.Client) (string, error) {
// Validate URL format, OpenAPI 3.0 specs are typically JSON
if !strings.HasSuffix(urlStr, ".json") {
return "", fmt.Errorf("URL does not appear to be an OpenAPI JSON spec")
}
const maxSpecSizeBytes = 10 * 1024 * 1024 // 10MB
// Use provided httpClient or create a fallback
var client *http.Client
if httpClient != nil {
client = httpClient.HTTPClient
} else {
// Fallback to simple client if no httpClient provided
client = &http.Client{Timeout: 30 * time.Second}
}
resp, err := client.Get(urlStr)
if err != nil {
return "", errors.Wrap(err, "failed to download OpenAPI spec")
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d when downloading OpenAPI spec", resp.StatusCode)
}
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxSpecSizeBytes))
if err != nil {
return "", errors.Wrap(err, "failed to read response body")
}
// Validate it's a valid JSON and has OpenAPI structure
var spec map[string]interface{}
if err := json.Unmarshal(bodyBytes, &spec); err != nil {
return "", fmt.Errorf("downloaded content is not valid JSON: %w", err)
}
// Check if it's an OpenAPI 3.0 spec
if openapi, exists := spec["openapi"]; exists {
if openapiStr, ok := openapi.(string); ok && strings.HasPrefix(openapiStr, "3.") {
// Valid OpenAPI 3.0 spec
} else {
return "", fmt.Errorf("not a valid OpenAPI 3.0 spec (found version: %v)", openapi)
}
} else {
return "", fmt.Errorf("not an OpenAPI spec (missing 'openapi' field)")
}
// Extract host from URL for server configuration
parsedURL, err := url.Parse(urlStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse URL")
}
host := parsedURL.Host
scheme := parsedURL.Scheme
if scheme == "" {
scheme = "https"
}
// Add servers section if missing or empty
servers, exists := spec["servers"]
if !exists || servers == nil {
spec["servers"] = []map[string]interface{}{{"url": scheme + "://" + host}}
} else if serverList, ok := servers.([]interface{}); ok && len(serverList) == 0 {
spec["servers"] = []map[string]interface{}{{"url": scheme + "://" + host}}
}
// Marshal back to JSON
modifiedJSON, err := json.Marshal(spec)
if err != nil {
return "", errors.Wrap(err, "failed to marshal modified spec")
}
// Create output directory
openapiDir := filepath.Join(tmpDir, "openapi")
if err := os.MkdirAll(openapiDir, 0755); err != nil {
return "", errors.Wrap(err, "failed to create openapi directory")
}
// Generate filename
filename := fmt.Sprintf("openapi-spec-%d.json", time.Now().Unix())
filePath := filepath.Join(openapiDir, filename)
// Write file
file, err := os.Create(filePath)
if err != nil {
return "", fmt.Errorf("failed to create file: %w", err)
}
defer func() {
_ = file.Close()
}()
if _, writeErr := file.Write(modifiedJSON); writeErr != nil {
_ = os.Remove(filePath)
return "", errors.Wrap(writeErr, "failed to write OpenAPI spec to file")
}
return filePath, nil
}
// SupportedExtensions returns the list of supported file extensions for OpenAPI
func (d *OpenAPIDownloader) SupportedExtensions() []string {
return []string{".json"}
}

View File

@@ -0,0 +1,278 @@
package openapi
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
)
func TestOpenAPIDownloader_SupportedExtensions(t *testing.T) {
downloader := &OpenAPIDownloader{}
extensions := downloader.SupportedExtensions()
expected := []string{".json"}
if len(extensions) != len(expected) {
t.Errorf("Expected %d extensions, got %d", len(expected), len(extensions))
}
for i, ext := range extensions {
if ext != expected[i] {
t.Errorf("Expected extension %s, got %s", expected[i], ext)
}
}
}
func TestOpenAPIDownloader_Download_Success(t *testing.T) {
// Create a mock OpenAPI spec
mockSpec := map[string]interface{}{
"openapi": "3.0.0",
"info": map[string]interface{}{
"title": "Test API",
"version": "1.0.0",
},
"paths": map[string]interface{}{
"/test": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Test endpoint",
},
},
},
}
// Create mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(mockSpec); err != nil {
http.Error(w, "failed to encode response", http.StatusInternalServerError)
}
}))
defer server.Close()
// Create temp directory
tmpDir, err := os.MkdirTemp("", "openapi_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
// Test download
downloader := &OpenAPIDownloader{}
filePath, err := downloader.Download(server.URL+"/openapi.json", tmpDir, nil)
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Verify file exists
if !fileExists(filePath) {
t.Errorf("Downloaded file does not exist: %s", filePath)
}
// Verify file content
content, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("Failed to read downloaded file: %v", err)
}
var downloadedSpec map[string]interface{}
if err := json.Unmarshal(content, &downloadedSpec); err != nil {
t.Fatalf("Failed to parse downloaded JSON: %v", err)
}
// Verify servers field was added
servers, exists := downloadedSpec["servers"]
if !exists {
t.Error("Servers field was not added to the spec")
}
if serversList, ok := servers.([]interface{}); ok {
if len(serversList) == 0 {
t.Error("Servers list is empty")
}
} else {
t.Error("Servers field is not a list")
}
}
func TestOpenAPIDownloader_Download_NonJSONURL(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "openapi_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
downloader := &OpenAPIDownloader{}
_, err = downloader.Download("http://example.com/spec.yaml", tmpDir, nil)
if err == nil {
t.Error("Expected error for non-JSON URL, but got none")
}
if !strings.Contains(err.Error(), "URL does not appear to be an OpenAPI JSON spec") {
t.Errorf("Unexpected error message: %v", err)
}
}
func TestOpenAPIDownloader_Download_HTTPError(t *testing.T) {
// Create mock server that returns 404
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
tmpDir, err := os.MkdirTemp("", "openapi_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
downloader := &OpenAPIDownloader{}
_, err = downloader.Download(server.URL+"/openapi.json", tmpDir, nil)
if err == nil {
t.Error("Expected error for HTTP 404, but got none")
}
}
func TestOpenAPIDownloader_Download_InvalidJSON(t *testing.T) {
// Create mock server that returns invalid JSON
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write([]byte("invalid json")); err != nil {
http.Error(w, "failed to write response", http.StatusInternalServerError)
}
}))
defer server.Close()
tmpDir, err := os.MkdirTemp("", "openapi_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
downloader := &OpenAPIDownloader{}
_, err = downloader.Download(server.URL+"/openapi.json", tmpDir, nil)
if err == nil {
t.Error("Expected error for invalid JSON, but got none")
}
}
func TestOpenAPIDownloader_Download_Timeout(t *testing.T) {
// Create mock server with delay
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(35 * time.Second) // Longer than 30 second timeout
if err := json.NewEncoder(w).Encode(map[string]interface{}{"test": "data"}); err != nil {
http.Error(w, "failed to encode response", http.StatusInternalServerError)
}
}))
defer server.Close()
tmpDir, err := os.MkdirTemp("", "openapi_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
downloader := &OpenAPIDownloader{}
_, err = downloader.Download(server.URL+"/openapi.json", tmpDir, nil)
if err == nil {
t.Error("Expected timeout error, but got none")
}
}
func TestOpenAPIDownloader_Download_WithExistingServers(t *testing.T) {
// Create a mock OpenAPI spec with existing servers
mockSpec := map[string]interface{}{
"openapi": "3.0.0",
"info": map[string]interface{}{
"title": "Test API",
"version": "1.0.0",
},
"servers": []interface{}{
map[string]interface{}{
"url": "https://existing-server.com",
},
},
"paths": map[string]interface{}{},
}
// Create mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(mockSpec); err != nil {
http.Error(w, "failed to encode response", http.StatusInternalServerError)
}
}))
defer server.Close()
tmpDir, err := os.MkdirTemp("", "openapi_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
downloader := &OpenAPIDownloader{}
filePath, err := downloader.Download(server.URL+"/openapi.json", tmpDir, nil)
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Verify existing servers are preserved
content, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("Failed to read downloaded file: %v", err)
}
var downloadedSpec map[string]interface{}
if err := json.Unmarshal(content, &downloadedSpec); err != nil {
t.Fatalf("Failed to parse downloaded JSON: %v", err)
}
servers, exists := downloadedSpec["servers"]
if !exists {
t.Error("Servers field was removed from the spec")
}
if serversList, ok := servers.([]interface{}); ok {
if len(serversList) != 1 {
t.Errorf("Expected 1 server, got %d", len(serversList))
}
}
}
// Helper function to check if file exists
func fileExists(filename string) bool {
_, err := os.Stat(filename)
return !os.IsNotExist(err)
}

View File

@@ -0,0 +1,165 @@
package swagger
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/pkg/errors"
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats"
"github.com/projectdiscovery/retryablehttp-go"
"gopkg.in/yaml.v3"
)
// SwaggerDownloader implements the SpecDownloader interface for Swagger 2.0 specs
type SwaggerDownloader struct{}
// NewDownloader creates a new Swagger downloader
func NewDownloader() formats.SpecDownloader {
return &SwaggerDownloader{}
}
// This function downloads a Swagger 2.0 spec from the given URL and saves it to tmpDir
func (d *SwaggerDownloader) Download(urlStr, tmpDir string, httpClient *retryablehttp.Client) (string, error) {
// Swagger can be JSON or YAML
supportedExts := d.SupportedExtensions()
isSupported := false
for _, ext := range supportedExts {
if strings.HasSuffix(urlStr, ext) {
isSupported = true
break
}
}
if !isSupported {
return "", fmt.Errorf("URL does not appear to be a Swagger spec (supported: %v)", supportedExts)
}
const maxSpecSizeBytes = 10 * 1024 * 1024 // 10MB
// Use provided httpClient or create a fallback
var client *http.Client
if httpClient != nil {
client = httpClient.HTTPClient
} else {
// Fallback to simple client if no httpClient provided
client = &http.Client{Timeout: 30 * time.Second}
}
resp, err := client.Get(urlStr)
if err != nil {
return "", errors.Wrap(err, "failed to download Swagger spec")
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d when downloading Swagger spec", resp.StatusCode)
}
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxSpecSizeBytes))
if err != nil {
return "", errors.Wrap(err, "failed to read response body")
}
// Determine format and parse
var spec map[string]interface{}
var isYAML bool
// Try JSON first
if err := json.Unmarshal(bodyBytes, &spec); err != nil {
// Then try YAML
if err := yaml.Unmarshal(bodyBytes, &spec); err != nil {
return "", fmt.Errorf("downloaded content is neither valid JSON nor YAML: %w", err)
}
isYAML = true
}
// Validate it's a Swagger 2.0 spec
if swagger, exists := spec["swagger"]; exists {
if swaggerStr, ok := swagger.(string); ok && strings.HasPrefix(swaggerStr, "2.") {
// Valid Swagger 2.0 spec
} else {
return "", fmt.Errorf("not a valid Swagger 2.0 spec (found version: %v)", swagger)
}
} else {
return "", fmt.Errorf("not a Swagger spec (missing 'swagger' field)")
}
// Extract host from URL for host configuration
parsedURL, err := url.Parse(urlStr)
if err != nil {
return "", errors.Wrap(err, "failed to parse URL")
}
host := parsedURL.Host
scheme := parsedURL.Scheme
if scheme == "" {
scheme = "https"
}
// Add host if missing
if _, exists := spec["host"]; !exists {
spec["host"] = host
}
// Add schemes if missing
if _, exists := spec["schemes"]; !exists {
spec["schemes"] = []string{scheme}
}
// Create output directory
swaggerDir := filepath.Join(tmpDir, "swagger")
if err := os.MkdirAll(swaggerDir, 0755); err != nil {
return "", errors.Wrap(err, "failed to create swagger directory")
}
// Generate filename and content based on original format
var filename string
var content []byte
if isYAML {
filename = fmt.Sprintf("swagger-spec-%d.yaml", time.Now().Unix())
content, err = yaml.Marshal(spec)
if err != nil {
return "", errors.Wrap(err, "failed to marshal modified YAML spec")
}
} else {
filename = fmt.Sprintf("swagger-spec-%d.json", time.Now().Unix())
content, err = json.Marshal(spec)
if err != nil {
return "", errors.Wrap(err, "failed to marshal modified JSON spec")
}
}
filePath := filepath.Join(swaggerDir, filename)
// Write file
file, err := os.Create(filePath)
if err != nil {
return "", errors.Wrap(err, "failed to create file")
}
defer func() {
_ = file.Close()
}()
if _, writeErr := file.Write(content); writeErr != nil {
_ = os.Remove(filePath)
return "", errors.Wrap(writeErr, "failed to write file")
}
return filePath, nil
}
// SupportedExtensions returns the list of supported file extensions for Swagger
func (d *SwaggerDownloader) SupportedExtensions() []string {
return []string{".json", ".yaml", ".yml"}
}

View File

@@ -0,0 +1,359 @@
package swagger
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"gopkg.in/yaml.v3"
)
func TestSwaggerDownloader_SupportedExtensions(t *testing.T) {
downloader := &SwaggerDownloader{}
extensions := downloader.SupportedExtensions()
expected := []string{".json", ".yaml", ".yml"}
if len(extensions) != len(expected) {
t.Errorf("Expected %d extensions, got %d", len(expected), len(extensions))
}
for i, ext := range extensions {
if ext != expected[i] {
t.Errorf("Expected extension %s, got %s", expected[i], ext)
}
}
}
func TestSwaggerDownloader_Download_JSON_Success(t *testing.T) {
// Create a mock Swagger spec (JSON)
mockSpec := map[string]interface{}{
"swagger": "2.0",
"info": map[string]interface{}{
"title": "Test API",
"version": "1.0.0",
},
"paths": map[string]interface{}{
"/test": map[string]interface{}{
"get": map[string]interface{}{
"summary": "Test endpoint",
},
},
},
}
// Create mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(mockSpec); err != nil {
http.Error(w, "failed to encode response", http.StatusInternalServerError)
}
}))
defer server.Close()
// Create temp directory
tmpDir, err := os.MkdirTemp("", "swagger_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
// Test download
downloader := &SwaggerDownloader{}
filePath, err := downloader.Download(server.URL+"/swagger.json", tmpDir, nil)
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Verify file exists
if !fileExists(filePath) {
t.Errorf("Downloaded file does not exist: %s", filePath)
}
// Verify file content
content, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("Failed to read downloaded file: %v", err)
}
var downloadedSpec map[string]interface{}
if err := json.Unmarshal(content, &downloadedSpec); err != nil {
t.Fatalf("Failed to parse downloaded JSON: %v", err)
}
// Verify host field was added
_, exists := downloadedSpec["host"]
if !exists {
t.Error("Host field was not added to the spec")
}
}
func TestSwaggerDownloader_Download_YAML_Success(t *testing.T) {
// Create a mock Swagger spec (YAML)
mockSpecYAML := `
swagger: "2.0"
info:
title: "Test API"
version: "1.0.0"
paths:
/test:
get:
summary: "Test endpoint"
`
// Create mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/yaml")
if _, err := w.Write([]byte(mockSpecYAML)); err != nil {
http.Error(w, "failed to write response", http.StatusInternalServerError)
}
}))
defer server.Close()
// Create temp directory
tmpDir, err := os.MkdirTemp("", "swagger_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
// Test download
downloader := &SwaggerDownloader{}
filePath, err := downloader.Download(server.URL+"/swagger.yaml", tmpDir, nil)
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Verify file exists
if !fileExists(filePath) {
t.Errorf("Downloaded file does not exist: %s", filePath)
}
// Verify file content
content, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("Failed to read downloaded file: %v", err)
}
var downloadedSpec map[string]interface{}
if err := yaml.Unmarshal(content, &downloadedSpec); err != nil {
t.Fatalf("Failed to parse downloaded YAML: %v", err)
}
// Verify host field was added
_, exists := downloadedSpec["host"]
if !exists {
t.Error("Host field was not added to the spec")
}
}
func TestSwaggerDownloader_Download_UnsupportedExtension(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "swagger_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
downloader := &SwaggerDownloader{}
_, err = downloader.Download("http://example.com/spec.xml", tmpDir, nil)
if err == nil {
t.Error("Expected error for unsupported extension, but got none")
}
if !strings.Contains(err.Error(), "URL does not appear to be a Swagger spec") {
t.Errorf("Unexpected error message: %v", err)
}
}
func TestSwaggerDownloader_Download_HTTPError(t *testing.T) {
// Create mock server that returns 404
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
tmpDir, err := os.MkdirTemp("", "swagger_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
downloader := &SwaggerDownloader{}
_, err = downloader.Download(server.URL+"/swagger.json", tmpDir, nil)
if err == nil {
t.Error("Expected error for HTTP 404, but got none")
}
}
func TestSwaggerDownloader_Download_InvalidJSON(t *testing.T) {
// Create mock server that returns invalid JSON
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write([]byte("invalid json")); err != nil {
http.Error(w, "failed to write response", http.StatusInternalServerError)
}
}))
defer server.Close()
tmpDir, err := os.MkdirTemp("", "swagger_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
downloader := &SwaggerDownloader{}
_, err = downloader.Download(server.URL+"/swagger.json", tmpDir, nil)
if err == nil {
t.Error("Expected error for invalid JSON, but got none")
}
}
func TestSwaggerDownloader_Download_InvalidYAML(t *testing.T) {
// Create mock server that returns invalid YAML
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/yaml")
if _, err := w.Write([]byte("invalid: yaml: content: [")); err != nil {
http.Error(w, "failed to write response", http.StatusInternalServerError)
}
}))
defer server.Close()
tmpDir, err := os.MkdirTemp("", "swagger_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
downloader := &SwaggerDownloader{}
_, err = downloader.Download(server.URL+"/swagger.yaml", tmpDir, nil)
if err == nil {
t.Error("Expected error for invalid YAML, but got none")
}
}
func TestSwaggerDownloader_Download_Timeout(t *testing.T) {
// Create mock server with delay
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(35 * time.Second) // Longer than 30 second timeout
if err := json.NewEncoder(w).Encode(map[string]interface{}{"test": "data"}); err != nil {
http.Error(w, "failed to encode response", http.StatusInternalServerError)
}
}))
defer server.Close()
tmpDir, err := os.MkdirTemp("", "swagger_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
downloader := &SwaggerDownloader{}
_, err = downloader.Download(server.URL+"/swagger.json", tmpDir, nil)
if err == nil {
t.Error("Expected timeout error, but got none")
}
}
func TestSwaggerDownloader_Download_WithExistingHost(t *testing.T) {
// Create a mock Swagger spec with existing host
mockSpec := map[string]interface{}{
"swagger": "2.0",
"info": map[string]interface{}{
"title": "Test API",
"version": "1.0.0",
},
"host": "existing-host.com",
"paths": map[string]interface{}{},
}
// Create mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(mockSpec); err != nil {
http.Error(w, "failed to encode response", http.StatusInternalServerError)
}
}))
defer server.Close()
tmpDir, err := os.MkdirTemp("", "swagger_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer func() {
if err := os.RemoveAll(tmpDir); err != nil {
t.Fatalf("Failed to remove temp dir: %v", err)
}
}()
downloader := &SwaggerDownloader{}
filePath, err := downloader.Download(server.URL+"/swagger.json", tmpDir, nil)
if err != nil {
t.Fatalf("Download failed: %v", err)
}
// Verify existing host is preserved
content, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("Failed to read downloaded file: %v", err)
}
var downloadedSpec map[string]interface{}
if err := json.Unmarshal(content, &downloadedSpec); err != nil {
t.Fatalf("Failed to parse downloaded JSON: %v", err)
}
host, exists := downloadedSpec["host"]
if !exists {
t.Error("Host field was removed from the spec")
}
if hostStr, ok := host.(string); !ok || hostStr != "existing-host.com" {
t.Errorf("Expected host 'existing-host.com', got '%v'", host)
}
}
// Helper function to check if file exists
func fileExists(filename string) bool {
_, err := os.Stat(filename)
return !os.IsNotExist(err)
}

View File

@@ -7,12 +7,16 @@ import (
"github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats"
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats/openapi"
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats/swagger"
"github.com/projectdiscovery/nuclei/v3/pkg/input/provider/http"
"github.com/projectdiscovery/nuclei/v3/pkg/input/provider/list"
"github.com/projectdiscovery/nuclei/v3/pkg/input/types"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/generators"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
configTypes "github.com/projectdiscovery/nuclei/v3/pkg/types"
"github.com/projectdiscovery/retryablehttp-go"
"github.com/projectdiscovery/utils/errkit"
stringsutil "github.com/projectdiscovery/utils/strings"
)
@@ -74,6 +78,8 @@ type InputProvider interface {
type InputOptions struct {
// Options for global config
Options *configTypes.Options
// TempDir is the temporary directory for storing files
TempDir string
// NotFoundCallback is the callback to call when input is not found
// only supported in list input provider
NotFoundCallback func(template string) bool
@@ -107,20 +113,58 @@ func NewInputProvider(opts InputOptions) (InputProvider, error) {
Options: opts.Options,
NotFoundCallback: opts.NotFoundCallback,
})
} else {
// use HttpInputProvider
return http.NewHttpInputProvider(&http.HttpMultiFormatOptions{
InputFile: opts.Options.TargetsFilePath,
InputMode: opts.Options.InputFileMode,
Options: formats.InputFormatOptions{
Variables: generators.MergeMaps(extraVars, opts.Options.Vars.AsMap()),
SkipFormatValidation: opts.Options.SkipFormatValidation,
RequiredOnly: opts.Options.FormatUseRequiredOnly,
VarsTextTemplating: opts.Options.VarsTextTemplating,
VarsFilePaths: opts.Options.VarsFilePaths,
},
})
} else if len(opts.Options.Targets) > 0 &&
(strings.EqualFold(opts.Options.InputFileMode, "openapi") || strings.EqualFold(opts.Options.InputFileMode, "swagger")) {
if len(opts.Options.Targets) > 1 {
return nil, fmt.Errorf("only one target URL is supported in %s input mode", opts.Options.InputFileMode)
}
target := opts.Options.Targets[0]
if strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://") {
var downloader formats.SpecDownloader
var tempFile string
var err error
// Get HttpClient from protocolstate if available
var httpClient *retryablehttp.Client
if opts.Options.ExecutionId != "" {
dialers := protocolstate.GetDialersWithId(opts.Options.ExecutionId)
if dialers != nil {
httpClient = dialers.DefaultHTTPClient
}
}
switch strings.ToLower(opts.Options.InputFileMode) {
case "openapi":
downloader = openapi.NewDownloader()
tempFile, err = downloader.Download(target, opts.TempDir, httpClient)
case "swagger":
downloader = swagger.NewDownloader()
tempFile, err = downloader.Download(target, opts.TempDir, httpClient)
default:
return nil, fmt.Errorf("unsupported input mode: %s", opts.Options.InputFileMode)
}
if err != nil {
return nil, fmt.Errorf("failed to download %s spec from url %s: %w", opts.Options.InputFileMode, target, err)
}
opts.Options.TargetsFilePath = tempFile
}
}
return http.NewHttpInputProvider(&http.HttpMultiFormatOptions{
InputFile: opts.Options.TargetsFilePath,
InputMode: opts.Options.InputFileMode,
Options: formats.InputFormatOptions{
Variables: generators.MergeMaps(extraVars, opts.Options.Vars.AsMap()),
SkipFormatValidation: opts.Options.SkipFormatValidation,
RequiredOnly: opts.Options.FormatUseRequiredOnly,
VarsTextTemplating: opts.Options.VarsTextTemplating,
VarsFilePaths: opts.Options.VarsFilePaths,
},
})
}
// SupportedInputFormats returns all supported input formats of nuclei