mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-01-31 11:46:16 +08:00
feat(worker): implement subdomain discovery workflow and enhance validation
- Rename IsSubdomainMatchTarget to IsSubdomainOfTarget for clarity - Add subdomain discovery workflow with template loader and helpers - Implement workflow registry for managing scan workflows - Add domain validator package for input validation - Create wordlist server component for managing DNS resolver lists - Add template loader activity for dynamic template management - Implement worker configuration module with environment setup - Update worker dependencies to include projectdiscovery/utils and govalidator - Consolidate workspace directory configuration (WORKSPACE_DIR replaces RESULTS_BASE_PATH) - Update seed generator to use standardized bulk-create API endpoint - Update all service layer calls to use renamed validation function
This commit is contained in:
@@ -65,9 +65,9 @@ func IsURLMatchTarget(urlStr, targetName, targetType string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// IsSubdomainMatchTarget checks if subdomain belongs to target domain
|
||||
// IsSubdomainOfTarget checks if subdomain belongs to target domain
|
||||
// Returns true if subdomain is a valid DNS name and equals target or ends with .target
|
||||
func IsSubdomainMatchTarget(subdomain, targetDomain string) bool {
|
||||
func IsSubdomainOfTarget(subdomain, targetDomain string) bool {
|
||||
subdomain = strings.ToLower(strings.TrimSpace(subdomain))
|
||||
targetDomain = strings.ToLower(strings.TrimSpace(targetDomain))
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ func TestIsURLMatchTarget(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSubdomainMatchTarget(t *testing.T) {
|
||||
func TestIsSubdomainOfTarget(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
subdomain string
|
||||
@@ -99,9 +99,9 @@ func TestIsSubdomainMatchTarget(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsSubdomainMatchTarget(tt.subdomain, tt.targetDomain)
|
||||
result := IsSubdomainOfTarget(tt.subdomain, tt.targetDomain)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsSubdomainMatchTarget(%q, %q) = %v, want %v",
|
||||
t.Errorf("IsSubdomainOfTarget(%q, %q) = %v, want %v",
|
||||
tt.subdomain, tt.targetDomain, result, tt.expected)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -61,7 +61,7 @@ func (s *SubdomainService) BulkCreate(targetID int, names []string) (int, error)
|
||||
subdomains := make([]model.Subdomain, 0, len(names))
|
||||
for _, name := range names {
|
||||
// Check if subdomain matches target domain (includes DNS name validation)
|
||||
if validator.IsSubdomainMatchTarget(name, target.Name) {
|
||||
if validator.IsSubdomainOfTarget(name, target.Name) {
|
||||
subdomains = append(subdomains, model.Subdomain{
|
||||
TargetID: targetID,
|
||||
Name: name,
|
||||
|
||||
@@ -71,7 +71,7 @@ func (s *SubdomainSnapshotService) SaveAndSync(scanID int, targetID int, items [
|
||||
validNames := make([]string, 0, len(items))
|
||||
|
||||
for _, item := range items {
|
||||
if validator.IsSubdomainMatchTarget(item.Name, target.Name) {
|
||||
if validator.IsSubdomainOfTarget(item.Name, target.Name) {
|
||||
snapshots = append(snapshots, model.SubdomainSnapshot{
|
||||
ScanID: scanID,
|
||||
Name: item.Name,
|
||||
|
||||
@@ -303,7 +303,7 @@ def create_targets(client, progress, error_handler, count):
|
||||
try:
|
||||
result = error_handler.retry_with_backoff(
|
||||
client.post,
|
||||
"/api/targets/batch_create",
|
||||
"/api/targets/bulk-create",
|
||||
{"targets": batch}
|
||||
)
|
||||
|
||||
|
||||
@@ -10,9 +10,8 @@ SERVER_TOKEN=your_server_token_here
|
||||
# ===========================================
|
||||
# Paths
|
||||
# ===========================================
|
||||
# Working directory for scan results
|
||||
WORKSPACE_DIR=/opt/orbit/results
|
||||
RESULTS_BASE_PATH=/opt/orbit/results
|
||||
# Working directory for scan execution (stores outputs, logs, temp files)
|
||||
WORKSPACE_DIR=/opt/orbit/workspace
|
||||
|
||||
# DNS resolvers file path
|
||||
RESOLVERS_PATH=/opt/orbit/wordlists/resolvers.txt
|
||||
|
||||
@@ -5,23 +5,28 @@ go 1.24.0
|
||||
toolchain go1.24.5
|
||||
|
||||
require (
|
||||
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/projectdiscovery/utils v0.8.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.5
|
||||
go.uber.org/zap v1.27.0
|
||||
golang.org/x/net v0.49.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
golang.org/x/net v0.49.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
golang.org/x/text v0.33.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so=
|
||||
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
|
||||
@@ -17,18 +21,24 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/projectdiscovery/utils v0.8.0 h1:8d79OCs5xGDNXdKxMUKMY/lgQSUWJMYB1B2Sx+oiqkQ=
|
||||
github.com/projectdiscovery/utils v0.8.0/go.mod h1:CU6tjtyTRxBrnNek+GPJplw4IIHcXNZNKO09kWgqTdg=
|
||||
github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d h1:hrujxIzL1woJ7AwssoOcM/tq5JjjG2yYOc8odClEiXA=
|
||||
github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d/go.mod h1:uugorj2VCxiV1x+LzaIdVa9b4S4qGAcH6cbhh4qVxOU=
|
||||
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
|
||||
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
|
||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
||||
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
@@ -37,8 +47,8 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||
@@ -47,8 +57,6 @@ golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
|
||||
92
worker/internal/activity/template_loader.go
Normal file
92
worker/internal/activity/template_loader.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"sync"
|
||||
"text/template"
|
||||
|
||||
"github.com/orbit/worker/internal/pkg"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TemplateLoader loads and caches command templates from embedded YAML
|
||||
type TemplateLoader struct {
|
||||
fs embed.FS
|
||||
filename string
|
||||
once sync.Once
|
||||
cache map[string]CommandTemplate
|
||||
err error
|
||||
}
|
||||
|
||||
// NewTemplateLoader creates a new template loader
|
||||
func NewTemplateLoader(fs embed.FS, filename string) *TemplateLoader {
|
||||
return &TemplateLoader{
|
||||
fs: fs,
|
||||
filename: filename,
|
||||
}
|
||||
}
|
||||
|
||||
// Load loads templates (cached with sync.Once)
|
||||
func (l *TemplateLoader) Load() (map[string]CommandTemplate, error) {
|
||||
l.once.Do(func() {
|
||||
data, err := l.fs.ReadFile(l.filename)
|
||||
if err != nil {
|
||||
l.err = fmt.Errorf("failed to read %s: %w", l.filename, err)
|
||||
pkg.Logger.Error("Failed to load templates",
|
||||
zap.String("file", l.filename),
|
||||
zap.Error(l.err))
|
||||
return
|
||||
}
|
||||
|
||||
l.cache = make(map[string]CommandTemplate)
|
||||
if err := yaml.Unmarshal(data, &l.cache); err != nil {
|
||||
l.err = fmt.Errorf("failed to parse %s: %w", l.filename, err)
|
||||
pkg.Logger.Error("Failed to parse templates",
|
||||
zap.String("file", l.filename),
|
||||
zap.Error(l.err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := l.validate(); err != nil {
|
||||
l.err = err
|
||||
pkg.Logger.Error("Failed to validate templates",
|
||||
zap.String("file", l.filename),
|
||||
zap.Error(l.err))
|
||||
return
|
||||
}
|
||||
|
||||
pkg.Logger.Info("Templates loaded",
|
||||
zap.String("file", l.filename),
|
||||
zap.Int("count", len(l.cache)))
|
||||
})
|
||||
|
||||
return l.cache, l.err
|
||||
}
|
||||
|
||||
// Get returns a specific template by name
|
||||
func (l *TemplateLoader) Get(name string) (CommandTemplate, error) {
|
||||
templates, err := l.Load()
|
||||
if err != nil {
|
||||
return CommandTemplate{}, fmt.Errorf("templates not loaded: %w", err)
|
||||
}
|
||||
tmpl, ok := templates[name]
|
||||
if !ok {
|
||||
return CommandTemplate{}, fmt.Errorf("template not found: %s", name)
|
||||
}
|
||||
return tmpl, nil
|
||||
}
|
||||
|
||||
// validate checks all templates for syntax errors
|
||||
func (l *TemplateLoader) validate() error {
|
||||
for name, tmpl := range l.cache {
|
||||
if tmpl.Base == "" {
|
||||
return fmt.Errorf("template %s: base command is required", name)
|
||||
}
|
||||
if _, err := template.New(name).Parse(tmpl.Base); err != nil {
|
||||
return fmt.Errorf("template %s: invalid base syntax: %w", name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
124
worker/internal/config/config.go
Normal file
124
worker/internal/config/config.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingServerURL = errors.New("missing required configuration: SERVER_URL")
|
||||
ErrMissingServerToken = errors.New("missing required configuration: SERVER_TOKEN")
|
||||
ErrMissingScanID = errors.New("missing required configuration: SCAN_ID")
|
||||
ErrMissingTargetID = errors.New("missing required configuration: TARGET_ID")
|
||||
ErrMissingTargetName = errors.New("missing required configuration: TARGET_NAME")
|
||||
ErrMissingTargetType = errors.New("missing required configuration: TARGET_TYPE")
|
||||
ErrMissingWorkflowName = errors.New("missing required configuration: WORKFLOW_NAME")
|
||||
ErrMissingConfig = errors.New("missing required configuration: CONFIG")
|
||||
)
|
||||
|
||||
// Config holds all configuration for the worker
|
||||
type Config struct {
|
||||
// Server connection
|
||||
ServerURL string
|
||||
ServerToken string
|
||||
|
||||
// Task parameters (from environment variables)
|
||||
ScanID int
|
||||
TargetID int
|
||||
TargetName string
|
||||
TargetType string // "domain", "ip", "cidr", "url"
|
||||
WorkflowName string // e.g., "subdomain_discovery", "website_scan"
|
||||
WorkspaceDir string // Base directory for workflow execution
|
||||
Config map[string]any
|
||||
|
||||
// Paths
|
||||
LogLevel string
|
||||
}
|
||||
|
||||
// Load reads configuration from environment variables
|
||||
func Load() (*Config, error) {
|
||||
// Load .env file if exists (for local development)
|
||||
_ = godotenv.Load()
|
||||
|
||||
cfg := &Config{
|
||||
ServerURL: os.Getenv("SERVER_URL"),
|
||||
ServerToken: os.Getenv("SERVER_TOKEN"),
|
||||
ScanID: getEnvAsInt("SCAN_ID", 0),
|
||||
TargetID: getEnvAsInt("TARGET_ID", 0),
|
||||
TargetName: os.Getenv("TARGET_NAME"),
|
||||
TargetType: os.Getenv("TARGET_TYPE"),
|
||||
WorkflowName: os.Getenv("WORKFLOW_NAME"),
|
||||
WorkspaceDir: getEnvOrDefault("WORKSPACE_DIR", "/opt/orbit/workspace"),
|
||||
LogLevel: getEnvOrDefault("LOG_LEVEL", "info"),
|
||||
}
|
||||
|
||||
// Parse YAML config from environment variable
|
||||
configYAML := os.Getenv("CONFIG")
|
||||
if configYAML != "" {
|
||||
var config map[string]any
|
||||
if err := yaml.Unmarshal([]byte(configYAML), &config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.Config = config
|
||||
}
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// InputURL returns the URL to download input data for this scan
|
||||
func (c *Config) InputURL() string {
|
||||
return c.ServerURL + "/api/scans/" + strconv.Itoa(c.ScanID) + "/input/"
|
||||
}
|
||||
|
||||
// Validate checks that all required configuration is present
|
||||
func (c *Config) Validate() error {
|
||||
if c.ServerURL == "" {
|
||||
return ErrMissingServerURL
|
||||
}
|
||||
if c.ServerToken == "" {
|
||||
return ErrMissingServerToken
|
||||
}
|
||||
if c.ScanID == 0 {
|
||||
return ErrMissingScanID
|
||||
}
|
||||
if c.TargetID == 0 {
|
||||
return ErrMissingTargetID
|
||||
}
|
||||
if c.TargetName == "" {
|
||||
return ErrMissingTargetName
|
||||
}
|
||||
if c.TargetType == "" {
|
||||
return ErrMissingTargetType
|
||||
}
|
||||
if c.WorkflowName == "" {
|
||||
return ErrMissingWorkflowName
|
||||
}
|
||||
if c.Config == nil {
|
||||
return ErrMissingConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getEnvOrDefault(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvAsInt(key string, defaultValue int) int {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
if intValue, err := strconv.Atoi(value); err == nil {
|
||||
return intValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
75
worker/internal/pkg/validator/domain.go
Normal file
75
worker/internal/pkg/validator/domain.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package validator
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/asaskevich/govalidator"
|
||||
iputil "github.com/projectdiscovery/utils/ip"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
// ValidateDomain checks if the input is a valid domain name
|
||||
// Returns error if invalid
|
||||
// Note: This function only validates, does not normalize
|
||||
func ValidateDomain(domain string) error {
|
||||
domain = strings.TrimSpace(domain)
|
||||
if domain == "" {
|
||||
return ErrEmptyDomain
|
||||
}
|
||||
|
||||
if !govalidator.IsDNSName(domain) {
|
||||
return ErrInvalidDomain
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NormalizeDomain normalizes a domain name:
|
||||
// - Converts to lowercase
|
||||
// - Handles IDN (punycode) conversion
|
||||
// - Removes trailing dots
|
||||
// - Trims whitespace
|
||||
func NormalizeDomain(domain string) (string, error) {
|
||||
domain = strings.TrimSpace(domain)
|
||||
if domain == "" {
|
||||
return "", ErrEmptyDomain
|
||||
}
|
||||
|
||||
// Convert to lowercase
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
// Remove trailing dot (FQDN format)
|
||||
domain = strings.TrimSuffix(domain, ".")
|
||||
|
||||
// Handle IDN (internationalized domain names)
|
||||
// Convert to ASCII (punycode) if needed
|
||||
ascii, err := idna.Lookup.ToASCII(domain)
|
||||
if err != nil {
|
||||
return "", ErrInvalidDomain
|
||||
}
|
||||
|
||||
return ascii, nil
|
||||
}
|
||||
|
||||
// IsValidSubdomainFormat performs fast basic validation for subdomain format
|
||||
// This is optimized for parsing tool outputs and may allow some edge cases
|
||||
// For strict validation, use ValidateDomain instead
|
||||
func IsValidSubdomainFormat(s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
// Skip comment lines (common in tool outputs)
|
||||
if strings.HasPrefix(s, "#") {
|
||||
return false
|
||||
}
|
||||
// Skip lines with spaces (likely error messages or headers)
|
||||
if strings.Contains(s, " ") {
|
||||
return false
|
||||
}
|
||||
// Skip IP addresses first (faster than DNS validation)
|
||||
if iputil.IsIP(s) {
|
||||
return false
|
||||
}
|
||||
// Use standard DNS name validation
|
||||
return govalidator.IsDNSName(s)
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
@@ -11,7 +13,8 @@ import (
|
||||
// BatchSender handles batched sending of scan results to Server.
|
||||
// It accumulates items and sends them in batches to reduce HTTP overhead.
|
||||
type BatchSender struct {
|
||||
client *Client
|
||||
ctx context.Context
|
||||
client ServerClient
|
||||
scanID int
|
||||
targetID int
|
||||
dataType string // "subdomain", "website", "endpoint", "port"
|
||||
@@ -24,11 +27,12 @@ type BatchSender struct {
|
||||
}
|
||||
|
||||
// NewBatchSender creates a new batch sender
|
||||
func NewBatchSender(client *Client, scanID, targetID int, dataType string, batchSize int) *BatchSender {
|
||||
func NewBatchSender(ctx context.Context, client ServerClient, scanID, targetID int, dataType string, batchSize int) *BatchSender {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 1000 // default batch size
|
||||
}
|
||||
return &BatchSender{
|
||||
ctx: ctx,
|
||||
client: client,
|
||||
scanID: scanID,
|
||||
targetID: targetID,
|
||||
@@ -39,7 +43,15 @@ func NewBatchSender(client *Client, scanID, targetID int, dataType string, batch
|
||||
}
|
||||
|
||||
// Add adds an item to the batch. Automatically sends when batch is full.
|
||||
// Returns context.Canceled or context.DeadlineExceeded if context is done.
|
||||
func (s *BatchSender) Add(item any) error {
|
||||
// Check context before processing
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.batch = append(s.batch, item)
|
||||
shouldSend := len(s.batch) >= s.batchSize
|
||||
@@ -72,6 +84,13 @@ func (s *BatchSender) Stats() (items, batches int) {
|
||||
|
||||
// sendBatch sends the current batch to the server
|
||||
func (s *BatchSender) sendBatch() error {
|
||||
// Check context before sending
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
if len(s.batch) == 0 {
|
||||
s.mu.Unlock()
|
||||
@@ -84,43 +103,21 @@ func (s *BatchSender) sendBatch() error {
|
||||
s.batch = s.batch[:0] // reset slice but keep capacity
|
||||
s.mu.Unlock()
|
||||
|
||||
// Build URL and body based on data type (RESTful style)
|
||||
var url string
|
||||
var body map[string]any
|
||||
|
||||
switch s.dataType {
|
||||
case "subdomain":
|
||||
url = fmt.Sprintf("%s/api/worker/scans/%d/subdomains/bulk-upsert", s.client.baseURL, s.scanID)
|
||||
body = map[string]any{
|
||||
"targetId": s.targetID,
|
||||
"subdomains": toSend,
|
||||
if err := s.client.PostBatch(s.ctx, s.scanID, s.targetID, s.dataType, toSend); err != nil {
|
||||
// Check if it's a non-retryable error (4xx)
|
||||
var httpErr *HTTPError
|
||||
if errors.As(err, &httpErr) && !httpErr.IsRetryable() {
|
||||
pkg.Logger.Error("Non-retryable error sending batch (data validation issue)",
|
||||
zap.String("type", s.dataType),
|
||||
zap.Int("count", len(toSend)),
|
||||
zap.Int("statusCode", httpErr.StatusCode),
|
||||
zap.String("response", httpErr.Body))
|
||||
} else {
|
||||
pkg.Logger.Error("Failed to send batch after retries",
|
||||
zap.String("type", s.dataType),
|
||||
zap.Int("count", len(toSend)),
|
||||
zap.Error(err))
|
||||
}
|
||||
case "website":
|
||||
url = fmt.Sprintf("%s/api/worker/scans/%d/websites/bulk-upsert", s.client.baseURL, s.scanID)
|
||||
body = map[string]any{
|
||||
"targetId": s.targetID,
|
||||
"websites": toSend,
|
||||
}
|
||||
case "endpoint":
|
||||
url = fmt.Sprintf("%s/api/worker/scans/%d/endpoints/bulk-upsert", s.client.baseURL, s.scanID)
|
||||
body = map[string]any{
|
||||
"targetId": s.targetID,
|
||||
"endpoints": toSend,
|
||||
}
|
||||
default:
|
||||
// Generic fallback
|
||||
url = fmt.Sprintf("%s/api/worker/scans/%d/%ss/bulk-upsert", s.client.baseURL, s.scanID, s.dataType)
|
||||
body = map[string]any{
|
||||
"targetId": s.targetID,
|
||||
"items": toSend,
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.client.postWithRetry(url, body); err != nil {
|
||||
pkg.Logger.Error("Failed to send batch",
|
||||
zap.String("type", s.dataType),
|
||||
zap.Int("count", len(toSend)),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("failed to send %s batch: %w", s.dataType, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -13,6 +14,40 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HTTPError represents an HTTP error with status code
|
||||
type HTTPError struct {
|
||||
StatusCode int
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *HTTPError) Error() string {
|
||||
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
// IsRetryable returns true if the error should be retried
|
||||
func (e *HTTPError) IsRetryable() bool {
|
||||
// 4xx errors (client errors) should not be retried
|
||||
// 5xx errors (server errors) should be retried
|
||||
return e.StatusCode >= 500
|
||||
}
|
||||
|
||||
// isRetryableError checks if an error should be retried
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if it's an HTTPError
|
||||
var httpErr *HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr.IsRetryable()
|
||||
}
|
||||
|
||||
// Network errors (connection refused, timeout, etc.) should be retried
|
||||
// These are typically wrapped in url.Error or net.Error
|
||||
return true
|
||||
}
|
||||
|
||||
// Client handles all HTTP communication with Server
|
||||
// Implements Provider, ResultSaver, and StatusUpdater interfaces
|
||||
type Client struct {
|
||||
@@ -28,7 +63,7 @@ func NewClient(baseURL, token string) *Client {
|
||||
baseURL: baseURL,
|
||||
token: token,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
Timeout: 15 * time.Minute,
|
||||
},
|
||||
maxRetries: 3,
|
||||
}
|
||||
@@ -60,39 +95,67 @@ func (c *Client) doWithRetry(ctx context.Context, method, url string, body any)
|
||||
default:
|
||||
}
|
||||
|
||||
// Exponential backoff for retries (but not on first attempt)
|
||||
if i > 0 {
|
||||
// Use select to allow cancellation during sleep
|
||||
backoff := time.Duration(1<<uint(i)) * time.Second
|
||||
pkg.Logger.Info("Retrying after backoff",
|
||||
zap.String("url", url),
|
||||
zap.Int("attempt", i+1),
|
||||
zap.Duration("backoff", backoff))
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Duration(1<<i) * time.Second):
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.doRequest(ctx, method, url, body); err == nil {
|
||||
err := c.doRequest(ctx, method, url, body)
|
||||
if err == nil {
|
||||
if i > 0 {
|
||||
pkg.Logger.Info("Request succeeded after retry",
|
||||
zap.String("url", url),
|
||||
zap.Int("attempts", i+1))
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
lastErr = err
|
||||
pkg.Logger.Warn("API call failed, retrying",
|
||||
zap.String("url", url),
|
||||
zap.Int("attempt", i+1),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
// Check if error is retryable
|
||||
if !isRetryableError(err) {
|
||||
pkg.Logger.Error("Non-retryable error, aborting",
|
||||
zap.String("url", url),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Log retryable error
|
||||
pkg.Logger.Warn("Retryable error occurred",
|
||||
zap.String("url", url),
|
||||
zap.Int("attempt", i+1),
|
||||
zap.Int("maxRetries", c.maxRetries),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
pkg.Logger.Error("All retries failed", zap.String("url", url), zap.Error(lastErr))
|
||||
return lastErr
|
||||
pkg.Logger.Error("All retries exhausted",
|
||||
zap.String("url", url),
|
||||
zap.Int("attempts", c.maxRetries),
|
||||
zap.Error(lastErr))
|
||||
return fmt.Errorf("max retries exceeded: %w", lastErr)
|
||||
}
|
||||
|
||||
func (c *Client) doRequest(ctx context.Context, method, url string, body any) error {
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
// JSON marshal error is not retryable (data problem)
|
||||
return &HTTPError{StatusCode: 0, Body: fmt.Sprintf("marshal error: %v", err)}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
// Request creation error is not retryable
|
||||
return &HTTPError{StatusCode: 0, Body: fmt.Sprintf("request creation error: %v", err)}
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -100,13 +163,17 @@ func (c *Client) doRequest(ctx context.Context, method, url string, body any) er
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("request failed: %w", err)
|
||||
// Network errors are retryable (connection issues, timeout, etc.)
|
||||
return fmt.Errorf("network error: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("API error: status=%d, body=%s", resp.StatusCode, string(respBody))
|
||||
return &HTTPError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBody),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
163
worker/internal/server/wordlist.go
Normal file
163
worker/internal/server/wordlist.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/orbit/worker/internal/pkg"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// WordlistInfo contains wordlist metadata from server
|
||||
type WordlistInfo struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
FilePath string `json:"filePath"`
|
||||
FileHash string `json:"fileHash"`
|
||||
FileSize int64 `json:"fileSize"`
|
||||
LineCount int `json:"lineCount"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// GetWordlistInfo fetches wordlist metadata from server
|
||||
func (c *Client) GetWordlistInfo(ctx context.Context, wordlistName string) (*WordlistInfo, error) {
|
||||
url := fmt.Sprintf("%s/api/worker/wordlists/%s", c.baseURL, wordlistName)
|
||||
return fetchJSON[*WordlistInfo](ctx, c, url)
|
||||
}
|
||||
|
||||
// DownloadWordlist downloads a wordlist file from server with atomic write
|
||||
func (c *Client) DownloadWordlist(ctx context.Context, wordlistName, destPath string) error {
|
||||
url := fmt.Sprintf("%s/api/worker/wordlists/%s/download", c.baseURL, wordlistName)
|
||||
|
||||
pkg.Logger.Info("Downloading wordlist",
|
||||
zap.String("name", wordlistName),
|
||||
zap.String("dest", destPath))
|
||||
|
||||
resp, err := c.get(ctx, url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download wordlist: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return &HTTPError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: fmt.Sprintf("downloading wordlist %s: %s", wordlistName, string(body)),
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(destPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create wordlist directory: %w", err)
|
||||
}
|
||||
|
||||
// Write to temporary file first (atomic write)
|
||||
tempPath := destPath + ".tmp"
|
||||
out, err := os.Create(tempPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temporary wordlist file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = out.Close()
|
||||
_ = os.Remove(tempPath) // Clean up temp file on error
|
||||
}()
|
||||
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write wordlist file: %w", err)
|
||||
}
|
||||
|
||||
// Close file before rename
|
||||
if err := out.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close temporary file: %w", err)
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := os.Rename(tempPath, destPath); err != nil {
|
||||
return fmt.Errorf("failed to rename temporary file: %w", err)
|
||||
}
|
||||
|
||||
pkg.Logger.Info("Wordlist downloaded successfully",
|
||||
zap.String("name", wordlistName),
|
||||
zap.String("path", destPath))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnsureWordlistLocal ensures a wordlist file exists locally, downloading if needed
|
||||
// If local file exists but hash doesn't match, re-download and verify
|
||||
func (c *Client) EnsureWordlistLocal(ctx context.Context, wordlistName, basePath string) (string, error) {
|
||||
if wordlistName == "" {
|
||||
return "", fmt.Errorf("wordlist name is empty")
|
||||
}
|
||||
|
||||
// Get wordlist info from server (includes expected hash)
|
||||
info, err := c.GetWordlistInfo(ctx, wordlistName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get wordlist info: %w", err)
|
||||
}
|
||||
|
||||
// Local path: basePath/wordlistName
|
||||
localPath := filepath.Join(basePath, wordlistName)
|
||||
|
||||
// Check if file already exists and hash matches
|
||||
if _, err := os.Stat(localPath); err == nil {
|
||||
localHash, hashErr := calcFileHash(localPath)
|
||||
if hashErr == nil && localHash == info.FileHash {
|
||||
pkg.Logger.Debug("Wordlist hash matches, using local file",
|
||||
zap.String("path", localPath))
|
||||
return localPath, nil
|
||||
}
|
||||
pkg.Logger.Info("Wordlist hash mismatch, re-downloading",
|
||||
zap.String("name", wordlistName),
|
||||
zap.String("expected", info.FileHash),
|
||||
zap.String("local", localHash))
|
||||
}
|
||||
|
||||
// Download from server
|
||||
if err := c.DownloadWordlist(ctx, wordlistName, localPath); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Verify downloaded file hash
|
||||
downloadedHash, err := calcFileHash(localPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to calculate hash of downloaded file: %w", err)
|
||||
}
|
||||
|
||||
if downloadedHash != info.FileHash {
|
||||
// Remove corrupted file
|
||||
_ = os.Remove(localPath)
|
||||
return "", fmt.Errorf("downloaded file hash mismatch: expected=%s, got=%s", info.FileHash, downloadedHash)
|
||||
}
|
||||
|
||||
pkg.Logger.Info("Wordlist verified and ready",
|
||||
zap.String("name", wordlistName),
|
||||
zap.String("path", localPath),
|
||||
zap.String("hash", downloadedHash))
|
||||
|
||||
return localPath, nil
|
||||
}
|
||||
|
||||
// calcFileHash calculates SHA-256 hash of a file
|
||||
func calcFileHash(filePath string) (string, error) {
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return hex.EncodeToString(h.Sum(nil)), nil
|
||||
}
|
||||
56
worker/internal/workflow/registry.go
Normal file
56
worker/internal/workflow/registry.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Factory creates a new Workflow instance with the given workDir
|
||||
type Factory func(workDir string) Workflow
|
||||
|
||||
var (
|
||||
registry = make(map[string]Factory)
|
||||
mu sync.RWMutex
|
||||
)
|
||||
|
||||
// Register adds a workflow factory to the registry
|
||||
// Typically called in init() of each workflow implementation
|
||||
func Register(name string, factory Factory) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if _, exists := registry[name]; exists {
|
||||
panic(fmt.Sprintf("workflow %q already registered", name))
|
||||
}
|
||||
registry[name] = factory
|
||||
}
|
||||
|
||||
// Get returns a new Workflow instance for the given name and workDir
|
||||
// Returns nil if the workflow is not registered
|
||||
func Get(name string, workDir string) Workflow {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
factory, exists := registry[name]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
return factory(workDir)
|
||||
}
|
||||
|
||||
// List returns all registered workflow names
|
||||
func List() []string {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
names := make([]string, 0, len(registry))
|
||||
for name := range registry {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// Exists checks if a workflow is registered
|
||||
func Exists(name string) bool {
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
_, exists := registry[name]
|
||||
return exists
|
||||
}
|
||||
115
worker/internal/workflow/subdomain_discovery/helpers.go
Normal file
115
worker/internal/workflow/subdomain_discovery/helpers.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package subdomain_discovery
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/orbit/worker/internal/activity"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeout = 86400 // default max timeout: 24 hours
|
||||
)
|
||||
|
||||
// buildCommand gets the template and builds the command string
|
||||
func buildCommand(toolName string, params map[string]string, config map[string]any) (string, error) {
|
||||
tmpl, err := getTemplate(toolName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
builder := activity.NewCommandBuilder()
|
||||
return builder.Build(tmpl, params, config)
|
||||
}
|
||||
|
||||
// isStageEnabled checks if a stage is enabled in the config
|
||||
func isStageEnabled(config map[string]any, stageName string) bool {
|
||||
stageConfig, ok := config[stageName].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
enabled, ok := stageConfig["enabled"].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// isToolEnabled checks if a specific tool is enabled within a stage
|
||||
func isToolEnabled(stageConfig map[string]any, toolName string) bool {
|
||||
toolConfig, ok := stageConfig[toolName].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
enabled, ok := toolConfig["enabled"].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// getConfigPath retrieves a nested config section by path
|
||||
func getConfigPath(config map[string]any, path string) map[string]any {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(path, ".")
|
||||
current := config
|
||||
for _, part := range parts {
|
||||
next, ok := current[part].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
current = next
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
// getTimeout extracts timeout from tool config
|
||||
// If not configured, returns default (24 hours)
|
||||
func getTimeout(toolConfig map[string]any) time.Duration {
|
||||
if toolConfig == nil {
|
||||
return time.Duration(defaultTimeout) * time.Second
|
||||
}
|
||||
|
||||
if timeout, ok := toolConfig["timeout"].(int); ok && timeout > 0 {
|
||||
return time.Duration(timeout) * time.Second
|
||||
}
|
||||
if timeout, ok := toolConfig["timeout"].(float64); ok && timeout > 0 {
|
||||
return time.Duration(timeout) * time.Second
|
||||
}
|
||||
|
||||
return time.Duration(defaultTimeout) * time.Second
|
||||
}
|
||||
|
||||
// countFileLines counts non-empty lines in a file
|
||||
func countFileLines(filePath string) int {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
count := 0
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
if strings.TrimSpace(scanner.Text()) != "" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// getStringValue extracts a string value from config with a default
|
||||
func getStringValue(config map[string]any, key, defaultValue string) string {
|
||||
if config == nil {
|
||||
return defaultValue
|
||||
}
|
||||
if value, ok := config[key].(string); ok && value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// sanitizeFilename removes or replaces characters that are invalid in filenames
|
||||
func sanitizeFilename(name string) string {
|
||||
// Replace common problematic characters
|
||||
re := regexp.MustCompile(`[<>:"/\\|?*\s]`)
|
||||
return re.ReplaceAllString(name, "_")
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package subdomain_discovery
|
||||
|
||||
import (
|
||||
"embed"
|
||||
|
||||
"github.com/orbit/worker/internal/activity"
|
||||
)
|
||||
|
||||
//go:embed templates.yaml
|
||||
var templatesFS embed.FS
|
||||
|
||||
// loader is the template loader for subdomain discovery workflow
|
||||
var loader = activity.NewTemplateLoader(templatesFS, "templates.yaml")
|
||||
|
||||
// getTemplate returns the command template for a given tool
|
||||
func getTemplate(toolName string) (activity.CommandTemplate, error) {
|
||||
return loader.Get(toolName)
|
||||
}
|
||||
198
worker/internal/workflow/subdomain_discovery/workflow.go
Normal file
198
worker/internal/workflow/subdomain_discovery/workflow.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package subdomain_discovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/orbit/worker/internal/activity"
|
||||
"github.com/orbit/worker/internal/pkg"
|
||||
"github.com/orbit/worker/internal/pkg/validator"
|
||||
"github.com/orbit/worker/internal/server"
|
||||
"github.com/orbit/worker/internal/workflow"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const Name = "subdomain_discovery"
|
||||
|
||||
func init() {
|
||||
workflow.Register(Name, func(workDir string) workflow.Workflow {
|
||||
return New(workDir)
|
||||
})
|
||||
}
|
||||
|
||||
// Workflow implements the subdomain discovery scan workflow
|
||||
type Workflow struct {
|
||||
runner *activity.Runner
|
||||
commandBuilder *activity.CommandBuilder
|
||||
parser *Parser
|
||||
workDir string
|
||||
}
|
||||
|
||||
// New creates a new subdomain discovery workflow
|
||||
func New(workDir string) *Workflow {
|
||||
return &Workflow{
|
||||
runner: activity.NewRunner(workDir),
|
||||
commandBuilder: activity.NewCommandBuilder(),
|
||||
parser: NewParser(),
|
||||
workDir: workDir,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Workflow) Name() string {
|
||||
return Name
|
||||
}
|
||||
|
||||
// Execute runs the subdomain discovery workflow
|
||||
func (w *Workflow) Execute(params *workflow.Params) (*workflow.Output, error) {
|
||||
// Initialize and validate
|
||||
ctx, err := w.initialize(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Run all stages
|
||||
allResults := w.runAllStages(ctx)
|
||||
|
||||
// Store result files for streaming in SaveResults
|
||||
output := &workflow.Output{
|
||||
Data: allResults.files, // Pass file paths instead of parsed data
|
||||
Metrics: &workflow.Metrics{
|
||||
ProcessedCount: 0, // Will be updated after streaming
|
||||
FailedCount: len(allResults.failed),
|
||||
FailedTools: allResults.failed,
|
||||
},
|
||||
}
|
||||
|
||||
// Check for complete failure
|
||||
if len(allResults.failed) > 0 && len(allResults.success) == 0 {
|
||||
return output, fmt.Errorf("all tools failed")
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// SaveResults streams subdomain results to the server in batches
|
||||
func (w *Workflow) SaveResults(ctx context.Context, client server.ServerClient, params *workflow.Params, output *workflow.Output) error {
|
||||
files, ok := output.Data.([]string)
|
||||
if !ok || len(files) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create batch sender with context
|
||||
sender := server.NewBatchSender(ctx, client, params.ScanID, params.TargetID, "subdomain", 5000)
|
||||
|
||||
// Stream and deduplicate from files
|
||||
subdomainCh, errCh := w.parser.StreamAndDeduplicate(files)
|
||||
|
||||
// Send subdomains in batches
|
||||
for subdomain := range subdomainCh {
|
||||
if err := sender.Add(map[string]string{"name": subdomain}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check for streaming errors
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error streaming results: %w", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
// Flush remaining items
|
||||
if err := sender.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update metrics
|
||||
items, batches := sender.Stats()
|
||||
output.Metrics.ProcessedCount = items
|
||||
pkg.Logger.Info("Results saved",
|
||||
zap.Int("subdomains", items),
|
||||
zap.Int("batches", batches))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// initialize validates params and prepares the workflow context
|
||||
func (w *Workflow) initialize(params *workflow.Params) (*workflowContext, error) {
|
||||
// Config can be either nested under workflow name or flat
|
||||
// Try nested first: { "subdomain_discovery": { "passive-tools": ... } }
|
||||
// Then flat: { "passive-tools": ... }
|
||||
flowConfig := getConfigPath(params.ScanConfig, Name)
|
||||
if flowConfig == nil {
|
||||
// Use flat config directly
|
||||
flowConfig = params.ScanConfig
|
||||
}
|
||||
if flowConfig == nil {
|
||||
return nil, fmt.Errorf("missing %s config", Name)
|
||||
}
|
||||
|
||||
workDir := filepath.Join(params.WorkDir, Name)
|
||||
if err := os.MkdirAll(workDir, 0755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Subdomain discovery only works for domain type targets
|
||||
if params.TargetType != "domain" {
|
||||
return nil, fmt.Errorf("subdomain discovery requires domain target, got %s", params.TargetType)
|
||||
}
|
||||
|
||||
// Normalize domain first
|
||||
normalizedDomain, err := validator.NormalizeDomain(params.TargetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to normalize domain: %w", err)
|
||||
}
|
||||
|
||||
// Validate normalized domain
|
||||
if err := validator.ValidateDomain(normalizedDomain); err != nil {
|
||||
return nil, fmt.Errorf("invalid target domain: %w", err)
|
||||
}
|
||||
|
||||
// Wrap in slice for compatibility with multi-domain processing
|
||||
domains := []string{normalizedDomain}
|
||||
|
||||
pkg.Logger.Info("Workflow initialized",
|
||||
zap.Int("scanId", params.ScanID),
|
||||
zap.String("targetName", params.TargetName),
|
||||
zap.String("targetType", params.TargetType))
|
||||
|
||||
ctx := context.Background()
|
||||
providerConfigPath, err := w.setupProviderConfig(ctx, params, workDir)
|
||||
if err != nil {
|
||||
// Log warning but continue - provider config is optional (enhances results but not required)
|
||||
pkg.Logger.Warn("Failed to setup provider config, subfinder will run without API keys",
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
return &workflowContext{
|
||||
ctx: ctx,
|
||||
domains: domains,
|
||||
config: flowConfig,
|
||||
workDir: workDir,
|
||||
providerConfigPath: providerConfigPath,
|
||||
serverClient: params.ServerClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// setupProviderConfig fetches and writes the subfinder provider config
|
||||
// Returns empty string if no config available, error if fetch/write failed
|
||||
func (w *Workflow) setupProviderConfig(ctx context.Context, params *workflow.Params, workDir string) (string, error) {
|
||||
providerConfig, err := params.ServerClient.GetProviderConfig(ctx, params.ScanID, toolSubfinder)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get provider config: %w", err)
|
||||
}
|
||||
if providerConfig == nil || providerConfig.Content == "" {
|
||||
return "", nil // No config available, not an error
|
||||
}
|
||||
|
||||
configPath := filepath.Join(workDir, "provider-config.yaml")
|
||||
if err := os.WriteFile(configPath, []byte(providerConfig.Content), 0600); err != nil {
|
||||
return "", fmt.Errorf("failed to write provider config: %w", err)
|
||||
}
|
||||
pkg.Logger.Info("Provider config written", zap.String("path", configPath))
|
||||
return configPath, nil
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package workflow
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/orbit/worker/internal/server"
|
||||
@@ -35,5 +36,5 @@ type Metrics struct {
|
||||
type Workflow interface {
|
||||
Name() string
|
||||
Execute(params *Params) (*Output, error)
|
||||
SaveResults(client *server.Client, params *Params, output *Output) error
|
||||
SaveResults(ctx context.Context, client server.ServerClient, params *Params, output *Output) error
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user