mirror of
https://github.com/projectdiscovery/nuclei.git
synced 2026-01-31 15:53:10 +08:00
feat: initial try to per-host caching pool tests
This commit is contained in:
@@ -54,6 +54,7 @@ import (
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/hosterrorscache"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/interactsh"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolinit"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/uncover"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/utils/excludematchers"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/headless/engine"
|
||||
@@ -743,6 +744,14 @@ func (r *Runner) RunEnumeration() error {
|
||||
|
||||
r.progress.Stop()
|
||||
timeTaken := time.Since(now)
|
||||
|
||||
// Print per-host pool stats if available
|
||||
if dialers := protocolstate.GetDialersWithId(r.options.ExecutionId); dialers != nil && dialers.PerHostHTTPPool != nil {
|
||||
if pool, ok := dialers.PerHostHTTPPool.(interface{ PrintStats() }); ok {
|
||||
pool.PrintStats()
|
||||
}
|
||||
}
|
||||
|
||||
// todo: error propagation without canonical straight error check is required by cloud?
|
||||
// use safe dereferencing to avoid potential panics in case of previous unchecked errors
|
||||
if v := ptrutil.Safe(results); !v.Load() {
|
||||
|
||||
@@ -15,6 +15,7 @@ type Dialers struct {
|
||||
RawHTTPClient *rawhttp.Client
|
||||
DefaultHTTPClient *retryablehttp.Client
|
||||
HTTPClientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client]
|
||||
PerHostHTTPPool any
|
||||
NetworkPolicy *networkpolicy.NetworkPolicy
|
||||
LocalFileAccessAllowed bool
|
||||
RestrictLocalNetworkAccess bool
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
protocolutils "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils"
|
||||
httputil "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils/http"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/types"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/types/scanstrategy"
|
||||
"github.com/projectdiscovery/rawhttp"
|
||||
"github.com/projectdiscovery/retryablehttp-go"
|
||||
"github.com/projectdiscovery/utils/errkit"
|
||||
@@ -451,10 +450,9 @@ func (r *requestGenerator) fillRequest(req *retryablehttp.Request, values map[st
|
||||
}
|
||||
}
|
||||
|
||||
// In case of multiple threads the underlying connection should remain open to allow reuse
|
||||
if r.request.Threads <= 0 && req.Header.Get("Connection") == "" && r.options.Options.ScanStrategy != scanstrategy.HostSpray.String() {
|
||||
req.Close = true
|
||||
}
|
||||
// Connection handling is now managed by per-host HTTP client pool
|
||||
// Only set Close=true if template explicitly requests it via Connection header
|
||||
// Otherwise, let the per-host pool's keep-alive settings handle it
|
||||
|
||||
// Check if the user requested a request body
|
||||
if r.request.Body != "" {
|
||||
|
||||
@@ -345,6 +345,7 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error {
|
||||
}
|
||||
request.connConfiguration = connectionConfiguration
|
||||
|
||||
// Don't create per-host client during Compile (no target yet)
|
||||
client, err := httpclientpool.Get(options.Options, connectionConfiguration)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not get dns client")
|
||||
|
||||
@@ -212,6 +212,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
|
||||
maxConnsPerHost = 500
|
||||
}
|
||||
|
||||
retryableHttpOptions.ImpersonateChrome = true
|
||||
retryableHttpOptions.RetryWaitMax = 10 * time.Second
|
||||
retryableHttpOptions.RetryMax = options.Retries
|
||||
retryableHttpOptions.Timeout = time.Duration(options.Timeout) * time.Second
|
||||
@@ -289,6 +290,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
|
||||
MaxIdleConns: maxIdleConns,
|
||||
MaxIdleConnsPerHost: maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: maxConnsPerHost,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSClientConfig: tlsConfig,
|
||||
DisableKeepAlives: disableKeepAlives,
|
||||
ResponseHeaderTimeout: responseHeaderTimeout,
|
||||
@@ -366,6 +368,55 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// GetForTarget creates or gets a client for a specific target with per-host connection pooling
|
||||
func GetForTarget(options *types.Options, configuration *Configuration, targetURL string) (*retryablehttp.Client, error) {
|
||||
if !shouldUsePerHostPooling(options, configuration) {
|
||||
return Get(options, configuration)
|
||||
}
|
||||
|
||||
dialers := protocolstate.GetDialersWithId(options.ExecutionId)
|
||||
if dialers == nil {
|
||||
return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId)
|
||||
}
|
||||
|
||||
dialers.Lock()
|
||||
if dialers.PerHostHTTPPool == nil {
|
||||
dialers.PerHostHTTPPool = NewPerHostClientPool(500, 5*time.Minute, 30*time.Minute)
|
||||
}
|
||||
dialers.Unlock()
|
||||
|
||||
pool, ok := dialers.PerHostHTTPPool.(*PerHostClientPool)
|
||||
if !ok || pool == nil {
|
||||
return Get(options, configuration)
|
||||
}
|
||||
|
||||
return pool.GetOrCreate(targetURL, func() (*retryablehttp.Client, error) {
|
||||
cfg := configuration.Clone()
|
||||
if cfg.Connection == nil {
|
||||
cfg.Connection = &ConnectionConfiguration{}
|
||||
}
|
||||
cfg.Connection.DisableKeepAlive = false
|
||||
|
||||
// Override Threads to force connection pool settings
|
||||
// This ensures MaxIdleConnsPerHost and MaxConnsPerHost are set correctly
|
||||
originalThreads := cfg.Threads
|
||||
cfg.Threads = 1
|
||||
client, err := wrappedGet(options, cfg)
|
||||
cfg.Threads = originalThreads
|
||||
|
||||
return client, err
|
||||
})
|
||||
}
|
||||
|
||||
// shouldUsePerHostPooling determines if per-host pooling should be enabled
|
||||
func shouldUsePerHostPooling(options *types.Options, config *Configuration) bool {
|
||||
// Enable per-host pooling for:
|
||||
// 1. Templates with threads (parallel requests to same host)
|
||||
// 2. TemplateSpray mode even without threads (sequential requests benefit from keep-alive)
|
||||
// Disable only for HostSpray mode (already has keep-alive enabled globally)
|
||||
return options.ScanStrategy != scanstrategy.HostSpray.String()
|
||||
}
|
||||
|
||||
type RedirectFlow uint8
|
||||
|
||||
const (
|
||||
|
||||
249
pkg/protocols/http/httpclientpool/perhost_pool.go
Normal file
249
pkg/protocols/http/httpclientpool/perhost_pool.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package httpclientpool
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
"github.com/projectdiscovery/gologger"
|
||||
"github.com/projectdiscovery/retryablehttp-go"
|
||||
urlutil "github.com/projectdiscovery/utils/url"
|
||||
)
|
||||
|
||||
type PerHostClientPool struct {
|
||||
cache *expirable.LRU[string, *clientEntry]
|
||||
capacity int
|
||||
mu sync.Mutex
|
||||
|
||||
hits atomic.Uint64
|
||||
misses atomic.Uint64
|
||||
evictions atomic.Uint64
|
||||
}
|
||||
|
||||
type clientEntry struct {
|
||||
client *retryablehttp.Client
|
||||
createdAt time.Time
|
||||
accessCount atomic.Uint64
|
||||
}
|
||||
|
||||
func NewPerHostClientPool(size int, maxIdleTime, maxLifetime time.Duration) *PerHostClientPool {
|
||||
if size <= 0 {
|
||||
size = 500
|
||||
}
|
||||
if maxIdleTime == 0 {
|
||||
maxIdleTime = 5 * time.Minute
|
||||
}
|
||||
if maxLifetime == 0 {
|
||||
maxLifetime = 30 * time.Minute
|
||||
}
|
||||
|
||||
ttl := maxIdleTime
|
||||
if maxLifetime < maxIdleTime {
|
||||
ttl = maxLifetime
|
||||
}
|
||||
|
||||
pool := &PerHostClientPool{
|
||||
cache: expirable.NewLRU[string, *clientEntry](
|
||||
size,
|
||||
func(key string, value *clientEntry) {
|
||||
gologger.Debug().Msgf("[perhost-pool] Evicted client for %s (age: %v, accesses: %d)",
|
||||
key, time.Since(value.createdAt), value.accessCount.Load())
|
||||
},
|
||||
ttl,
|
||||
),
|
||||
capacity: size,
|
||||
}
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) GetOrCreate(
|
||||
host string,
|
||||
createFunc func() (*retryablehttp.Client, error),
|
||||
) (*retryablehttp.Client, error) {
|
||||
normalizedHost := normalizeHost(host)
|
||||
|
||||
if entry, ok := p.cache.Get(normalizedHost); ok {
|
||||
entry.accessCount.Add(1)
|
||||
p.hits.Add(1)
|
||||
return entry.client, nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if entry, ok := p.cache.Peek(normalizedHost); ok {
|
||||
entry.accessCount.Add(1)
|
||||
p.hits.Add(1)
|
||||
return entry.client, nil
|
||||
}
|
||||
|
||||
p.misses.Add(1)
|
||||
|
||||
client, err := createFunc()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entry := &clientEntry{
|
||||
client: client,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
entry.accessCount.Store(1)
|
||||
|
||||
evicted := p.cache.Add(normalizedHost, entry)
|
||||
if evicted {
|
||||
p.evictions.Add(1)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) EvictHost(host string) bool {
|
||||
normalizedHost := normalizeHost(host)
|
||||
existed := p.cache.Remove(normalizedHost)
|
||||
|
||||
if existed {
|
||||
p.evictions.Add(1)
|
||||
}
|
||||
return existed
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) EvictAll() {
|
||||
count := p.cache.Len()
|
||||
p.cache.Purge()
|
||||
p.evictions.Add(uint64(count))
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) Size() int {
|
||||
return p.cache.Len()
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) Stats() PoolStats {
|
||||
return PoolStats{
|
||||
Hits: p.hits.Load(),
|
||||
Misses: p.misses.Load(),
|
||||
Evictions: p.evictions.Load(),
|
||||
Size: p.Size(),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) Close() {
|
||||
p.EvictAll()
|
||||
}
|
||||
|
||||
func normalizeHost(rawURL string) string {
|
||||
if rawURL == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parsed, err := urlutil.Parse(rawURL)
|
||||
if err != nil {
|
||||
return rawURL
|
||||
}
|
||||
|
||||
scheme := parsed.Scheme
|
||||
if scheme == "" {
|
||||
scheme = "http"
|
||||
}
|
||||
|
||||
host := parsed.Host
|
||||
if host == "" {
|
||||
host = parsed.Hostname()
|
||||
}
|
||||
|
||||
port := parsed.Port()
|
||||
if port != "" {
|
||||
return fmt.Sprintf("%s://%s:%s", scheme, parsed.Hostname(), port)
|
||||
}
|
||||
|
||||
if scheme == "https" && port == "" {
|
||||
return fmt.Sprintf("%s://%s:443", scheme, parsed.Hostname())
|
||||
}
|
||||
if scheme == "http" && port == "" {
|
||||
return fmt.Sprintf("%s://%s:80", scheme, parsed.Hostname())
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s", scheme, host)
|
||||
}
|
||||
|
||||
type PoolStats struct {
|
||||
Hits uint64
|
||||
Misses uint64
|
||||
Evictions uint64
|
||||
Size int
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) GetClientForHost(host string) (*retryablehttp.Client, bool) {
|
||||
normalizedHost := normalizeHost(host)
|
||||
|
||||
if entry, ok := p.cache.Peek(normalizedHost); ok {
|
||||
return entry.client, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) ListAllClients() []string {
|
||||
return p.cache.Keys()
|
||||
}
|
||||
|
||||
type ClientInfo struct {
|
||||
Host string
|
||||
CreatedAt time.Time
|
||||
AccessCount uint64
|
||||
Age time.Duration
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) GetClientInfo(host string) *ClientInfo {
|
||||
normalizedHost := normalizeHost(host)
|
||||
|
||||
entry, ok := p.cache.Peek(normalizedHost)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
return &ClientInfo{
|
||||
Host: normalizedHost,
|
||||
CreatedAt: entry.createdAt,
|
||||
AccessCount: entry.accessCount.Load(),
|
||||
Age: now.Sub(entry.createdAt),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) GetAllClientInfo() []*ClientInfo {
|
||||
infos := []*ClientInfo{}
|
||||
for _, key := range p.cache.Keys() {
|
||||
if info := p.GetClientInfo(key); info != nil {
|
||||
infos = append(infos, info)
|
||||
}
|
||||
}
|
||||
return infos
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) Resize(size int) int {
|
||||
evicted := p.cache.Resize(size)
|
||||
p.capacity = size
|
||||
return evicted
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) Cap() int {
|
||||
return p.capacity
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) PrintStats() {
|
||||
stats := p.Stats()
|
||||
if stats.Size == 0 {
|
||||
return
|
||||
}
|
||||
gologger.Verbose().Msgf("[perhost-pool] Connection reuse stats: Hits=%d Misses=%d HitRate=%.1f%% Hosts=%d",
|
||||
stats.Hits, stats.Misses,
|
||||
float64(stats.Hits)*100/float64(stats.Hits+stats.Misses+1),
|
||||
stats.Size)
|
||||
}
|
||||
|
||||
func (p *PerHostClientPool) PrintTransportStats() {
|
||||
}
|
||||
@@ -829,12 +829,16 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ
|
||||
modifiedConfig.ResponseHeaderTimeout = updatedTimeout.Timeout
|
||||
}
|
||||
|
||||
if modifiedConfig != nil {
|
||||
client, err := httpclientpool.Get(request.options.Options, modifiedConfig)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not get http client")
|
||||
}
|
||||
// always prefer per-host pooled client for better reuse
|
||||
// choose config to use (modified if present else default)
|
||||
configToUse := modifiedConfig
|
||||
if configToUse == nil {
|
||||
configToUse = request.connConfiguration
|
||||
}
|
||||
if client, err := httpclientpool.GetForTarget(request.options.Options, configToUse, formedURL); err == nil {
|
||||
httpclient = client
|
||||
} else {
|
||||
return errors.Wrap(err, "could not get http client")
|
||||
}
|
||||
|
||||
resp, err = httpclient.Do(generatedRequest.request)
|
||||
|
||||
Reference in New Issue
Block a user