From c0972540c43f14346df7b1e772b2e2070deba1b5 Mon Sep 17 00:00:00 2001 From: Ice3man Date: Fri, 19 Dec 2025 19:16:04 +0530 Subject: [PATCH] feat: initial try to per-host caching pool tests --- internal/runner/runner.go | 9 + pkg/protocols/common/protocolstate/dialers.go | 1 + pkg/protocols/http/build_request.go | 8 +- pkg/protocols/http/http.go | 1 + .../http/httpclientpool/clientpool.go | 51 ++++ .../http/httpclientpool/perhost_pool.go | 249 ++++++++++++++++++ pkg/protocols/http/request.go | 14 +- 7 files changed, 323 insertions(+), 10 deletions(-) create mode 100644 pkg/protocols/http/httpclientpool/perhost_pool.go diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 236ca3d6d..223c014f6 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -54,6 +54,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/hosterrorscache" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/interactsh" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolinit" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/uncover" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/utils/excludematchers" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/headless/engine" @@ -743,6 +744,14 @@ func (r *Runner) RunEnumeration() error { r.progress.Stop() timeTaken := time.Since(now) + + // Print per-host pool stats if available + if dialers := protocolstate.GetDialersWithId(r.options.ExecutionId); dialers != nil && dialers.PerHostHTTPPool != nil { + if pool, ok := dialers.PerHostHTTPPool.(interface{ PrintStats() }); ok { + pool.PrintStats() + } + } + // todo: error propagation without canonical straight error check is required by cloud? // use safe dereferencing to avoid potential panics in case of previous unchecked errors if v := ptrutil.Safe(results); !v.Load() { diff --git a/pkg/protocols/common/protocolstate/dialers.go b/pkg/protocols/common/protocolstate/dialers.go index 91bdbae51..6c65ce206 100644 --- a/pkg/protocols/common/protocolstate/dialers.go +++ b/pkg/protocols/common/protocolstate/dialers.go @@ -15,6 +15,7 @@ type Dialers struct { RawHTTPClient *rawhttp.Client DefaultHTTPClient *retryablehttp.Client HTTPClientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client] + PerHostHTTPPool any NetworkPolicy *networkpolicy.NetworkPolicy LocalFileAccessAllowed bool RestrictLocalNetworkAccess bool diff --git a/pkg/protocols/http/build_request.go b/pkg/protocols/http/build_request.go index 980573c96..a233444da 100644 --- a/pkg/protocols/http/build_request.go +++ b/pkg/protocols/http/build_request.go @@ -24,7 +24,6 @@ import ( protocolutils "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils" httputil "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils/http" "github.com/projectdiscovery/nuclei/v3/pkg/types" - "github.com/projectdiscovery/nuclei/v3/pkg/types/scanstrategy" "github.com/projectdiscovery/rawhttp" "github.com/projectdiscovery/retryablehttp-go" "github.com/projectdiscovery/utils/errkit" @@ -451,10 +450,9 @@ func (r *requestGenerator) fillRequest(req *retryablehttp.Request, values map[st } } - // In case of multiple threads the underlying connection should remain open to allow reuse - if r.request.Threads <= 0 && req.Header.Get("Connection") == "" && r.options.Options.ScanStrategy != scanstrategy.HostSpray.String() { - req.Close = true - } + // Connection handling is now managed by per-host HTTP client pool + // Only set Close=true if template explicitly requests it via Connection header + // Otherwise, let the per-host pool's keep-alive settings handle it // Check if the user requested a request body if r.request.Body != "" { diff --git a/pkg/protocols/http/http.go b/pkg/protocols/http/http.go index ae3f3f471..b82017453 100644 --- a/pkg/protocols/http/http.go +++ b/pkg/protocols/http/http.go @@ -345,6 +345,7 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error { } request.connConfiguration = connectionConfiguration + // Don't create per-host client during Compile (no target yet) client, err := httpclientpool.Get(options.Options, connectionConfiguration) if err != nil { return errors.Wrap(err, "could not get dns client") diff --git a/pkg/protocols/http/httpclientpool/clientpool.go b/pkg/protocols/http/httpclientpool/clientpool.go index 14f4c8dc3..e094fb4c3 100644 --- a/pkg/protocols/http/httpclientpool/clientpool.go +++ b/pkg/protocols/http/httpclientpool/clientpool.go @@ -212,6 +212,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl maxConnsPerHost = 500 } + retryableHttpOptions.ImpersonateChrome = true retryableHttpOptions.RetryWaitMax = 10 * time.Second retryableHttpOptions.RetryMax = options.Retries retryableHttpOptions.Timeout = time.Duration(options.Timeout) * time.Second @@ -289,6 +290,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl MaxIdleConns: maxIdleConns, MaxIdleConnsPerHost: maxIdleConnsPerHost, MaxConnsPerHost: maxConnsPerHost, + IdleConnTimeout: 90 * time.Second, TLSClientConfig: tlsConfig, DisableKeepAlives: disableKeepAlives, ResponseHeaderTimeout: responseHeaderTimeout, @@ -366,6 +368,55 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl return client, nil } +// GetForTarget creates or gets a client for a specific target with per-host connection pooling +func GetForTarget(options *types.Options, configuration *Configuration, targetURL string) (*retryablehttp.Client, error) { + if !shouldUsePerHostPooling(options, configuration) { + return Get(options, configuration) + } + + dialers := protocolstate.GetDialersWithId(options.ExecutionId) + if dialers == nil { + return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId) + } + + dialers.Lock() + if dialers.PerHostHTTPPool == nil { + dialers.PerHostHTTPPool = NewPerHostClientPool(500, 5*time.Minute, 30*time.Minute) + } + dialers.Unlock() + + pool, ok := dialers.PerHostHTTPPool.(*PerHostClientPool) + if !ok || pool == nil { + return Get(options, configuration) + } + + return pool.GetOrCreate(targetURL, func() (*retryablehttp.Client, error) { + cfg := configuration.Clone() + if cfg.Connection == nil { + cfg.Connection = &ConnectionConfiguration{} + } + cfg.Connection.DisableKeepAlive = false + + // Override Threads to force connection pool settings + // This ensures MaxIdleConnsPerHost and MaxConnsPerHost are set correctly + originalThreads := cfg.Threads + cfg.Threads = 1 + client, err := wrappedGet(options, cfg) + cfg.Threads = originalThreads + + return client, err + }) +} + +// shouldUsePerHostPooling determines if per-host pooling should be enabled +func shouldUsePerHostPooling(options *types.Options, config *Configuration) bool { + // Enable per-host pooling for: + // 1. Templates with threads (parallel requests to same host) + // 2. TemplateSpray mode even without threads (sequential requests benefit from keep-alive) + // Disable only for HostSpray mode (already has keep-alive enabled globally) + return options.ScanStrategy != scanstrategy.HostSpray.String() +} + type RedirectFlow uint8 const ( diff --git a/pkg/protocols/http/httpclientpool/perhost_pool.go b/pkg/protocols/http/httpclientpool/perhost_pool.go new file mode 100644 index 000000000..cf4504480 --- /dev/null +++ b/pkg/protocols/http/httpclientpool/perhost_pool.go @@ -0,0 +1,249 @@ +package httpclientpool + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/hashicorp/golang-lru/v2/expirable" + "github.com/projectdiscovery/gologger" + "github.com/projectdiscovery/retryablehttp-go" + urlutil "github.com/projectdiscovery/utils/url" +) + +type PerHostClientPool struct { + cache *expirable.LRU[string, *clientEntry] + capacity int + mu sync.Mutex + + hits atomic.Uint64 + misses atomic.Uint64 + evictions atomic.Uint64 +} + +type clientEntry struct { + client *retryablehttp.Client + createdAt time.Time + accessCount atomic.Uint64 +} + +func NewPerHostClientPool(size int, maxIdleTime, maxLifetime time.Duration) *PerHostClientPool { + if size <= 0 { + size = 500 + } + if maxIdleTime == 0 { + maxIdleTime = 5 * time.Minute + } + if maxLifetime == 0 { + maxLifetime = 30 * time.Minute + } + + ttl := maxIdleTime + if maxLifetime < maxIdleTime { + ttl = maxLifetime + } + + pool := &PerHostClientPool{ + cache: expirable.NewLRU[string, *clientEntry]( + size, + func(key string, value *clientEntry) { + gologger.Debug().Msgf("[perhost-pool] Evicted client for %s (age: %v, accesses: %d)", + key, time.Since(value.createdAt), value.accessCount.Load()) + }, + ttl, + ), + capacity: size, + } + + return pool +} + +func (p *PerHostClientPool) GetOrCreate( + host string, + createFunc func() (*retryablehttp.Client, error), +) (*retryablehttp.Client, error) { + normalizedHost := normalizeHost(host) + + if entry, ok := p.cache.Get(normalizedHost); ok { + entry.accessCount.Add(1) + p.hits.Add(1) + return entry.client, nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + if entry, ok := p.cache.Peek(normalizedHost); ok { + entry.accessCount.Add(1) + p.hits.Add(1) + return entry.client, nil + } + + p.misses.Add(1) + + client, err := createFunc() + if err != nil { + return nil, err + } + + entry := &clientEntry{ + client: client, + createdAt: time.Now(), + } + entry.accessCount.Store(1) + + evicted := p.cache.Add(normalizedHost, entry) + if evicted { + p.evictions.Add(1) + } + + return client, nil +} + +func (p *PerHostClientPool) EvictHost(host string) bool { + normalizedHost := normalizeHost(host) + existed := p.cache.Remove(normalizedHost) + + if existed { + p.evictions.Add(1) + } + return existed +} + +func (p *PerHostClientPool) EvictAll() { + count := p.cache.Len() + p.cache.Purge() + p.evictions.Add(uint64(count)) +} + +func (p *PerHostClientPool) Size() int { + return p.cache.Len() +} + +func (p *PerHostClientPool) Stats() PoolStats { + return PoolStats{ + Hits: p.hits.Load(), + Misses: p.misses.Load(), + Evictions: p.evictions.Load(), + Size: p.Size(), + } +} + +func (p *PerHostClientPool) Close() { + p.EvictAll() +} + +func normalizeHost(rawURL string) string { + if rawURL == "" { + return "" + } + + parsed, err := urlutil.Parse(rawURL) + if err != nil { + return rawURL + } + + scheme := parsed.Scheme + if scheme == "" { + scheme = "http" + } + + host := parsed.Host + if host == "" { + host = parsed.Hostname() + } + + port := parsed.Port() + if port != "" { + return fmt.Sprintf("%s://%s:%s", scheme, parsed.Hostname(), port) + } + + if scheme == "https" && port == "" { + return fmt.Sprintf("%s://%s:443", scheme, parsed.Hostname()) + } + if scheme == "http" && port == "" { + return fmt.Sprintf("%s://%s:80", scheme, parsed.Hostname()) + } + + return fmt.Sprintf("%s://%s", scheme, host) +} + +type PoolStats struct { + Hits uint64 + Misses uint64 + Evictions uint64 + Size int +} + +func (p *PerHostClientPool) GetClientForHost(host string) (*retryablehttp.Client, bool) { + normalizedHost := normalizeHost(host) + + if entry, ok := p.cache.Peek(normalizedHost); ok { + return entry.client, true + } + return nil, false +} + +func (p *PerHostClientPool) ListAllClients() []string { + return p.cache.Keys() +} + +type ClientInfo struct { + Host string + CreatedAt time.Time + AccessCount uint64 + Age time.Duration +} + +func (p *PerHostClientPool) GetClientInfo(host string) *ClientInfo { + normalizedHost := normalizeHost(host) + + entry, ok := p.cache.Peek(normalizedHost) + if !ok { + return nil + } + + now := time.Now() + + return &ClientInfo{ + Host: normalizedHost, + CreatedAt: entry.createdAt, + AccessCount: entry.accessCount.Load(), + Age: now.Sub(entry.createdAt), + } +} + +func (p *PerHostClientPool) GetAllClientInfo() []*ClientInfo { + infos := []*ClientInfo{} + for _, key := range p.cache.Keys() { + if info := p.GetClientInfo(key); info != nil { + infos = append(infos, info) + } + } + return infos +} + +func (p *PerHostClientPool) Resize(size int) int { + evicted := p.cache.Resize(size) + p.capacity = size + return evicted +} + +func (p *PerHostClientPool) Cap() int { + return p.capacity +} + +func (p *PerHostClientPool) PrintStats() { + stats := p.Stats() + if stats.Size == 0 { + return + } + gologger.Verbose().Msgf("[perhost-pool] Connection reuse stats: Hits=%d Misses=%d HitRate=%.1f%% Hosts=%d", + stats.Hits, stats.Misses, + float64(stats.Hits)*100/float64(stats.Hits+stats.Misses+1), + stats.Size) +} + +func (p *PerHostClientPool) PrintTransportStats() { +} diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index c538686e1..c57149616 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -829,12 +829,16 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ modifiedConfig.ResponseHeaderTimeout = updatedTimeout.Timeout } - if modifiedConfig != nil { - client, err := httpclientpool.Get(request.options.Options, modifiedConfig) - if err != nil { - return errors.Wrap(err, "could not get http client") - } + // always prefer per-host pooled client for better reuse + // choose config to use (modified if present else default) + configToUse := modifiedConfig + if configToUse == nil { + configToUse = request.connConfiguration + } + if client, err := httpclientpool.GetForTarget(request.options.Options, configToUse, formedURL); err == nil { httpclient = client + } else { + return errors.Wrap(err, "could not get http client") } resp, err = httpclient.Do(generatedRequest.request)