mirror of
https://github.com/projectdiscovery/nuclei.git
synced 2026-01-31 15:53:10 +08:00
Merge pull request #6542 from roiswd/feat-openapi-direct-fuzzing
feat(openapi/swagger): direct fuzzing using target url
This commit is contained in:
@@ -254,8 +254,23 @@ func New(options *types.Options) (*Runner, error) {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "nuclei-tmp-*")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not create temporary directory")
|
||||
}
|
||||
runner.tmpDir = tmpDir
|
||||
|
||||
// Cleanup tmpDir only if initialization fails
|
||||
// On successful initialization, Close() method will handle cleanup
|
||||
cleanupOnError := true
|
||||
defer func() {
|
||||
if cleanupOnError && runner.tmpDir != "" {
|
||||
_ = os.RemoveAll(runner.tmpDir)
|
||||
}
|
||||
}()
|
||||
|
||||
// create the input provider and load the inputs
|
||||
inputProvider, err := provider.NewInputProvider(provider.InputOptions{Options: options})
|
||||
inputProvider, err := provider.NewInputProvider(provider.InputOptions{Options: options, TempDir: runner.tmpDir})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not create input provider")
|
||||
}
|
||||
@@ -386,10 +401,8 @@ func New(options *types.Options) (*Runner, error) {
|
||||
}
|
||||
runner.rateLimiter = utils.GetRateLimiter(context.Background(), options.RateLimit, options.RateLimitDuration)
|
||||
|
||||
if tmpDir, err := os.MkdirTemp("", "nuclei-tmp-*"); err == nil {
|
||||
runner.tmpDir = tmpDir
|
||||
}
|
||||
|
||||
// Initialization successful, disable cleanup on error
|
||||
cleanupOnError = false
|
||||
return runner, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/input/types"
|
||||
"github.com/projectdiscovery/retryablehttp-go"
|
||||
fileutil "github.com/projectdiscovery/utils/file"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -47,6 +48,16 @@ type Format interface {
|
||||
SetOptions(options InputFormatOptions)
|
||||
}
|
||||
|
||||
// SpecDownloader is an interface for downloading API specifications from URLs
|
||||
type SpecDownloader interface {
|
||||
// Download downloads the spec from the given URL and saves it to tmpDir
|
||||
// Returns the path to the downloaded file
|
||||
// httpClient is a retryablehttp.Client instance (can be nil for fallback)
|
||||
Download(url, tmpDir string, httpClient *retryablehttp.Client) (string, error)
|
||||
// SupportedExtensions returns the list of supported file extensions
|
||||
SupportedExtensions() []string
|
||||
}
|
||||
|
||||
var (
|
||||
DefaultVarDumpFileName = "required_openapi_params.yaml"
|
||||
ErrNoVarsDumpFile = errors.New("no required params file found")
|
||||
|
||||
136
pkg/input/formats/openapi/downloader.go
Normal file
136
pkg/input/formats/openapi/downloader.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package openapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats"
|
||||
"github.com/projectdiscovery/retryablehttp-go"
|
||||
)
|
||||
|
||||
// OpenAPIDownloader implements the SpecDownloader interface for OpenAPI 3.0 specs
|
||||
type OpenAPIDownloader struct{}
|
||||
|
||||
// NewDownloader creates a new OpenAPI downloader
|
||||
func NewDownloader() formats.SpecDownloader {
|
||||
return &OpenAPIDownloader{}
|
||||
}
|
||||
|
||||
// This function downloads an OpenAPI 3.0 spec from the given URL and saves it to tmpDir
|
||||
func (d *OpenAPIDownloader) Download(urlStr, tmpDir string, httpClient *retryablehttp.Client) (string, error) {
|
||||
// Validate URL format, OpenAPI 3.0 specs are typically JSON
|
||||
if !strings.HasSuffix(urlStr, ".json") {
|
||||
return "", fmt.Errorf("URL does not appear to be an OpenAPI JSON spec")
|
||||
}
|
||||
|
||||
const maxSpecSizeBytes = 10 * 1024 * 1024 // 10MB
|
||||
|
||||
// Use provided httpClient or create a fallback
|
||||
var client *http.Client
|
||||
if httpClient != nil {
|
||||
client = httpClient.HTTPClient
|
||||
} else {
|
||||
// Fallback to simple client if no httpClient provided
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
resp, err := client.Get(urlStr)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to download OpenAPI spec")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d when downloading OpenAPI spec", resp.StatusCode)
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxSpecSizeBytes))
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to read response body")
|
||||
}
|
||||
|
||||
// Validate it's a valid JSON and has OpenAPI structure
|
||||
var spec map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &spec); err != nil {
|
||||
return "", fmt.Errorf("downloaded content is not valid JSON: %w", err)
|
||||
}
|
||||
|
||||
// Check if it's an OpenAPI 3.0 spec
|
||||
if openapi, exists := spec["openapi"]; exists {
|
||||
if openapiStr, ok := openapi.(string); ok && strings.HasPrefix(openapiStr, "3.") {
|
||||
// Valid OpenAPI 3.0 spec
|
||||
} else {
|
||||
return "", fmt.Errorf("not a valid OpenAPI 3.0 spec (found version: %v)", openapi)
|
||||
}
|
||||
} else {
|
||||
return "", fmt.Errorf("not an OpenAPI spec (missing 'openapi' field)")
|
||||
}
|
||||
|
||||
// Extract host from URL for server configuration
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to parse URL")
|
||||
}
|
||||
host := parsedURL.Host
|
||||
scheme := parsedURL.Scheme
|
||||
if scheme == "" {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
// Add servers section if missing or empty
|
||||
servers, exists := spec["servers"]
|
||||
if !exists || servers == nil {
|
||||
spec["servers"] = []map[string]interface{}{{"url": scheme + "://" + host}}
|
||||
} else if serverList, ok := servers.([]interface{}); ok && len(serverList) == 0 {
|
||||
spec["servers"] = []map[string]interface{}{{"url": scheme + "://" + host}}
|
||||
}
|
||||
|
||||
// Marshal back to JSON
|
||||
modifiedJSON, err := json.Marshal(spec)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to marshal modified spec")
|
||||
}
|
||||
|
||||
// Create output directory
|
||||
openapiDir := filepath.Join(tmpDir, "openapi")
|
||||
if err := os.MkdirAll(openapiDir, 0755); err != nil {
|
||||
return "", errors.Wrap(err, "failed to create openapi directory")
|
||||
}
|
||||
|
||||
// Generate filename
|
||||
filename := fmt.Sprintf("openapi-spec-%d.json", time.Now().Unix())
|
||||
filePath := filepath.Join(openapiDir, filename)
|
||||
|
||||
// Write file
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = file.Close()
|
||||
}()
|
||||
|
||||
if _, writeErr := file.Write(modifiedJSON); writeErr != nil {
|
||||
_ = os.Remove(filePath)
|
||||
return "", errors.Wrap(writeErr, "failed to write OpenAPI spec to file")
|
||||
}
|
||||
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
// SupportedExtensions returns the list of supported file extensions for OpenAPI
|
||||
func (d *OpenAPIDownloader) SupportedExtensions() []string {
|
||||
return []string{".json"}
|
||||
}
|
||||
278
pkg/input/formats/openapi/downloader_test.go
Normal file
278
pkg/input/formats/openapi/downloader_test.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package openapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestOpenAPIDownloader_SupportedExtensions(t *testing.T) {
|
||||
downloader := &OpenAPIDownloader{}
|
||||
extensions := downloader.SupportedExtensions()
|
||||
|
||||
expected := []string{".json"}
|
||||
if len(extensions) != len(expected) {
|
||||
t.Errorf("Expected %d extensions, got %d", len(expected), len(extensions))
|
||||
}
|
||||
|
||||
for i, ext := range extensions {
|
||||
if ext != expected[i] {
|
||||
t.Errorf("Expected extension %s, got %s", expected[i], ext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAPIDownloader_Download_Success(t *testing.T) {
|
||||
// Create a mock OpenAPI spec
|
||||
mockSpec := map[string]interface{}{
|
||||
"openapi": "3.0.0",
|
||||
"info": map[string]interface{}{
|
||||
"title": "Test API",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
"paths": map[string]interface{}{
|
||||
"/test": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"summary": "Test endpoint",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create mock server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(mockSpec); err != nil {
|
||||
http.Error(w, "failed to encode response", http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create temp directory
|
||||
tmpDir, err := os.MkdirTemp("", "openapi_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test download
|
||||
downloader := &OpenAPIDownloader{}
|
||||
filePath, err := downloader.Download(server.URL+"/openapi.json", tmpDir, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Download failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists
|
||||
if !fileExists(filePath) {
|
||||
t.Errorf("Downloaded file does not exist: %s", filePath)
|
||||
}
|
||||
|
||||
// Verify file content
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read downloaded file: %v", err)
|
||||
}
|
||||
|
||||
var downloadedSpec map[string]interface{}
|
||||
if err := json.Unmarshal(content, &downloadedSpec); err != nil {
|
||||
t.Fatalf("Failed to parse downloaded JSON: %v", err)
|
||||
}
|
||||
|
||||
// Verify servers field was added
|
||||
servers, exists := downloadedSpec["servers"]
|
||||
if !exists {
|
||||
t.Error("Servers field was not added to the spec")
|
||||
}
|
||||
|
||||
if serversList, ok := servers.([]interface{}); ok {
|
||||
if len(serversList) == 0 {
|
||||
t.Error("Servers list is empty")
|
||||
}
|
||||
} else {
|
||||
t.Error("Servers field is not a list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAPIDownloader_Download_NonJSONURL(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "openapi_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
downloader := &OpenAPIDownloader{}
|
||||
_, err = downloader.Download("http://example.com/spec.yaml", tmpDir, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-JSON URL, but got none")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "URL does not appear to be an OpenAPI JSON spec") {
|
||||
t.Errorf("Unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAPIDownloader_Download_HTTPError(t *testing.T) {
|
||||
// Create mock server that returns 404
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "openapi_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
downloader := &OpenAPIDownloader{}
|
||||
_, err = downloader.Download(server.URL+"/openapi.json", tmpDir, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for HTTP 404, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAPIDownloader_Download_InvalidJSON(t *testing.T) {
|
||||
// Create mock server that returns invalid JSON
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write([]byte("invalid json")); err != nil {
|
||||
http.Error(w, "failed to write response", http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "openapi_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
downloader := &OpenAPIDownloader{}
|
||||
_, err = downloader.Download(server.URL+"/openapi.json", tmpDir, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid JSON, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAPIDownloader_Download_Timeout(t *testing.T) {
|
||||
// Create mock server with delay
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(35 * time.Second) // Longer than 30 second timeout
|
||||
if err := json.NewEncoder(w).Encode(map[string]interface{}{"test": "data"}); err != nil {
|
||||
http.Error(w, "failed to encode response", http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "openapi_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
downloader := &OpenAPIDownloader{}
|
||||
_, err = downloader.Download(server.URL+"/openapi.json", tmpDir, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAPIDownloader_Download_WithExistingServers(t *testing.T) {
|
||||
// Create a mock OpenAPI spec with existing servers
|
||||
mockSpec := map[string]interface{}{
|
||||
"openapi": "3.0.0",
|
||||
"info": map[string]interface{}{
|
||||
"title": "Test API",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
"servers": []interface{}{
|
||||
map[string]interface{}{
|
||||
"url": "https://existing-server.com",
|
||||
},
|
||||
},
|
||||
"paths": map[string]interface{}{},
|
||||
}
|
||||
|
||||
// Create mock server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(mockSpec); err != nil {
|
||||
http.Error(w, "failed to encode response", http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "openapi_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
downloader := &OpenAPIDownloader{}
|
||||
filePath, err := downloader.Download(server.URL+"/openapi.json", tmpDir, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Download failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify existing servers are preserved
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read downloaded file: %v", err)
|
||||
}
|
||||
|
||||
var downloadedSpec map[string]interface{}
|
||||
if err := json.Unmarshal(content, &downloadedSpec); err != nil {
|
||||
t.Fatalf("Failed to parse downloaded JSON: %v", err)
|
||||
}
|
||||
|
||||
servers, exists := downloadedSpec["servers"]
|
||||
if !exists {
|
||||
t.Error("Servers field was removed from the spec")
|
||||
}
|
||||
|
||||
if serversList, ok := servers.([]interface{}); ok {
|
||||
if len(serversList) != 1 {
|
||||
t.Errorf("Expected 1 server, got %d", len(serversList))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if file exists
|
||||
func fileExists(filename string) bool {
|
||||
_, err := os.Stat(filename)
|
||||
return !os.IsNotExist(err)
|
||||
}
|
||||
165
pkg/input/formats/swagger/downloader.go
Normal file
165
pkg/input/formats/swagger/downloader.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats"
|
||||
"github.com/projectdiscovery/retryablehttp-go"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// SwaggerDownloader implements the SpecDownloader interface for Swagger 2.0 specs
|
||||
type SwaggerDownloader struct{}
|
||||
|
||||
// NewDownloader creates a new Swagger downloader
|
||||
func NewDownloader() formats.SpecDownloader {
|
||||
return &SwaggerDownloader{}
|
||||
}
|
||||
|
||||
// This function downloads a Swagger 2.0 spec from the given URL and saves it to tmpDir
|
||||
func (d *SwaggerDownloader) Download(urlStr, tmpDir string, httpClient *retryablehttp.Client) (string, error) {
|
||||
// Swagger can be JSON or YAML
|
||||
supportedExts := d.SupportedExtensions()
|
||||
isSupported := false
|
||||
for _, ext := range supportedExts {
|
||||
if strings.HasSuffix(urlStr, ext) {
|
||||
isSupported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isSupported {
|
||||
return "", fmt.Errorf("URL does not appear to be a Swagger spec (supported: %v)", supportedExts)
|
||||
}
|
||||
|
||||
const maxSpecSizeBytes = 10 * 1024 * 1024 // 10MB
|
||||
|
||||
// Use provided httpClient or create a fallback
|
||||
var client *http.Client
|
||||
if httpClient != nil {
|
||||
client = httpClient.HTTPClient
|
||||
} else {
|
||||
// Fallback to simple client if no httpClient provided
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
resp, err := client.Get(urlStr)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to download Swagger spec")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d when downloading Swagger spec", resp.StatusCode)
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxSpecSizeBytes))
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to read response body")
|
||||
}
|
||||
|
||||
// Determine format and parse
|
||||
var spec map[string]interface{}
|
||||
var isYAML bool
|
||||
|
||||
// Try JSON first
|
||||
if err := json.Unmarshal(bodyBytes, &spec); err != nil {
|
||||
// Then try YAML
|
||||
if err := yaml.Unmarshal(bodyBytes, &spec); err != nil {
|
||||
return "", fmt.Errorf("downloaded content is neither valid JSON nor YAML: %w", err)
|
||||
}
|
||||
isYAML = true
|
||||
}
|
||||
|
||||
// Validate it's a Swagger 2.0 spec
|
||||
if swagger, exists := spec["swagger"]; exists {
|
||||
if swaggerStr, ok := swagger.(string); ok && strings.HasPrefix(swaggerStr, "2.") {
|
||||
// Valid Swagger 2.0 spec
|
||||
} else {
|
||||
return "", fmt.Errorf("not a valid Swagger 2.0 spec (found version: %v)", swagger)
|
||||
}
|
||||
} else {
|
||||
return "", fmt.Errorf("not a Swagger spec (missing 'swagger' field)")
|
||||
}
|
||||
|
||||
// Extract host from URL for host configuration
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to parse URL")
|
||||
}
|
||||
|
||||
host := parsedURL.Host
|
||||
scheme := parsedURL.Scheme
|
||||
if scheme == "" {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
// Add host if missing
|
||||
if _, exists := spec["host"]; !exists {
|
||||
spec["host"] = host
|
||||
}
|
||||
|
||||
// Add schemes if missing
|
||||
if _, exists := spec["schemes"]; !exists {
|
||||
spec["schemes"] = []string{scheme}
|
||||
}
|
||||
|
||||
// Create output directory
|
||||
swaggerDir := filepath.Join(tmpDir, "swagger")
|
||||
if err := os.MkdirAll(swaggerDir, 0755); err != nil {
|
||||
return "", errors.Wrap(err, "failed to create swagger directory")
|
||||
}
|
||||
|
||||
// Generate filename and content based on original format
|
||||
var filename string
|
||||
var content []byte
|
||||
|
||||
if isYAML {
|
||||
filename = fmt.Sprintf("swagger-spec-%d.yaml", time.Now().Unix())
|
||||
content, err = yaml.Marshal(spec)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to marshal modified YAML spec")
|
||||
}
|
||||
} else {
|
||||
filename = fmt.Sprintf("swagger-spec-%d.json", time.Now().Unix())
|
||||
content, err = json.Marshal(spec)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to marshal modified JSON spec")
|
||||
}
|
||||
}
|
||||
|
||||
filePath := filepath.Join(swaggerDir, filename)
|
||||
|
||||
// Write file
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to create file")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = file.Close()
|
||||
}()
|
||||
|
||||
if _, writeErr := file.Write(content); writeErr != nil {
|
||||
_ = os.Remove(filePath)
|
||||
return "", errors.Wrap(writeErr, "failed to write file")
|
||||
}
|
||||
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
// SupportedExtensions returns the list of supported file extensions for Swagger
|
||||
func (d *SwaggerDownloader) SupportedExtensions() []string {
|
||||
return []string{".json", ".yaml", ".yml"}
|
||||
}
|
||||
359
pkg/input/formats/swagger/downloader_test.go
Normal file
359
pkg/input/formats/swagger/downloader_test.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package swagger
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestSwaggerDownloader_SupportedExtensions(t *testing.T) {
|
||||
downloader := &SwaggerDownloader{}
|
||||
extensions := downloader.SupportedExtensions()
|
||||
|
||||
expected := []string{".json", ".yaml", ".yml"}
|
||||
if len(extensions) != len(expected) {
|
||||
t.Errorf("Expected %d extensions, got %d", len(expected), len(extensions))
|
||||
}
|
||||
|
||||
for i, ext := range extensions {
|
||||
if ext != expected[i] {
|
||||
t.Errorf("Expected extension %s, got %s", expected[i], ext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwaggerDownloader_Download_JSON_Success(t *testing.T) {
|
||||
// Create a mock Swagger spec (JSON)
|
||||
mockSpec := map[string]interface{}{
|
||||
"swagger": "2.0",
|
||||
"info": map[string]interface{}{
|
||||
"title": "Test API",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
"paths": map[string]interface{}{
|
||||
"/test": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"summary": "Test endpoint",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create mock server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(mockSpec); err != nil {
|
||||
http.Error(w, "failed to encode response", http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create temp directory
|
||||
tmpDir, err := os.MkdirTemp("", "swagger_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test download
|
||||
downloader := &SwaggerDownloader{}
|
||||
filePath, err := downloader.Download(server.URL+"/swagger.json", tmpDir, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Download failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists
|
||||
if !fileExists(filePath) {
|
||||
t.Errorf("Downloaded file does not exist: %s", filePath)
|
||||
}
|
||||
|
||||
// Verify file content
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read downloaded file: %v", err)
|
||||
}
|
||||
|
||||
var downloadedSpec map[string]interface{}
|
||||
if err := json.Unmarshal(content, &downloadedSpec); err != nil {
|
||||
t.Fatalf("Failed to parse downloaded JSON: %v", err)
|
||||
}
|
||||
|
||||
// Verify host field was added
|
||||
_, exists := downloadedSpec["host"]
|
||||
if !exists {
|
||||
t.Error("Host field was not added to the spec")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwaggerDownloader_Download_YAML_Success(t *testing.T) {
|
||||
// Create a mock Swagger spec (YAML)
|
||||
mockSpecYAML := `
|
||||
swagger: "2.0"
|
||||
info:
|
||||
title: "Test API"
|
||||
version: "1.0.0"
|
||||
paths:
|
||||
/test:
|
||||
get:
|
||||
summary: "Test endpoint"
|
||||
`
|
||||
|
||||
// Create mock server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/yaml")
|
||||
if _, err := w.Write([]byte(mockSpecYAML)); err != nil {
|
||||
http.Error(w, "failed to write response", http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
defer server.Close()
|
||||
|
||||
// Create temp directory
|
||||
tmpDir, err := os.MkdirTemp("", "swagger_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test download
|
||||
downloader := &SwaggerDownloader{}
|
||||
filePath, err := downloader.Download(server.URL+"/swagger.yaml", tmpDir, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Download failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists
|
||||
if !fileExists(filePath) {
|
||||
t.Errorf("Downloaded file does not exist: %s", filePath)
|
||||
}
|
||||
|
||||
// Verify file content
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read downloaded file: %v", err)
|
||||
}
|
||||
|
||||
var downloadedSpec map[string]interface{}
|
||||
if err := yaml.Unmarshal(content, &downloadedSpec); err != nil {
|
||||
t.Fatalf("Failed to parse downloaded YAML: %v", err)
|
||||
}
|
||||
|
||||
// Verify host field was added
|
||||
_, exists := downloadedSpec["host"]
|
||||
if !exists {
|
||||
t.Error("Host field was not added to the spec")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwaggerDownloader_Download_UnsupportedExtension(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "swagger_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
downloader := &SwaggerDownloader{}
|
||||
_, err = downloader.Download("http://example.com/spec.xml", tmpDir, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for unsupported extension, but got none")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "URL does not appear to be a Swagger spec") {
|
||||
t.Errorf("Unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwaggerDownloader_Download_HTTPError(t *testing.T) {
|
||||
// Create mock server that returns 404
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "swagger_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
downloader := &SwaggerDownloader{}
|
||||
_, err = downloader.Download(server.URL+"/swagger.json", tmpDir, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for HTTP 404, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwaggerDownloader_Download_InvalidJSON(t *testing.T) {
|
||||
// Create mock server that returns invalid JSON
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write([]byte("invalid json")); err != nil {
|
||||
http.Error(w, "failed to write response", http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "swagger_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
downloader := &SwaggerDownloader{}
|
||||
_, err = downloader.Download(server.URL+"/swagger.json", tmpDir, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid JSON, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwaggerDownloader_Download_InvalidYAML(t *testing.T) {
|
||||
// Create mock server that returns invalid YAML
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/yaml")
|
||||
if _, err := w.Write([]byte("invalid: yaml: content: [")); err != nil {
|
||||
http.Error(w, "failed to write response", http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "swagger_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
downloader := &SwaggerDownloader{}
|
||||
_, err = downloader.Download(server.URL+"/swagger.yaml", tmpDir, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid YAML, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwaggerDownloader_Download_Timeout(t *testing.T) {
|
||||
// Create mock server with delay
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(35 * time.Second) // Longer than 30 second timeout
|
||||
if err := json.NewEncoder(w).Encode(map[string]interface{}{"test": "data"}); err != nil {
|
||||
http.Error(w, "failed to encode response", http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "swagger_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
downloader := &SwaggerDownloader{}
|
||||
_, err = downloader.Download(server.URL+"/swagger.json", tmpDir, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error, but got none")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwaggerDownloader_Download_WithExistingHost(t *testing.T) {
|
||||
// Create a mock Swagger spec with existing host
|
||||
mockSpec := map[string]interface{}{
|
||||
"swagger": "2.0",
|
||||
"info": map[string]interface{}{
|
||||
"title": "Test API",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
"host": "existing-host.com",
|
||||
"paths": map[string]interface{}{},
|
||||
}
|
||||
|
||||
// Create mock server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(mockSpec); err != nil {
|
||||
http.Error(w, "failed to encode response", http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "swagger_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(tmpDir); err != nil {
|
||||
t.Fatalf("Failed to remove temp dir: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
downloader := &SwaggerDownloader{}
|
||||
filePath, err := downloader.Download(server.URL+"/swagger.json", tmpDir, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Download failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify existing host is preserved
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read downloaded file: %v", err)
|
||||
}
|
||||
|
||||
var downloadedSpec map[string]interface{}
|
||||
if err := json.Unmarshal(content, &downloadedSpec); err != nil {
|
||||
t.Fatalf("Failed to parse downloaded JSON: %v", err)
|
||||
}
|
||||
|
||||
host, exists := downloadedSpec["host"]
|
||||
if !exists {
|
||||
t.Error("Host field was removed from the spec")
|
||||
}
|
||||
|
||||
if hostStr, ok := host.(string); !ok || hostStr != "existing-host.com" {
|
||||
t.Errorf("Expected host 'existing-host.com', got '%v'", host)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if file exists
|
||||
func fileExists(filename string) bool {
|
||||
_, err := os.Stat(filename)
|
||||
return !os.IsNotExist(err)
|
||||
}
|
||||
@@ -7,12 +7,16 @@ import (
|
||||
|
||||
"github.com/projectdiscovery/gologger"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats/openapi"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/input/formats/swagger"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/input/provider/http"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/input/provider/list"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/input/types"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/generators"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
|
||||
configTypes "github.com/projectdiscovery/nuclei/v3/pkg/types"
|
||||
"github.com/projectdiscovery/retryablehttp-go"
|
||||
"github.com/projectdiscovery/utils/errkit"
|
||||
stringsutil "github.com/projectdiscovery/utils/strings"
|
||||
)
|
||||
@@ -74,6 +78,8 @@ type InputProvider interface {
|
||||
type InputOptions struct {
|
||||
// Options for global config
|
||||
Options *configTypes.Options
|
||||
// TempDir is the temporary directory for storing files
|
||||
TempDir string
|
||||
// NotFoundCallback is the callback to call when input is not found
|
||||
// only supported in list input provider
|
||||
NotFoundCallback func(template string) bool
|
||||
@@ -107,20 +113,58 @@ func NewInputProvider(opts InputOptions) (InputProvider, error) {
|
||||
Options: opts.Options,
|
||||
NotFoundCallback: opts.NotFoundCallback,
|
||||
})
|
||||
} else {
|
||||
// use HttpInputProvider
|
||||
return http.NewHttpInputProvider(&http.HttpMultiFormatOptions{
|
||||
InputFile: opts.Options.TargetsFilePath,
|
||||
InputMode: opts.Options.InputFileMode,
|
||||
Options: formats.InputFormatOptions{
|
||||
Variables: generators.MergeMaps(extraVars, opts.Options.Vars.AsMap()),
|
||||
SkipFormatValidation: opts.Options.SkipFormatValidation,
|
||||
RequiredOnly: opts.Options.FormatUseRequiredOnly,
|
||||
VarsTextTemplating: opts.Options.VarsTextTemplating,
|
||||
VarsFilePaths: opts.Options.VarsFilePaths,
|
||||
},
|
||||
})
|
||||
} else if len(opts.Options.Targets) > 0 &&
|
||||
(strings.EqualFold(opts.Options.InputFileMode, "openapi") || strings.EqualFold(opts.Options.InputFileMode, "swagger")) {
|
||||
|
||||
if len(opts.Options.Targets) > 1 {
|
||||
return nil, fmt.Errorf("only one target URL is supported in %s input mode", opts.Options.InputFileMode)
|
||||
}
|
||||
|
||||
target := opts.Options.Targets[0]
|
||||
if strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://") {
|
||||
var downloader formats.SpecDownloader
|
||||
var tempFile string
|
||||
var err error
|
||||
|
||||
// Get HttpClient from protocolstate if available
|
||||
var httpClient *retryablehttp.Client
|
||||
if opts.Options.ExecutionId != "" {
|
||||
dialers := protocolstate.GetDialersWithId(opts.Options.ExecutionId)
|
||||
if dialers != nil {
|
||||
httpClient = dialers.DefaultHTTPClient
|
||||
}
|
||||
}
|
||||
|
||||
switch strings.ToLower(opts.Options.InputFileMode) {
|
||||
case "openapi":
|
||||
downloader = openapi.NewDownloader()
|
||||
tempFile, err = downloader.Download(target, opts.TempDir, httpClient)
|
||||
case "swagger":
|
||||
downloader = swagger.NewDownloader()
|
||||
tempFile, err = downloader.Download(target, opts.TempDir, httpClient)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported input mode: %s", opts.Options.InputFileMode)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to download %s spec from url %s: %w", opts.Options.InputFileMode, target, err)
|
||||
}
|
||||
|
||||
opts.Options.TargetsFilePath = tempFile
|
||||
}
|
||||
}
|
||||
|
||||
return http.NewHttpInputProvider(&http.HttpMultiFormatOptions{
|
||||
InputFile: opts.Options.TargetsFilePath,
|
||||
InputMode: opts.Options.InputFileMode,
|
||||
Options: formats.InputFormatOptions{
|
||||
Variables: generators.MergeMaps(extraVars, opts.Options.Vars.AsMap()),
|
||||
SkipFormatValidation: opts.Options.SkipFormatValidation,
|
||||
RequiredOnly: opts.Options.FormatUseRequiredOnly,
|
||||
VarsTextTemplating: opts.Options.VarsTextTemplating,
|
||||
VarsFilePaths: opts.Options.VarsFilePaths,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// SupportedInputFormats returns all supported input formats of nuclei
|
||||
|
||||
Reference in New Issue
Block a user