Resolver cli changes

- Max of 10 simultaneous goroutines to avoid being throttled by DNS
- All template formatting moved to formatLine function
- resolveRepeat is synchronous to avoid being throttled by DNS
This commit is contained in:
Quentin McGaw
2020-06-02 23:10:04 +00:00
parent 20a3327815
commit f262ee6454

View File

@@ -25,20 +25,17 @@ func _main(ctx context.Context) int {
resolver := newResolver(*resolverAddress) resolver := newResolver(*resolverAddress)
lookupIP := newLookupIP(resolver) lookupIP := newLookupIP(resolver)
var domain, template string var domain string
var servers []server var servers []server
switch *provider { switch *provider {
case "pia": case "pia":
domain = "privateinternetaccess.com" domain = "privateinternetaccess.com"
template = "{Region: models.PIARegion(%q), IPs: []net.IP{%s}},"
servers = piaServers() servers = piaServers()
case "windscribe": case "windscribe":
domain = "windscribe.com" domain = "windscribe.com"
template = "{Region: models.WindscribeRegion(%q), IPs: []net.IP{%s}},"
servers = windscribeServers() servers = windscribeServers()
case "surfshark": case "surfshark":
domain = "prod.surfshark.com" domain = "prod.surfshark.com"
template = "{Region: models.SurfsharkRegion(%q), IPs: []net.IP{%s}},"
servers = surfsharkServers() servers = surfsharkServers()
default: default:
fmt.Printf("Provider %q is not supported\n", *provider) fmt.Printf("Provider %q is not supported\n", *provider)
@@ -60,17 +57,19 @@ func _main(ctx context.Context) int {
stringChannel := make(chan string) stringChannel := make(chan string)
errorChannel := make(chan error) errorChannel := make(chan error)
const maxGoroutines = 10
guard := make(chan struct{}, maxGoroutines)
for _, s := range servers { for _, s := range servers {
s := s go func(s server) {
go func() { guard <- struct{}{}
ips, err := resolveRepeat(ctx, lookupIP, s.subdomain+"."+domain, 3) ips, err := resolveRepeat(ctx, lookupIP, s.subdomain+"."+domain, 3)
<-guard
if err != nil { if err != nil {
errorChannel <- err errorChannel <- err
return return
} }
ipsString := formatIPs(ips) stringChannel <- formatLine(*provider, s, ips)
stringChannel <- fmt.Sprintf(template, s.region, ipsString) }(s)
}()
} }
var lines []string var lines []string
var errors []error var errors []error
@@ -98,6 +97,32 @@ func _main(ctx context.Context) int {
return 0 return 0
} }
func formatLine(provider string, s server, ips []net.IP) string {
ipStrings := make([]string, len(ips))
for i := range ips {
ipStrings[i] = fmt.Sprintf("{%s}", strings.ReplaceAll(ips[i].String(), ".", ", "))
}
ipString := strings.Join(ipStrings, ", ")
switch provider {
case "pia":
return fmt.Sprintf(
"{Region: models.PIARegion(%q), IPs: []net.IP{%s}},",
s.region, ipString,
)
case "windscribe":
return fmt.Sprintf(
"{Region: models.WindscribeRegion(%q), IPs: []net.IP{%s}},",
s.region, ipString,
)
case "surfshark":
return fmt.Sprintf(
"{Region: models.SurfsharkRegion(%q), IPs: []net.IP{%s}},",
s.region, ipString,
)
}
return ""
}
type lookupIPFunc func(ctx context.Context, host string) (ips []net.IP, err error) type lookupIPFunc func(ctx context.Context, host string) (ips []net.IP, err error)
func newLookupIP(r *net.Resolver) lookupIPFunc { func newLookupIP(r *net.Resolver) lookupIPFunc {
@@ -125,29 +150,12 @@ func newResolver(ip string) *net.Resolver {
} }
func resolveRepeat(ctx context.Context, lookupIP lookupIPFunc, host string, n int) (ips []net.IP, err error) { func resolveRepeat(ctx context.Context, lookupIP lookupIPFunc, host string, n int) (ips []net.IP, err error) {
ipsChannel := make(chan []net.IP)
errorsChannel := make(chan error)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
go func() { newIPs, err := lookupIP(ctx, host)
ips, err := lookupIP(ctx, host) if err != nil {
if err != nil { return nil, err
errorsChannel <- err
} else {
ipsChannel <- ips
}
}()
}
for i := 0; i < n; i++ {
select {
case err = <-errorsChannel:
case newIPs := <-ipsChannel:
ips = append(ips, newIPs...)
} }
} ips = append(ips, newIPs...)
close(errorsChannel)
close(ipsChannel)
if err != nil {
return nil, err
} }
return uniqueSortedIPs(ips), nil return uniqueSortedIPs(ips), nil
} }
@@ -169,14 +177,6 @@ func uniqueSortedIPs(ips []net.IP) []net.IP {
return ips return ips
} }
func formatIPs(ips []net.IP) (s string) {
ipStrings := make([]string, len(ips))
for i := range ips {
ipStrings[i] = fmt.Sprintf("{%s}", strings.ReplaceAll(ips[i].String(), ".", ", "))
}
return strings.Join(ipStrings, ", ")
}
type server struct { type server struct {
subdomain string subdomain string
region string region string