diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 236ca3d6d..0c5573519 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -254,8 +254,23 @@ func New(options *types.Options) (*Runner, error) { os.Exit(0) } + tmpDir, err := os.MkdirTemp("", "nuclei-tmp-*") + if err != nil { + return nil, errors.Wrap(err, "could not create temporary directory") + } + runner.tmpDir = tmpDir + + // Cleanup tmpDir only if initialization fails + // On successful initialization, Close() method will handle cleanup + cleanupOnError := true + defer func() { + if cleanupOnError && runner.tmpDir != "" { + _ = os.RemoveAll(runner.tmpDir) + } + }() + // create the input provider and load the inputs - inputProvider, err := provider.NewInputProvider(provider.InputOptions{Options: options}) + inputProvider, err := provider.NewInputProvider(provider.InputOptions{Options: options, TempDir: runner.tmpDir}) if err != nil { return nil, errors.Wrap(err, "could not create input provider") } @@ -386,10 +401,8 @@ func New(options *types.Options) (*Runner, error) { } runner.rateLimiter = utils.GetRateLimiter(context.Background(), options.RateLimit, options.RateLimitDuration) - if tmpDir, err := os.MkdirTemp("", "nuclei-tmp-*"); err == nil { - runner.tmpDir = tmpDir - } - + // Initialization successful, disable cleanup on error + cleanupOnError = false return runner, nil } diff --git a/pkg/input/formats/formats.go b/pkg/input/formats/formats.go index c7798286a..9de4d0d01 100644 --- a/pkg/input/formats/formats.go +++ b/pkg/input/formats/formats.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/projectdiscovery/nuclei/v3/pkg/input/types" + "github.com/projectdiscovery/retryablehttp-go" fileutil "github.com/projectdiscovery/utils/file" "gopkg.in/yaml.v3" ) @@ -47,6 +48,16 @@ type Format interface { SetOptions(options InputFormatOptions) } +// SpecDownloader is an interface for downloading API specifications from URLs +type SpecDownloader interface { + // Download downloads the spec from the given URL and saves it to tmpDir + // Returns the path to the downloaded file + // httpClient is a retryablehttp.Client instance (can be nil for fallback) + Download(url, tmpDir string, httpClient *retryablehttp.Client) (string, error) + // SupportedExtensions returns the list of supported file extensions + SupportedExtensions() []string +} + var ( DefaultVarDumpFileName = "required_openapi_params.yaml" ErrNoVarsDumpFile = errors.New("no required params file found") diff --git a/pkg/input/formats/openapi/downloader.go b/pkg/input/formats/openapi/downloader.go new file mode 100644 index 000000000..955fdc50c --- /dev/null +++ b/pkg/input/formats/openapi/downloader.go @@ -0,0 +1,136 @@ +package openapi + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/projectdiscovery/nuclei/v3/pkg/input/formats" + "github.com/projectdiscovery/retryablehttp-go" +) + +// OpenAPIDownloader implements the SpecDownloader interface for OpenAPI 3.0 specs +type OpenAPIDownloader struct{} + +// NewDownloader creates a new OpenAPI downloader +func NewDownloader() formats.SpecDownloader { + return &OpenAPIDownloader{} +} + +// This function downloads an OpenAPI 3.0 spec from the given URL and saves it to tmpDir +func (d *OpenAPIDownloader) Download(urlStr, tmpDir string, httpClient *retryablehttp.Client) (string, error) { + // Validate URL format, OpenAPI 3.0 specs are typically JSON + if !strings.HasSuffix(urlStr, ".json") { + return "", fmt.Errorf("URL does not appear to be an OpenAPI JSON spec") + } + + const maxSpecSizeBytes = 10 * 1024 * 1024 // 10MB + + // Use provided httpClient or create a fallback + var client *http.Client + if httpClient != nil { + client = httpClient.HTTPClient + } else { + // Fallback to simple client if no httpClient provided + client = &http.Client{Timeout: 30 * time.Second} + } + + resp, err := client.Get(urlStr) + if err != nil { + return "", errors.Wrap(err, "failed to download OpenAPI spec") + } + + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("HTTP %d when downloading OpenAPI spec", resp.StatusCode) + } + + bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxSpecSizeBytes)) + if err != nil { + return "", errors.Wrap(err, "failed to read response body") + } + + // Validate it's a valid JSON and has OpenAPI structure + var spec map[string]interface{} + if err := json.Unmarshal(bodyBytes, &spec); err != nil { + return "", fmt.Errorf("downloaded content is not valid JSON: %w", err) + } + + // Check if it's an OpenAPI 3.0 spec + if openapi, exists := spec["openapi"]; exists { + if openapiStr, ok := openapi.(string); ok && strings.HasPrefix(openapiStr, "3.") { + // Valid OpenAPI 3.0 spec + } else { + return "", fmt.Errorf("not a valid OpenAPI 3.0 spec (found version: %v)", openapi) + } + } else { + return "", fmt.Errorf("not an OpenAPI spec (missing 'openapi' field)") + } + + // Extract host from URL for server configuration + parsedURL, err := url.Parse(urlStr) + if err != nil { + return "", errors.Wrap(err, "failed to parse URL") + } + host := parsedURL.Host + scheme := parsedURL.Scheme + if scheme == "" { + scheme = "https" + } + + // Add servers section if missing or empty + servers, exists := spec["servers"] + if !exists || servers == nil { + spec["servers"] = []map[string]interface{}{{"url": scheme + "://" + host}} + } else if serverList, ok := servers.([]interface{}); ok && len(serverList) == 0 { + spec["servers"] = []map[string]interface{}{{"url": scheme + "://" + host}} + } + + // Marshal back to JSON + modifiedJSON, err := json.Marshal(spec) + if err != nil { + return "", errors.Wrap(err, "failed to marshal modified spec") + } + + // Create output directory + openapiDir := filepath.Join(tmpDir, "openapi") + if err := os.MkdirAll(openapiDir, 0755); err != nil { + return "", errors.Wrap(err, "failed to create openapi directory") + } + + // Generate filename + filename := fmt.Sprintf("openapi-spec-%d.json", time.Now().Unix()) + filePath := filepath.Join(openapiDir, filename) + + // Write file + file, err := os.Create(filePath) + if err != nil { + return "", fmt.Errorf("failed to create file: %w", err) + } + + defer func() { + _ = file.Close() + }() + + if _, writeErr := file.Write(modifiedJSON); writeErr != nil { + _ = os.Remove(filePath) + return "", errors.Wrap(writeErr, "failed to write OpenAPI spec to file") + } + + return filePath, nil +} + +// SupportedExtensions returns the list of supported file extensions for OpenAPI +func (d *OpenAPIDownloader) SupportedExtensions() []string { + return []string{".json"} +} diff --git a/pkg/input/formats/openapi/downloader_test.go b/pkg/input/formats/openapi/downloader_test.go new file mode 100644 index 000000000..10ee93817 --- /dev/null +++ b/pkg/input/formats/openapi/downloader_test.go @@ -0,0 +1,278 @@ +package openapi + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" +) + +func TestOpenAPIDownloader_SupportedExtensions(t *testing.T) { + downloader := &OpenAPIDownloader{} + extensions := downloader.SupportedExtensions() + + expected := []string{".json"} + if len(extensions) != len(expected) { + t.Errorf("Expected %d extensions, got %d", len(expected), len(extensions)) + } + + for i, ext := range extensions { + if ext != expected[i] { + t.Errorf("Expected extension %s, got %s", expected[i], ext) + } + } +} + +func TestOpenAPIDownloader_Download_Success(t *testing.T) { + // Create a mock OpenAPI spec + mockSpec := map[string]interface{}{ + "openapi": "3.0.0", + "info": map[string]interface{}{ + "title": "Test API", + "version": "1.0.0", + }, + "paths": map[string]interface{}{ + "/test": map[string]interface{}{ + "get": map[string]interface{}{ + "summary": "Test endpoint", + }, + }, + }, + } + + // Create mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(mockSpec); err != nil { + http.Error(w, "failed to encode response", http.StatusInternalServerError) + } + })) + defer server.Close() + + // Create temp directory + tmpDir, err := os.MkdirTemp("", "openapi_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + // Test download + downloader := &OpenAPIDownloader{} + filePath, err := downloader.Download(server.URL+"/openapi.json", tmpDir, nil) + if err != nil { + t.Fatalf("Download failed: %v", err) + } + + // Verify file exists + if !fileExists(filePath) { + t.Errorf("Downloaded file does not exist: %s", filePath) + } + + // Verify file content + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + var downloadedSpec map[string]interface{} + if err := json.Unmarshal(content, &downloadedSpec); err != nil { + t.Fatalf("Failed to parse downloaded JSON: %v", err) + } + + // Verify servers field was added + servers, exists := downloadedSpec["servers"] + if !exists { + t.Error("Servers field was not added to the spec") + } + + if serversList, ok := servers.([]interface{}); ok { + if len(serversList) == 0 { + t.Error("Servers list is empty") + } + } else { + t.Error("Servers field is not a list") + } +} + +func TestOpenAPIDownloader_Download_NonJSONURL(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "openapi_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + downloader := &OpenAPIDownloader{} + _, err = downloader.Download("http://example.com/spec.yaml", tmpDir, nil) + if err == nil { + t.Error("Expected error for non-JSON URL, but got none") + } + + if !strings.Contains(err.Error(), "URL does not appear to be an OpenAPI JSON spec") { + t.Errorf("Unexpected error message: %v", err) + } +} + +func TestOpenAPIDownloader_Download_HTTPError(t *testing.T) { + // Create mock server that returns 404 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + tmpDir, err := os.MkdirTemp("", "openapi_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + downloader := &OpenAPIDownloader{} + _, err = downloader.Download(server.URL+"/openapi.json", tmpDir, nil) + if err == nil { + t.Error("Expected error for HTTP 404, but got none") + } +} + +func TestOpenAPIDownloader_Download_InvalidJSON(t *testing.T) { + // Create mock server that returns invalid JSON + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write([]byte("invalid json")); err != nil { + http.Error(w, "failed to write response", http.StatusInternalServerError) + } + })) + defer server.Close() + + tmpDir, err := os.MkdirTemp("", "openapi_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + downloader := &OpenAPIDownloader{} + _, err = downloader.Download(server.URL+"/openapi.json", tmpDir, nil) + if err == nil { + t.Error("Expected error for invalid JSON, but got none") + } +} + +func TestOpenAPIDownloader_Download_Timeout(t *testing.T) { + // Create mock server with delay + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(35 * time.Second) // Longer than 30 second timeout + if err := json.NewEncoder(w).Encode(map[string]interface{}{"test": "data"}); err != nil { + http.Error(w, "failed to encode response", http.StatusInternalServerError) + } + })) + defer server.Close() + + tmpDir, err := os.MkdirTemp("", "openapi_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + downloader := &OpenAPIDownloader{} + _, err = downloader.Download(server.URL+"/openapi.json", tmpDir, nil) + if err == nil { + t.Error("Expected timeout error, but got none") + } +} + +func TestOpenAPIDownloader_Download_WithExistingServers(t *testing.T) { + // Create a mock OpenAPI spec with existing servers + mockSpec := map[string]interface{}{ + "openapi": "3.0.0", + "info": map[string]interface{}{ + "title": "Test API", + "version": "1.0.0", + }, + "servers": []interface{}{ + map[string]interface{}{ + "url": "https://existing-server.com", + }, + }, + "paths": map[string]interface{}{}, + } + + // Create mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(mockSpec); err != nil { + http.Error(w, "failed to encode response", http.StatusInternalServerError) + } + })) + defer server.Close() + + tmpDir, err := os.MkdirTemp("", "openapi_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + downloader := &OpenAPIDownloader{} + filePath, err := downloader.Download(server.URL+"/openapi.json", tmpDir, nil) + if err != nil { + t.Fatalf("Download failed: %v", err) + } + + // Verify existing servers are preserved + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + var downloadedSpec map[string]interface{} + if err := json.Unmarshal(content, &downloadedSpec); err != nil { + t.Fatalf("Failed to parse downloaded JSON: %v", err) + } + + servers, exists := downloadedSpec["servers"] + if !exists { + t.Error("Servers field was removed from the spec") + } + + if serversList, ok := servers.([]interface{}); ok { + if len(serversList) != 1 { + t.Errorf("Expected 1 server, got %d", len(serversList)) + } + } +} + +// Helper function to check if file exists +func fileExists(filename string) bool { + _, err := os.Stat(filename) + return !os.IsNotExist(err) +} diff --git a/pkg/input/formats/swagger/downloader.go b/pkg/input/formats/swagger/downloader.go new file mode 100644 index 000000000..b6b5a333f --- /dev/null +++ b/pkg/input/formats/swagger/downloader.go @@ -0,0 +1,165 @@ +package swagger + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/projectdiscovery/nuclei/v3/pkg/input/formats" + "github.com/projectdiscovery/retryablehttp-go" + "gopkg.in/yaml.v3" +) + +// SwaggerDownloader implements the SpecDownloader interface for Swagger 2.0 specs +type SwaggerDownloader struct{} + +// NewDownloader creates a new Swagger downloader +func NewDownloader() formats.SpecDownloader { + return &SwaggerDownloader{} +} + +// This function downloads a Swagger 2.0 spec from the given URL and saves it to tmpDir +func (d *SwaggerDownloader) Download(urlStr, tmpDir string, httpClient *retryablehttp.Client) (string, error) { + // Swagger can be JSON or YAML + supportedExts := d.SupportedExtensions() + isSupported := false + for _, ext := range supportedExts { + if strings.HasSuffix(urlStr, ext) { + isSupported = true + break + } + } + if !isSupported { + return "", fmt.Errorf("URL does not appear to be a Swagger spec (supported: %v)", supportedExts) + } + + const maxSpecSizeBytes = 10 * 1024 * 1024 // 10MB + + // Use provided httpClient or create a fallback + var client *http.Client + if httpClient != nil { + client = httpClient.HTTPClient + } else { + // Fallback to simple client if no httpClient provided + client = &http.Client{Timeout: 30 * time.Second} + } + + resp, err := client.Get(urlStr) + if err != nil { + return "", errors.Wrap(err, "failed to download Swagger spec") + } + + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("HTTP %d when downloading Swagger spec", resp.StatusCode) + } + + bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxSpecSizeBytes)) + if err != nil { + return "", errors.Wrap(err, "failed to read response body") + } + + // Determine format and parse + var spec map[string]interface{} + var isYAML bool + + // Try JSON first + if err := json.Unmarshal(bodyBytes, &spec); err != nil { + // Then try YAML + if err := yaml.Unmarshal(bodyBytes, &spec); err != nil { + return "", fmt.Errorf("downloaded content is neither valid JSON nor YAML: %w", err) + } + isYAML = true + } + + // Validate it's a Swagger 2.0 spec + if swagger, exists := spec["swagger"]; exists { + if swaggerStr, ok := swagger.(string); ok && strings.HasPrefix(swaggerStr, "2.") { + // Valid Swagger 2.0 spec + } else { + return "", fmt.Errorf("not a valid Swagger 2.0 spec (found version: %v)", swagger) + } + } else { + return "", fmt.Errorf("not a Swagger spec (missing 'swagger' field)") + } + + // Extract host from URL for host configuration + parsedURL, err := url.Parse(urlStr) + if err != nil { + return "", errors.Wrap(err, "failed to parse URL") + } + + host := parsedURL.Host + scheme := parsedURL.Scheme + if scheme == "" { + scheme = "https" + } + + // Add host if missing + if _, exists := spec["host"]; !exists { + spec["host"] = host + } + + // Add schemes if missing + if _, exists := spec["schemes"]; !exists { + spec["schemes"] = []string{scheme} + } + + // Create output directory + swaggerDir := filepath.Join(tmpDir, "swagger") + if err := os.MkdirAll(swaggerDir, 0755); err != nil { + return "", errors.Wrap(err, "failed to create swagger directory") + } + + // Generate filename and content based on original format + var filename string + var content []byte + + if isYAML { + filename = fmt.Sprintf("swagger-spec-%d.yaml", time.Now().Unix()) + content, err = yaml.Marshal(spec) + if err != nil { + return "", errors.Wrap(err, "failed to marshal modified YAML spec") + } + } else { + filename = fmt.Sprintf("swagger-spec-%d.json", time.Now().Unix()) + content, err = json.Marshal(spec) + if err != nil { + return "", errors.Wrap(err, "failed to marshal modified JSON spec") + } + } + + filePath := filepath.Join(swaggerDir, filename) + + // Write file + file, err := os.Create(filePath) + if err != nil { + return "", errors.Wrap(err, "failed to create file") + } + + defer func() { + _ = file.Close() + }() + + if _, writeErr := file.Write(content); writeErr != nil { + _ = os.Remove(filePath) + return "", errors.Wrap(writeErr, "failed to write file") + } + + return filePath, nil +} + +// SupportedExtensions returns the list of supported file extensions for Swagger +func (d *SwaggerDownloader) SupportedExtensions() []string { + return []string{".json", ".yaml", ".yml"} +} diff --git a/pkg/input/formats/swagger/downloader_test.go b/pkg/input/formats/swagger/downloader_test.go new file mode 100644 index 000000000..d55b57395 --- /dev/null +++ b/pkg/input/formats/swagger/downloader_test.go @@ -0,0 +1,359 @@ +package swagger + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "gopkg.in/yaml.v3" +) + +func TestSwaggerDownloader_SupportedExtensions(t *testing.T) { + downloader := &SwaggerDownloader{} + extensions := downloader.SupportedExtensions() + + expected := []string{".json", ".yaml", ".yml"} + if len(extensions) != len(expected) { + t.Errorf("Expected %d extensions, got %d", len(expected), len(extensions)) + } + + for i, ext := range extensions { + if ext != expected[i] { + t.Errorf("Expected extension %s, got %s", expected[i], ext) + } + } +} + +func TestSwaggerDownloader_Download_JSON_Success(t *testing.T) { + // Create a mock Swagger spec (JSON) + mockSpec := map[string]interface{}{ + "swagger": "2.0", + "info": map[string]interface{}{ + "title": "Test API", + "version": "1.0.0", + }, + "paths": map[string]interface{}{ + "/test": map[string]interface{}{ + "get": map[string]interface{}{ + "summary": "Test endpoint", + }, + }, + }, + } + + // Create mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(mockSpec); err != nil { + http.Error(w, "failed to encode response", http.StatusInternalServerError) + } + })) + defer server.Close() + + // Create temp directory + tmpDir, err := os.MkdirTemp("", "swagger_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + // Test download + downloader := &SwaggerDownloader{} + filePath, err := downloader.Download(server.URL+"/swagger.json", tmpDir, nil) + if err != nil { + t.Fatalf("Download failed: %v", err) + } + + // Verify file exists + if !fileExists(filePath) { + t.Errorf("Downloaded file does not exist: %s", filePath) + } + + // Verify file content + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + var downloadedSpec map[string]interface{} + if err := json.Unmarshal(content, &downloadedSpec); err != nil { + t.Fatalf("Failed to parse downloaded JSON: %v", err) + } + + // Verify host field was added + _, exists := downloadedSpec["host"] + if !exists { + t.Error("Host field was not added to the spec") + } +} + +func TestSwaggerDownloader_Download_YAML_Success(t *testing.T) { + // Create a mock Swagger spec (YAML) + mockSpecYAML := ` +swagger: "2.0" +info: + title: "Test API" + version: "1.0.0" +paths: + /test: + get: + summary: "Test endpoint" +` + + // Create mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/yaml") + if _, err := w.Write([]byte(mockSpecYAML)); err != nil { + http.Error(w, "failed to write response", http.StatusInternalServerError) + } + })) + + defer server.Close() + + // Create temp directory + tmpDir, err := os.MkdirTemp("", "swagger_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + // Test download + downloader := &SwaggerDownloader{} + filePath, err := downloader.Download(server.URL+"/swagger.yaml", tmpDir, nil) + if err != nil { + t.Fatalf("Download failed: %v", err) + } + + // Verify file exists + if !fileExists(filePath) { + t.Errorf("Downloaded file does not exist: %s", filePath) + } + + // Verify file content + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + var downloadedSpec map[string]interface{} + if err := yaml.Unmarshal(content, &downloadedSpec); err != nil { + t.Fatalf("Failed to parse downloaded YAML: %v", err) + } + + // Verify host field was added + _, exists := downloadedSpec["host"] + if !exists { + t.Error("Host field was not added to the spec") + } +} + +func TestSwaggerDownloader_Download_UnsupportedExtension(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "swagger_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + downloader := &SwaggerDownloader{} + _, err = downloader.Download("http://example.com/spec.xml", tmpDir, nil) + if err == nil { + t.Error("Expected error for unsupported extension, but got none") + } + + if !strings.Contains(err.Error(), "URL does not appear to be a Swagger spec") { + t.Errorf("Unexpected error message: %v", err) + } +} + +func TestSwaggerDownloader_Download_HTTPError(t *testing.T) { + // Create mock server that returns 404 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + tmpDir, err := os.MkdirTemp("", "swagger_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + downloader := &SwaggerDownloader{} + _, err = downloader.Download(server.URL+"/swagger.json", tmpDir, nil) + if err == nil { + t.Error("Expected error for HTTP 404, but got none") + } +} + +func TestSwaggerDownloader_Download_InvalidJSON(t *testing.T) { + // Create mock server that returns invalid JSON + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write([]byte("invalid json")); err != nil { + http.Error(w, "failed to write response", http.StatusInternalServerError) + } + })) + defer server.Close() + + tmpDir, err := os.MkdirTemp("", "swagger_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + downloader := &SwaggerDownloader{} + _, err = downloader.Download(server.URL+"/swagger.json", tmpDir, nil) + if err == nil { + t.Error("Expected error for invalid JSON, but got none") + } +} + +func TestSwaggerDownloader_Download_InvalidYAML(t *testing.T) { + // Create mock server that returns invalid YAML + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/yaml") + if _, err := w.Write([]byte("invalid: yaml: content: [")); err != nil { + http.Error(w, "failed to write response", http.StatusInternalServerError) + } + })) + defer server.Close() + + tmpDir, err := os.MkdirTemp("", "swagger_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + downloader := &SwaggerDownloader{} + _, err = downloader.Download(server.URL+"/swagger.yaml", tmpDir, nil) + if err == nil { + t.Error("Expected error for invalid YAML, but got none") + } +} + +func TestSwaggerDownloader_Download_Timeout(t *testing.T) { + // Create mock server with delay + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(35 * time.Second) // Longer than 30 second timeout + if err := json.NewEncoder(w).Encode(map[string]interface{}{"test": "data"}); err != nil { + http.Error(w, "failed to encode response", http.StatusInternalServerError) + } + })) + defer server.Close() + + tmpDir, err := os.MkdirTemp("", "swagger_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + downloader := &SwaggerDownloader{} + _, err = downloader.Download(server.URL+"/swagger.json", tmpDir, nil) + if err == nil { + t.Error("Expected timeout error, but got none") + } +} + +func TestSwaggerDownloader_Download_WithExistingHost(t *testing.T) { + // Create a mock Swagger spec with existing host + mockSpec := map[string]interface{}{ + "swagger": "2.0", + "info": map[string]interface{}{ + "title": "Test API", + "version": "1.0.0", + }, + "host": "existing-host.com", + "paths": map[string]interface{}{}, + } + + // Create mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(mockSpec); err != nil { + http.Error(w, "failed to encode response", http.StatusInternalServerError) + } + })) + defer server.Close() + + tmpDir, err := os.MkdirTemp("", "swagger_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + + defer func() { + if err := os.RemoveAll(tmpDir); err != nil { + t.Fatalf("Failed to remove temp dir: %v", err) + } + }() + + downloader := &SwaggerDownloader{} + filePath, err := downloader.Download(server.URL+"/swagger.json", tmpDir, nil) + if err != nil { + t.Fatalf("Download failed: %v", err) + } + + // Verify existing host is preserved + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read downloaded file: %v", err) + } + + var downloadedSpec map[string]interface{} + if err := json.Unmarshal(content, &downloadedSpec); err != nil { + t.Fatalf("Failed to parse downloaded JSON: %v", err) + } + + host, exists := downloadedSpec["host"] + if !exists { + t.Error("Host field was removed from the spec") + } + + if hostStr, ok := host.(string); !ok || hostStr != "existing-host.com" { + t.Errorf("Expected host 'existing-host.com', got '%v'", host) + } +} + +// Helper function to check if file exists +func fileExists(filename string) bool { + _, err := os.Stat(filename) + return !os.IsNotExist(err) +} diff --git a/pkg/input/provider/interface.go b/pkg/input/provider/interface.go index 9e1d09ab2..33cfbee7f 100644 --- a/pkg/input/provider/interface.go +++ b/pkg/input/provider/interface.go @@ -7,12 +7,16 @@ import ( "github.com/projectdiscovery/gologger" "github.com/projectdiscovery/nuclei/v3/pkg/input/formats" + "github.com/projectdiscovery/nuclei/v3/pkg/input/formats/openapi" + "github.com/projectdiscovery/nuclei/v3/pkg/input/formats/swagger" "github.com/projectdiscovery/nuclei/v3/pkg/input/provider/http" "github.com/projectdiscovery/nuclei/v3/pkg/input/provider/list" "github.com/projectdiscovery/nuclei/v3/pkg/input/types" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/generators" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" configTypes "github.com/projectdiscovery/nuclei/v3/pkg/types" + "github.com/projectdiscovery/retryablehttp-go" "github.com/projectdiscovery/utils/errkit" stringsutil "github.com/projectdiscovery/utils/strings" ) @@ -74,6 +78,8 @@ type InputProvider interface { type InputOptions struct { // Options for global config Options *configTypes.Options + // TempDir is the temporary directory for storing files + TempDir string // NotFoundCallback is the callback to call when input is not found // only supported in list input provider NotFoundCallback func(template string) bool @@ -107,20 +113,58 @@ func NewInputProvider(opts InputOptions) (InputProvider, error) { Options: opts.Options, NotFoundCallback: opts.NotFoundCallback, }) - } else { - // use HttpInputProvider - return http.NewHttpInputProvider(&http.HttpMultiFormatOptions{ - InputFile: opts.Options.TargetsFilePath, - InputMode: opts.Options.InputFileMode, - Options: formats.InputFormatOptions{ - Variables: generators.MergeMaps(extraVars, opts.Options.Vars.AsMap()), - SkipFormatValidation: opts.Options.SkipFormatValidation, - RequiredOnly: opts.Options.FormatUseRequiredOnly, - VarsTextTemplating: opts.Options.VarsTextTemplating, - VarsFilePaths: opts.Options.VarsFilePaths, - }, - }) + } else if len(opts.Options.Targets) > 0 && + (strings.EqualFold(opts.Options.InputFileMode, "openapi") || strings.EqualFold(opts.Options.InputFileMode, "swagger")) { + + if len(opts.Options.Targets) > 1 { + return nil, fmt.Errorf("only one target URL is supported in %s input mode", opts.Options.InputFileMode) + } + + target := opts.Options.Targets[0] + if strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://") { + var downloader formats.SpecDownloader + var tempFile string + var err error + + // Get HttpClient from protocolstate if available + var httpClient *retryablehttp.Client + if opts.Options.ExecutionId != "" { + dialers := protocolstate.GetDialersWithId(opts.Options.ExecutionId) + if dialers != nil { + httpClient = dialers.DefaultHTTPClient + } + } + + switch strings.ToLower(opts.Options.InputFileMode) { + case "openapi": + downloader = openapi.NewDownloader() + tempFile, err = downloader.Download(target, opts.TempDir, httpClient) + case "swagger": + downloader = swagger.NewDownloader() + tempFile, err = downloader.Download(target, opts.TempDir, httpClient) + default: + return nil, fmt.Errorf("unsupported input mode: %s", opts.Options.InputFileMode) + } + + if err != nil { + return nil, fmt.Errorf("failed to download %s spec from url %s: %w", opts.Options.InputFileMode, target, err) + } + + opts.Options.TargetsFilePath = tempFile + } } + + return http.NewHttpInputProvider(&http.HttpMultiFormatOptions{ + InputFile: opts.Options.TargetsFilePath, + InputMode: opts.Options.InputFileMode, + Options: formats.InputFormatOptions{ + Variables: generators.MergeMaps(extraVars, opts.Options.Vars.AsMap()), + SkipFormatValidation: opts.Options.SkipFormatValidation, + RequiredOnly: opts.Options.FormatUseRequiredOnly, + VarsTextTemplating: opts.Options.VarsTextTemplating, + VarsFilePaths: opts.Options.VarsFilePaths, + }, + }) } // SupportedInputFormats returns all supported input formats of nuclei