consider protocolType in max host error (#5668)

* consider protocolType in max host error

* add mutex when updating internal-event
This commit is contained in:
Tarun Koyalwar
2024-09-28 17:20:35 +04:00
committed by GitHub
parent e4dae52d5a
commit 1f945d6d50
10 changed files with 37 additions and 29 deletions

View File

@@ -20,10 +20,10 @@ import (
// CacheInterface defines the signature of the hosterrorscache so that
// users of Nuclei as embedded lib may implement their own cache
type CacheInterface interface {
SetVerbose(verbose bool) // log verbosely
Close() // close the cache
Check(ctx *contextargs.Context) bool // return true if the host should be skipped
MarkFailed(ctx *contextargs.Context, err error) // record a failure (and cause) for the host
SetVerbose(verbose bool) // log verbosely
Close() // close the cache
Check(protoType string, ctx *contextargs.Context) bool // return true if the host should be skipped
MarkFailed(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host
}
var (
@@ -115,7 +115,7 @@ func (c *Cache) NormalizeCacheValue(value string) string {
// - URL: https?:// type
// - Host:port type
// - host type
func (c *Cache) Check(ctx *contextargs.Context) bool {
func (c *Cache) Check(protoType string, ctx *contextargs.Context) bool {
finalValue := c.GetKeyFromContext(ctx, nil)
existingCacheItem, err := c.failedTargets.GetIFPresent(finalValue)
@@ -138,8 +138,8 @@ func (c *Cache) Check(ctx *contextargs.Context) bool {
}
// MarkFailed marks a host as failed previously
func (c *Cache) MarkFailed(ctx *contextargs.Context, err error) {
if !c.checkError(err) {
func (c *Cache) MarkFailed(protoType string, ctx *contextargs.Context, err error) {
if !c.checkError(protoType, err) {
return
}
finalValue := c.GetKeyFromContext(ctx, err)
@@ -186,11 +186,13 @@ var reCheckError = regexp.MustCompile(`(no address found for host|could not reso
// added to the host skipping table.
// it first parses error and extracts the cause and checks for blacklisted
// or common errors that should be skipped
func (c *Cache) checkError(err error) bool {
func (c *Cache) checkError(protoType string, err error) bool {
if err == nil {
return false
}
if protoType != "http" {
return false
}
kind := errkit.GetErrorKind(err, nucleierr.ErrTemplateLogic)
switch kind {
case nucleierr.ErrTemplateLogic:

View File

@@ -11,12 +11,16 @@ import (
"github.com/stretchr/testify/require"
)
const (
protoType = "http"
)
func TestCacheCheck(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)
for i := 0; i < 100; i++ {
cache.MarkFailed(newCtxArgs("test"), fmt.Errorf("could not resolve host"))
got := cache.Check(newCtxArgs("test"))
cache.MarkFailed(protoType, newCtxArgs("test"), fmt.Errorf("could not resolve host"))
got := cache.Check(protoType, newCtxArgs("test"))
if i < 2 {
// till 3 the host is not flagged to skip
require.False(t, got)
@@ -26,7 +30,7 @@ func TestCacheCheck(t *testing.T) {
}
}
value := cache.Check(newCtxArgs("test"))
value := cache.Check(protoType, newCtxArgs("test"))
require.Equal(t, true, value, "could not get checked value")
}
@@ -34,8 +38,8 @@ func TestTrackErrors(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, []string{"custom error"})
for i := 0; i < 100; i++ {
cache.MarkFailed(newCtxArgs("custom"), fmt.Errorf("got: nested: custom error"))
got := cache.Check(newCtxArgs("custom"))
cache.MarkFailed(protoType, newCtxArgs("custom"), fmt.Errorf("got: nested: custom error"))
got := cache.Check(protoType, newCtxArgs("custom"))
if i < 2 {
// till 3 the host is not flagged to skip
require.False(t, got)
@@ -44,7 +48,7 @@ func TestTrackErrors(t *testing.T) {
require.True(t, got)
}
}
value := cache.Check(newCtxArgs("custom"))
value := cache.Check(protoType, newCtxArgs("custom"))
require.Equal(t, true, value, "could not get checked value")
}
@@ -86,7 +90,7 @@ func TestCacheMarkFailed(t *testing.T) {
for _, test := range tests {
normalizedCacheValue := cache.GetKeyFromContext(newCtxArgs(test.host), nil)
cache.MarkFailed(newCtxArgs(test.host), fmt.Errorf("no address found for host"))
cache.MarkFailed(protoType, newCtxArgs(test.host), fmt.Errorf("no address found for host"))
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
require.Nil(t, err)
require.NotNil(t, failedTarget)
@@ -122,14 +126,14 @@ func TestCacheMarkFailedConcurrent(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
cache.MarkFailed(newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host"))
cache.MarkFailed(protoType, newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host"))
}()
}
}
wg.Wait()
for _, test := range tests {
require.True(t, cache.Check(newCtxArgs(test.host)))
require.True(t, cache.Check(protoType, newCtxArgs(test.host)))
normalizedCacheValue := cache.NormalizeCacheValue(test.host)
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)