mirror of
https://github.com/projectdiscovery/nuclei.git
synced 2026-02-01 00:03:15 +08:00
feat(hosterrorscache): add Remove and MarkFailedOrRemove methods (#5984)
* feat(hosterrorscache): add `Remove` and `MarkFailedOrRemove` methods and also deprecating `MarkFailed` Signed-off-by: Dwi Siswanto <git@dw1.io> * refactor(*): unwraps `hosterrorscache\.MarkFailed` invocation Signed-off-by: Dwi Siswanto <git@dw1.io> * feat(hosterrorscache): add sync in `Check` and `MarkFailedOrRemove` methods * test(hosterrorscache): add concurrent test for `Check` method * refactor(hosterrorscache): do NOT change `MarkFailed` behavior Signed-off-by: Dwi Siswanto <git@dw1.io> * feat(*): use `MarkFailedOrRemove` explicitly Signed-off-by: Dwi Siswanto <git@dw1.io> --------- Signed-off-by: Dwi Siswanto <git@dw1.io>
This commit is contained in:
@@ -2,7 +2,7 @@ package hosterrorscache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -17,28 +17,40 @@ const (
|
||||
|
||||
func TestCacheCheck(t *testing.T) {
|
||||
cache := New(3, DefaultMaxHostsCount, nil)
|
||||
err := errors.New("net/http: timeout awaiting response headers")
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
cache.MarkFailed(protoType, newCtxArgs("test"), fmt.Errorf("could not resolve host"))
|
||||
got := cache.Check(protoType, newCtxArgs("test"))
|
||||
if i < 2 {
|
||||
// till 3 the host is not flagged to skip
|
||||
require.False(t, got)
|
||||
} else {
|
||||
// above 3 it must remain flagged to skip
|
||||
require.True(t, got)
|
||||
t.Run("increment host error", func(t *testing.T) {
|
||||
ctx := newCtxArgs(t.Name())
|
||||
for i := 1; i < 3; i++ {
|
||||
cache.MarkFailed(protoType, ctx, err)
|
||||
got := cache.Check(protoType, ctx)
|
||||
require.Falsef(t, got, "got %v in iteration %d", got, i)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
value := cache.Check(protoType, newCtxArgs("test"))
|
||||
require.Equal(t, true, value, "could not get checked value")
|
||||
t.Run("flagged", func(t *testing.T) {
|
||||
ctx := newCtxArgs(t.Name())
|
||||
for i := 1; i <= 3; i++ {
|
||||
cache.MarkFailed(protoType, ctx, err)
|
||||
}
|
||||
|
||||
got := cache.Check(protoType, ctx)
|
||||
require.True(t, got)
|
||||
})
|
||||
|
||||
t.Run("mark failed or remove", func(t *testing.T) {
|
||||
ctx := newCtxArgs(t.Name())
|
||||
cache.MarkFailedOrRemove(protoType, ctx, nil) // nil error should remove the host from cache
|
||||
got := cache.Check(protoType, ctx)
|
||||
require.False(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTrackErrors(t *testing.T) {
|
||||
cache := New(3, DefaultMaxHostsCount, []string{"custom error"})
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
cache.MarkFailed(protoType, newCtxArgs("custom"), fmt.Errorf("got: nested: custom error"))
|
||||
cache.MarkFailed(protoType, newCtxArgs("custom"), errors.New("got: nested: custom error"))
|
||||
got := cache.Check(protoType, newCtxArgs("custom"))
|
||||
if i < 2 {
|
||||
// till 3 the host is not flagged to skip
|
||||
@@ -74,6 +86,20 @@ func TestCacheItemDo(t *testing.T) {
|
||||
require.Equal(t, count, 1)
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
cache := New(3, DefaultMaxHostsCount, nil)
|
||||
ctx := newCtxArgs(t.Name())
|
||||
err := errors.New("net/http: timeout awaiting response headers")
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
cache.MarkFailed(protoType, ctx, err)
|
||||
}
|
||||
|
||||
require.True(t, cache.Check(protoType, ctx))
|
||||
cache.Remove(ctx)
|
||||
require.False(t, cache.Check(protoType, ctx))
|
||||
}
|
||||
|
||||
func TestCacheMarkFailed(t *testing.T) {
|
||||
cache := New(3, DefaultMaxHostsCount, nil)
|
||||
|
||||
@@ -90,7 +116,7 @@ func TestCacheMarkFailed(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
normalizedCacheValue := cache.GetKeyFromContext(newCtxArgs(test.host), nil)
|
||||
cache.MarkFailed(protoType, newCtxArgs(test.host), fmt.Errorf("no address found for host"))
|
||||
cache.MarkFailed(protoType, newCtxArgs(test.host), errors.New("no address found for host"))
|
||||
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, failedTarget)
|
||||
@@ -126,7 +152,7 @@ func TestCacheMarkFailedConcurrent(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cache.MarkFailed(protoType, newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host"))
|
||||
cache.MarkFailed(protoType, newCtxArgs(currentTest.host), errors.New("net/http: timeout awaiting response headers"))
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -144,6 +170,26 @@ func TestCacheMarkFailedConcurrent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheCheckConcurrent(t *testing.T) {
|
||||
cache := New(3, DefaultMaxHostsCount, nil)
|
||||
ctx := newCtxArgs(t.Name())
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 1; i <= 100; i++ {
|
||||
wg.Add(1)
|
||||
i := i
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cache.MarkFailed(protoType, ctx, errors.New("no address found for host"))
|
||||
if i >= 3 {
|
||||
got := cache.Check(protoType, ctx)
|
||||
require.True(t, got)
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func newCtxArgs(value string) *contextargs.Context {
|
||||
ctx := contextargs.NewWithInput(context.TODO(), value)
|
||||
return ctx
|
||||
|
||||
Reference in New Issue
Block a user