feat: initial try to per-host caching pool tests

This commit is contained in:
Ice3man
2025-12-19 19:16:04 +05:30
parent 141f34a8ae
commit c0972540c4
7 changed files with 323 additions and 10 deletions

View File

@@ -54,6 +54,7 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/hosterrorscache"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/interactsh"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolinit"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/uncover"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/utils/excludematchers"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/headless/engine"
@@ -743,6 +744,14 @@ func (r *Runner) RunEnumeration() error {
r.progress.Stop()
timeTaken := time.Since(now)
// Print per-host pool stats if available
if dialers := protocolstate.GetDialersWithId(r.options.ExecutionId); dialers != nil && dialers.PerHostHTTPPool != nil {
if pool, ok := dialers.PerHostHTTPPool.(interface{ PrintStats() }); ok {
pool.PrintStats()
}
}
// todo: error propagation without canonical straight error check is required by cloud?
// use safe dereferencing to avoid potential panics in case of previous unchecked errors
if v := ptrutil.Safe(results); !v.Load() {

View File

@@ -15,6 +15,7 @@ type Dialers struct {
RawHTTPClient *rawhttp.Client
DefaultHTTPClient *retryablehttp.Client
HTTPClientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client]
PerHostHTTPPool any
NetworkPolicy *networkpolicy.NetworkPolicy
LocalFileAccessAllowed bool
RestrictLocalNetworkAccess bool

View File

@@ -24,7 +24,6 @@ import (
protocolutils "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils"
httputil "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils/http"
"github.com/projectdiscovery/nuclei/v3/pkg/types"
"github.com/projectdiscovery/nuclei/v3/pkg/types/scanstrategy"
"github.com/projectdiscovery/rawhttp"
"github.com/projectdiscovery/retryablehttp-go"
"github.com/projectdiscovery/utils/errkit"
@@ -451,10 +450,9 @@ func (r *requestGenerator) fillRequest(req *retryablehttp.Request, values map[st
}
}
// In case of multiple threads the underlying connection should remain open to allow reuse
if r.request.Threads <= 0 && req.Header.Get("Connection") == "" && r.options.Options.ScanStrategy != scanstrategy.HostSpray.String() {
req.Close = true
}
// Connection handling is now managed by per-host HTTP client pool
// Only set Close=true if template explicitly requests it via Connection header
// Otherwise, let the per-host pool's keep-alive settings handle it
// Check if the user requested a request body
if r.request.Body != "" {

View File

@@ -345,6 +345,7 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error {
}
request.connConfiguration = connectionConfiguration
// Don't create per-host client during Compile (no target yet)
client, err := httpclientpool.Get(options.Options, connectionConfiguration)
if err != nil {
return errors.Wrap(err, "could not get dns client")

View File

@@ -212,6 +212,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
maxConnsPerHost = 500
}
retryableHttpOptions.ImpersonateChrome = true
retryableHttpOptions.RetryWaitMax = 10 * time.Second
retryableHttpOptions.RetryMax = options.Retries
retryableHttpOptions.Timeout = time.Duration(options.Timeout) * time.Second
@@ -289,6 +290,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
MaxConnsPerHost: maxConnsPerHost,
IdleConnTimeout: 90 * time.Second,
TLSClientConfig: tlsConfig,
DisableKeepAlives: disableKeepAlives,
ResponseHeaderTimeout: responseHeaderTimeout,
@@ -366,6 +368,55 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
return client, nil
}
// GetForTarget creates or gets a client for a specific target with per-host connection pooling
func GetForTarget(options *types.Options, configuration *Configuration, targetURL string) (*retryablehttp.Client, error) {
if !shouldUsePerHostPooling(options, configuration) {
return Get(options, configuration)
}
dialers := protocolstate.GetDialersWithId(options.ExecutionId)
if dialers == nil {
return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId)
}
dialers.Lock()
if dialers.PerHostHTTPPool == nil {
dialers.PerHostHTTPPool = NewPerHostClientPool(500, 5*time.Minute, 30*time.Minute)
}
dialers.Unlock()
pool, ok := dialers.PerHostHTTPPool.(*PerHostClientPool)
if !ok || pool == nil {
return Get(options, configuration)
}
return pool.GetOrCreate(targetURL, func() (*retryablehttp.Client, error) {
cfg := configuration.Clone()
if cfg.Connection == nil {
cfg.Connection = &ConnectionConfiguration{}
}
cfg.Connection.DisableKeepAlive = false
// Override Threads to force connection pool settings
// This ensures MaxIdleConnsPerHost and MaxConnsPerHost are set correctly
originalThreads := cfg.Threads
cfg.Threads = 1
client, err := wrappedGet(options, cfg)
cfg.Threads = originalThreads
return client, err
})
}
// shouldUsePerHostPooling determines if per-host pooling should be enabled
func shouldUsePerHostPooling(options *types.Options, config *Configuration) bool {
// Enable per-host pooling for:
// 1. Templates with threads (parallel requests to same host)
// 2. TemplateSpray mode even without threads (sequential requests benefit from keep-alive)
// Disable only for HostSpray mode (already has keep-alive enabled globally)
return options.ScanStrategy != scanstrategy.HostSpray.String()
}
type RedirectFlow uint8
const (

View File

@@ -0,0 +1,249 @@
package httpclientpool
import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/retryablehttp-go"
urlutil "github.com/projectdiscovery/utils/url"
)
type PerHostClientPool struct {
cache *expirable.LRU[string, *clientEntry]
capacity int
mu sync.Mutex
hits atomic.Uint64
misses atomic.Uint64
evictions atomic.Uint64
}
type clientEntry struct {
client *retryablehttp.Client
createdAt time.Time
accessCount atomic.Uint64
}
func NewPerHostClientPool(size int, maxIdleTime, maxLifetime time.Duration) *PerHostClientPool {
if size <= 0 {
size = 500
}
if maxIdleTime == 0 {
maxIdleTime = 5 * time.Minute
}
if maxLifetime == 0 {
maxLifetime = 30 * time.Minute
}
ttl := maxIdleTime
if maxLifetime < maxIdleTime {
ttl = maxLifetime
}
pool := &PerHostClientPool{
cache: expirable.NewLRU[string, *clientEntry](
size,
func(key string, value *clientEntry) {
gologger.Debug().Msgf("[perhost-pool] Evicted client for %s (age: %v, accesses: %d)",
key, time.Since(value.createdAt), value.accessCount.Load())
},
ttl,
),
capacity: size,
}
return pool
}
func (p *PerHostClientPool) GetOrCreate(
host string,
createFunc func() (*retryablehttp.Client, error),
) (*retryablehttp.Client, error) {
normalizedHost := normalizeHost(host)
if entry, ok := p.cache.Get(normalizedHost); ok {
entry.accessCount.Add(1)
p.hits.Add(1)
return entry.client, nil
}
p.mu.Lock()
defer p.mu.Unlock()
if entry, ok := p.cache.Peek(normalizedHost); ok {
entry.accessCount.Add(1)
p.hits.Add(1)
return entry.client, nil
}
p.misses.Add(1)
client, err := createFunc()
if err != nil {
return nil, err
}
entry := &clientEntry{
client: client,
createdAt: time.Now(),
}
entry.accessCount.Store(1)
evicted := p.cache.Add(normalizedHost, entry)
if evicted {
p.evictions.Add(1)
}
return client, nil
}
func (p *PerHostClientPool) EvictHost(host string) bool {
normalizedHost := normalizeHost(host)
existed := p.cache.Remove(normalizedHost)
if existed {
p.evictions.Add(1)
}
return existed
}
func (p *PerHostClientPool) EvictAll() {
count := p.cache.Len()
p.cache.Purge()
p.evictions.Add(uint64(count))
}
func (p *PerHostClientPool) Size() int {
return p.cache.Len()
}
func (p *PerHostClientPool) Stats() PoolStats {
return PoolStats{
Hits: p.hits.Load(),
Misses: p.misses.Load(),
Evictions: p.evictions.Load(),
Size: p.Size(),
}
}
func (p *PerHostClientPool) Close() {
p.EvictAll()
}
func normalizeHost(rawURL string) string {
if rawURL == "" {
return ""
}
parsed, err := urlutil.Parse(rawURL)
if err != nil {
return rawURL
}
scheme := parsed.Scheme
if scheme == "" {
scheme = "http"
}
host := parsed.Host
if host == "" {
host = parsed.Hostname()
}
port := parsed.Port()
if port != "" {
return fmt.Sprintf("%s://%s:%s", scheme, parsed.Hostname(), port)
}
if scheme == "https" && port == "" {
return fmt.Sprintf("%s://%s:443", scheme, parsed.Hostname())
}
if scheme == "http" && port == "" {
return fmt.Sprintf("%s://%s:80", scheme, parsed.Hostname())
}
return fmt.Sprintf("%s://%s", scheme, host)
}
type PoolStats struct {
Hits uint64
Misses uint64
Evictions uint64
Size int
}
func (p *PerHostClientPool) GetClientForHost(host string) (*retryablehttp.Client, bool) {
normalizedHost := normalizeHost(host)
if entry, ok := p.cache.Peek(normalizedHost); ok {
return entry.client, true
}
return nil, false
}
func (p *PerHostClientPool) ListAllClients() []string {
return p.cache.Keys()
}
type ClientInfo struct {
Host string
CreatedAt time.Time
AccessCount uint64
Age time.Duration
}
func (p *PerHostClientPool) GetClientInfo(host string) *ClientInfo {
normalizedHost := normalizeHost(host)
entry, ok := p.cache.Peek(normalizedHost)
if !ok {
return nil
}
now := time.Now()
return &ClientInfo{
Host: normalizedHost,
CreatedAt: entry.createdAt,
AccessCount: entry.accessCount.Load(),
Age: now.Sub(entry.createdAt),
}
}
func (p *PerHostClientPool) GetAllClientInfo() []*ClientInfo {
infos := []*ClientInfo{}
for _, key := range p.cache.Keys() {
if info := p.GetClientInfo(key); info != nil {
infos = append(infos, info)
}
}
return infos
}
func (p *PerHostClientPool) Resize(size int) int {
evicted := p.cache.Resize(size)
p.capacity = size
return evicted
}
func (p *PerHostClientPool) Cap() int {
return p.capacity
}
func (p *PerHostClientPool) PrintStats() {
stats := p.Stats()
if stats.Size == 0 {
return
}
gologger.Verbose().Msgf("[perhost-pool] Connection reuse stats: Hits=%d Misses=%d HitRate=%.1f%% Hosts=%d",
stats.Hits, stats.Misses,
float64(stats.Hits)*100/float64(stats.Hits+stats.Misses+1),
stats.Size)
}
func (p *PerHostClientPool) PrintTransportStats() {
}

View File

@@ -829,12 +829,16 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ
modifiedConfig.ResponseHeaderTimeout = updatedTimeout.Timeout
}
if modifiedConfig != nil {
client, err := httpclientpool.Get(request.options.Options, modifiedConfig)
if err != nil {
return errors.Wrap(err, "could not get http client")
}
// always prefer per-host pooled client for better reuse
// choose config to use (modified if present else default)
configToUse := modifiedConfig
if configToUse == nil {
configToUse = request.connConfiguration
}
if client, err := httpclientpool.GetForTarget(request.options.Options, configToUse, formedURL); err == nil {
httpclient = client
} else {
return errors.Wrap(err, "could not get http client")
}
resp, err = httpclient.Do(generatedRequest.request)