diff --git a/pkg/protocols/dns/dns.go b/pkg/protocols/dns/dns.go index d6f462c44..198bb87bd 100644 --- a/pkg/protocols/dns/dns.go +++ b/pkg/protocols/dns/dns.go @@ -185,6 +185,7 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error { func (request *Request) getDnsClient(options *protocols.ExecutorOptions, metadata map[string]interface{}) (*retryabledns.Client, error) { dnsClientOptions := &dnsclientpool.Configuration{ Retries: request.Retries, + Proxy: options.Options.AliveSocksProxy, } if len(request.Resolvers) > 0 { if len(request.Resolvers) > 0 { diff --git a/pkg/protocols/dns/dnsclientpool/clientpool.go b/pkg/protocols/dns/dnsclientpool/clientpool.go index 4f019808f..8eb19b8ba 100644 --- a/pkg/protocols/dns/dnsclientpool/clientpool.go +++ b/pkg/protocols/dns/dnsclientpool/clientpool.go @@ -51,6 +51,8 @@ type Configuration struct { Retries int // Resolvers contains the specific per request resolvers Resolvers []string + // Proxy contains the proxy to use for the dns client + Proxy string } // Hash returns the hash of the configuration to allow client pooling @@ -60,6 +62,8 @@ func (c *Configuration) Hash() string { builder.WriteString(strconv.Itoa(c.Retries)) builder.WriteString("l") builder.WriteString(strings.Join(c.Resolvers, "")) + builder.WriteString("p") + builder.WriteString(c.Proxy) hash := builder.String() return hash } @@ -83,7 +87,11 @@ func Get(options *types.Options, configuration *Configuration) (*retryabledns.Cl } else if len(configuration.Resolvers) > 0 { resolvers = configuration.Resolvers } - client, err := retryabledns.New(resolvers, configuration.Retries) + client, err := retryabledns.NewWithOptions(retryabledns.Options{ + BaseResolvers: resolvers, + MaxRetries: configuration.Retries, + Proxy: options.AliveSocksProxy, + }) if err != nil { return nil, errors.Wrap(err, "could not create dns client") } diff --git a/pkg/protocols/http/http.go b/pkg/protocols/http/http.go index 0b30a7408..7a45fcd0d 100644 --- a/pkg/protocols/http/http.go +++ b/pkg/protocols/http/http.go @@ -11,6 +11,7 @@ import ( json "github.com/json-iterator/go" "github.com/pkg/errors" + "github.com/projectdiscovery/fastdialer/fastdialer" _ "github.com/projectdiscovery/nuclei/v3/pkg/fuzz/analyzers/time" "github.com/projectdiscovery/nuclei/v3/pkg/fuzz" @@ -22,6 +23,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/generators" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/http/httpclientpool" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/network/networkclientpool" httputil "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils/http" "github.com/projectdiscovery/nuclei/v3/pkg/utils/stats" "github.com/projectdiscovery/rawhttp" @@ -144,6 +146,7 @@ type Request struct { generator *generators.PayloadGenerator // optional, only enabled when using payloads httpClient *retryablehttp.Client rawhttpClient *rawhttp.Client + dialer *fastdialer.Dialer // description: | // SelfContained specifies if the request is self-contained. @@ -348,6 +351,15 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error { } request.customHeaders = make(map[string]string) request.httpClient = client + + dialer, err := networkclientpool.Get(options.Options, &networkclientpool.Configuration{ + CustomDialer: options.CustomFastdialer, + }) + if err != nil { + return errors.Wrap(err, "could not get dialer") + } + request.dialer = dialer + request.options = options for _, option := range request.options.Options.CustomHeaders { parts := strings.SplitN(option, ":", 2) diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index 6d8ad3e1d..090de2ed6 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -841,7 +841,7 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ if input.MetaInput.CustomIP != "" { outputEvent["ip"] = input.MetaInput.CustomIP } else { - outputEvent["ip"] = protocolstate.Dialer.GetDialedIP(hostname) + outputEvent["ip"] = request.dialer.GetDialedIP(hostname) // try getting cname request.addCNameIfAvailable(hostname, outputEvent) } @@ -1085,11 +1085,11 @@ func (request *Request) validateNFixEvent(input *contextargs.Context, gr *genera // addCNameIfAvailable adds the cname to the event if available func (request *Request) addCNameIfAvailable(hostname string, outputEvent map[string]interface{}) { - if protocolstate.Dialer == nil { + if request.dialer == nil { return } - data, err := protocolstate.Dialer.GetDNSData(hostname) + data, err := request.dialer.GetDNSData(hostname) if err == nil { switch len(data.CNAME) { case 0: diff --git a/pkg/protocols/network/network.go b/pkg/protocols/network/network.go index 5c072affc..7aba6244a 100644 --- a/pkg/protocols/network/network.go +++ b/pkg/protocols/network/network.go @@ -237,7 +237,9 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error { } // Create a client for the class - client, err := networkclientpool.Get(options.Options, &networkclientpool.Configuration{}) + client, err := networkclientpool.Get(options.Options, &networkclientpool.Configuration{ + CustomDialer: options.CustomFastdialer, + }) if err != nil { return errors.Wrap(err, "could not get network client") } @@ -259,7 +261,3 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error { func (request *Request) Requests() int { return len(request.Address) } - -func (request *Request) SetDialer(dialer *fastdialer.Dialer) { - request.dialer = dialer -} diff --git a/pkg/protocols/network/networkclientpool/clientpool.go b/pkg/protocols/network/networkclientpool/clientpool.go index a67cee296..6293a931e 100644 --- a/pkg/protocols/network/networkclientpool/clientpool.go +++ b/pkg/protocols/network/networkclientpool/clientpool.go @@ -21,7 +21,9 @@ func Init(options *types.Options) error { } // Configuration contains the custom configuration options for a client -type Configuration struct{} +type Configuration struct { + CustomDialer *fastdialer.Dialer +} // Hash returns the hash of the configuration to allow client pooling func (c *Configuration) Hash() string { @@ -30,5 +32,10 @@ func (c *Configuration) Hash() string { // Get creates or gets a client for the protocol based on custom configuration func Get(options *types.Options, configuration *Configuration /*TODO review unused parameters*/) (*fastdialer.Dialer, error) { + + if configuration != nil && configuration.CustomDialer != nil { + return configuration.CustomDialer, nil + } + return normalClient, nil } diff --git a/pkg/protocols/network/request.go b/pkg/protocols/network/request.go index f7b11fbb5..197ef5332 100644 --- a/pkg/protocols/network/request.go +++ b/pkg/protocols/network/request.go @@ -25,9 +25,9 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/helpers/eventcreator" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/helpers/responsehighlighter" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/interactsh" - "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/replacer" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/utils/vardump" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/network/networkclientpool" protocolutils "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils" templateTypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" errorutil "github.com/projectdiscovery/utils/errors" @@ -64,7 +64,11 @@ func (request *Request) getOpenPorts(target *contextargs.Context) ([]string, err errs = append(errs, err) continue } - conn, err := protocolstate.Dialer.Dial(target.Context(), "tcp", addr) + if request.dialer == nil { + request.dialer, _ = networkclientpool.Get(request.options.Options, &networkclientpool.Configuration{}) + } + + conn, err := request.dialer.Dial(target.Context(), "tcp", addr) if err != nil { errs = append(errs, err) continue diff --git a/pkg/protocols/protocols.go b/pkg/protocols/protocols.go index 7b7f71d48..6b5c089be 100644 --- a/pkg/protocols/protocols.go +++ b/pkg/protocols/protocols.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "sync/atomic" + "github.com/projectdiscovery/fastdialer/fastdialer" "github.com/projectdiscovery/ratelimit" mapsutil "github.com/projectdiscovery/utils/maps" stringsutil "github.com/projectdiscovery/utils/strings" @@ -132,6 +133,8 @@ type ExecutorOptions struct { ExportReqURLPattern bool // GlobalMatchers is the storage for global matchers with http passive templates GlobalMatchers *globalmatchers.Storage + // CustomFastdialer is a fastdialer dialer instance + CustomFastdialer *fastdialer.Dialer } // todo: centralizing components is not feasible with current clogged architecture diff --git a/pkg/protocols/ssl/ssl.go b/pkg/protocols/ssl/ssl.go index fd0dae83d..8943d597d 100644 --- a/pkg/protocols/ssl/ssl.go +++ b/pkg/protocols/ssl/ssl.go @@ -115,7 +115,9 @@ func (request *Request) IsClusterable() bool { func (request *Request) Compile(options *protocols.ExecutorOptions) error { request.options = options - client, err := networkclientpool.Get(options.Options, &networkclientpool.Configuration{}) + client, err := networkclientpool.Get(options.Options, &networkclientpool.Configuration{ + CustomDialer: options.CustomFastdialer, + }) if err != nil { return errorutil.NewWithTag("ssl", "could not get network client").Wrap(err) } diff --git a/pkg/protocols/websocket/websocket.go b/pkg/protocols/websocket/websocket.go index 8eeeedf21..02cf75190 100644 --- a/pkg/protocols/websocket/websocket.go +++ b/pkg/protocols/websocket/websocket.go @@ -100,7 +100,9 @@ const ( func (request *Request) Compile(options *protocols.ExecutorOptions) error { request.options = options - client, err := networkclientpool.Get(options.Options, &networkclientpool.Configuration{}) + client, err := networkclientpool.Get(options.Options, &networkclientpool.Configuration{ + CustomDialer: options.CustomFastdialer, + }) if err != nil { return errors.Wrap(err, "could not get network client") }