From af7450737acfbf2ffd918d8a566596bcc9519250 Mon Sep 17 00:00:00 2001 From: mzack Date: Wed, 3 Apr 2024 23:06:08 +0200 Subject: [PATCH] making payload concurrency dynamic via direct int change --- pkg/protocols/dns/request.go | 9 +++++++++ pkg/protocols/http/httputils/spm.go | 10 ++++++++++ pkg/protocols/http/request.go | 9 +++++++++ pkg/protocols/javascript/js.go | 9 +++++++++ pkg/protocols/network/request.go | 9 +++++++++ pkg/protocols/protocols.go | 14 ++------------ 6 files changed, 48 insertions(+), 12 deletions(-) diff --git a/pkg/protocols/dns/request.go b/pkg/protocols/dns/request.go index 8aa8c4a58..d4e70e13e 100644 --- a/pkg/protocols/dns/request.go +++ b/pkg/protocols/dns/request.go @@ -62,6 +62,9 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata, variablesMap := request.options.Variables.Evaluate(vars) vars = generators.MergeMaps(vars, variablesMap, request.options.Constants) + // if request threads matches global payload concurrency we follow it + shouldFollowGlobal := request.Threads == request.options.Options.PayloadConcurrency + if request.generator != nil { iterator := request.generator.NewIterator() swg, err := syncutil.New(syncutil.WithSize(request.Threads)) @@ -76,6 +79,12 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata, if !ok { break } + + // resize check point - nop if there are no changes + if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency { + swg.Resize(request.options.Options.PayloadConcurrency) + } + value = generators.MergeMaps(vars, value) swg.Add() go func(newVars map[string]interface{}) { diff --git a/pkg/protocols/http/httputils/spm.go b/pkg/protocols/http/httputils/spm.go index bca6c2ee5..52d13f06f 100644 --- a/pkg/protocols/http/httputils/spm.go +++ b/pkg/protocols/http/httputils/spm.go @@ -143,6 +143,16 @@ func (h *StopAtFirstMatchHandler[T]) Release() { } } +func (h *StopAtFirstMatchHandler[T]) Resize(size int) { + if h.sgPool.Size != size { + h.sgPool.Resize(size) + } +} + +func (h *StopAtFirstMatchHandler[T]) Size() int { + return h.sgPool.Size +} + // Wait waits for all work to be done func (h *StopAtFirstMatchHandler[T]) Wait() { switch h.poolType { diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index 1b27aa7bf..5d64216f6 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -165,6 +165,9 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV // Workers that keeps enqueuing new requests maxWorkers := request.Threads + // if request threads matches global payload concurrency we follow it + shouldFollowGlobal := maxWorkers == request.options.Options.PayloadConcurrency + if protocolstate.IsLowOnMemory() { maxWorkers = protocolstate.GuardThreadsOrDefault(request.Threads) } @@ -198,6 +201,12 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV if !ok { break } + + // resize check point - nop if there are no changes + if shouldFollowGlobal && spmHandler.Size() != request.options.Options.PayloadConcurrency { + spmHandler.Resize(request.options.Options.PayloadConcurrency) + } + ctx := request.newContext(input) generatedHttpRequest, err := generator.Make(ctx, input, inputData, payloads, dynamicValues) if err != nil { diff --git a/pkg/protocols/javascript/js.go b/pkg/protocols/javascript/js.go index 48a1be485..efabd503c 100644 --- a/pkg/protocols/javascript/js.go +++ b/pkg/protocols/javascript/js.go @@ -404,6 +404,9 @@ func (request *Request) executeRequestParallel(ctxParent context.Context, hostPo requestOptions := request.options gotmatches := &atomic.Bool{} + // if request threads matches global payload concurrency we follow it + shouldFollowGlobal := threads == request.options.Options.PayloadConcurrency + sg, _ := syncutil.New(syncutil.WithSize(threads)) if request.generator != nil { @@ -413,6 +416,12 @@ func (request *Request) executeRequestParallel(ctxParent context.Context, hostPo if !ok { break } + + // resize check point - nop if there are no changes + if shouldFollowGlobal && sg.Size != request.options.Options.PayloadConcurrency { + sg.Resize(request.options.Options.PayloadConcurrency) + } + sg.Add() go func() { defer sg.Done() diff --git a/pkg/protocols/network/request.go b/pkg/protocols/network/request.go index ef0ff01c7..146c65707 100644 --- a/pkg/protocols/network/request.go +++ b/pkg/protocols/network/request.go @@ -174,6 +174,9 @@ func (request *Request) executeAddress(variables map[string]interface{}, actualA return err } + // if request threads matches global payload concurrency we follow it + shouldFollowGlobal := request.Threads == request.options.Options.PayloadConcurrency + if request.generator != nil { iterator := request.generator.NewIterator() var multiErr error @@ -188,6 +191,12 @@ func (request *Request) executeAddress(variables map[string]interface{}, actualA if !ok { break } + + // resize check point - nop if there are no changes + if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency { + swg.Resize(request.options.Options.PayloadConcurrency) + } + value = generators.MergeMaps(value, payloads) swg.Add() go func(vars map[string]interface{}) { diff --git a/pkg/protocols/protocols.go b/pkg/protocols/protocols.go index 5ac0a2f0b..6328bc36a 100644 --- a/pkg/protocols/protocols.go +++ b/pkg/protocols/protocols.go @@ -34,9 +34,6 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/types" ) -// Optional Callback to update Thread count in payloads across all requests -type PayloadThreadSetterCallback func(opts *ExecutorOptions, totalRequests, currentThreads int) int - var ( MaxTemplateFileSizeForEncoding = 1024 * 1024 ) @@ -114,10 +111,6 @@ type ExecutorOptions struct { // JsCompiler is abstracted javascript compiler which adds node modules and provides execution // environment for javascript templates JsCompiler *compiler.Compiler - // Optional Callback function to update Thread count in payloads across all protocols - // based on given logic. by default nuclei reverts to using value of `-c` when threads count - // is not specified or is 0 in template - OverrideThreadsCount PayloadThreadSetterCallback // AuthProvider is a provider for auth strategies AuthProvider authprovider.AuthProvider //TemporaryDirectory is the directory to store temporary files @@ -142,14 +135,11 @@ func (eo *ExecutorOptions) RateLimitTake() { // GetThreadsForPayloadRequests returns the number of threads to use as default for // given max-request of payloads func (e *ExecutorOptions) GetThreadsForNPayloadRequests(totalRequests int, currentThreads int) int { - if e.OverrideThreadsCount != nil { - return e.OverrideThreadsCount(e, totalRequests, currentThreads) - } if currentThreads > 0 { return currentThreads - } else { - return e.Options.PayloadConcurrency } + + return e.Options.PayloadConcurrency } // CreateTemplateCtxStore creates template context store (which contains templateCtx for every scan)