diff --git a/cmd/resolver/main.go b/cmd/resolver/main.go index 9d7b1881..5cef7f63 100644 --- a/cmd/resolver/main.go +++ b/cmd/resolver/main.go @@ -25,20 +25,17 @@ func _main(ctx context.Context) int { resolver := newResolver(*resolverAddress) lookupIP := newLookupIP(resolver) - var domain, template string + var domain string var servers []server switch *provider { case "pia": domain = "privateinternetaccess.com" - template = "{Region: models.PIARegion(%q), IPs: []net.IP{%s}}," servers = piaServers() case "windscribe": domain = "windscribe.com" - template = "{Region: models.WindscribeRegion(%q), IPs: []net.IP{%s}}," servers = windscribeServers() case "surfshark": domain = "prod.surfshark.com" - template = "{Region: models.SurfsharkRegion(%q), IPs: []net.IP{%s}}," servers = surfsharkServers() default: fmt.Printf("Provider %q is not supported\n", *provider) @@ -60,17 +57,19 @@ func _main(ctx context.Context) int { stringChannel := make(chan string) errorChannel := make(chan error) + const maxGoroutines = 10 + guard := make(chan struct{}, maxGoroutines) for _, s := range servers { - s := s - go func() { + go func(s server) { + guard <- struct{}{} ips, err := resolveRepeat(ctx, lookupIP, s.subdomain+"."+domain, 3) + <-guard if err != nil { errorChannel <- err return } - ipsString := formatIPs(ips) - stringChannel <- fmt.Sprintf(template, s.region, ipsString) - }() + stringChannel <- formatLine(*provider, s, ips) + }(s) } var lines []string var errors []error @@ -98,6 +97,32 @@ func _main(ctx context.Context) int { return 0 } +func formatLine(provider string, s server, ips []net.IP) string { + ipStrings := make([]string, len(ips)) + for i := range ips { + ipStrings[i] = fmt.Sprintf("{%s}", strings.ReplaceAll(ips[i].String(), ".", ", ")) + } + ipString := strings.Join(ipStrings, ", ") + switch provider { + case "pia": + return fmt.Sprintf( + "{Region: models.PIARegion(%q), IPs: []net.IP{%s}},", + s.region, ipString, + ) + case "windscribe": + return fmt.Sprintf( + "{Region: models.WindscribeRegion(%q), IPs: []net.IP{%s}},", + s.region, ipString, + ) + case "surfshark": + return fmt.Sprintf( + "{Region: models.SurfsharkRegion(%q), IPs: []net.IP{%s}},", + s.region, ipString, + ) + } + return "" +} + type lookupIPFunc func(ctx context.Context, host string) (ips []net.IP, err error) func newLookupIP(r *net.Resolver) lookupIPFunc { @@ -125,29 +150,12 @@ func newResolver(ip string) *net.Resolver { } func resolveRepeat(ctx context.Context, lookupIP lookupIPFunc, host string, n int) (ips []net.IP, err error) { - ipsChannel := make(chan []net.IP) - errorsChannel := make(chan error) for i := 0; i < n; i++ { - go func() { - ips, err := lookupIP(ctx, host) - if err != nil { - errorsChannel <- err - } else { - ipsChannel <- ips - } - }() - } - for i := 0; i < n; i++ { - select { - case err = <-errorsChannel: - case newIPs := <-ipsChannel: - ips = append(ips, newIPs...) + newIPs, err := lookupIP(ctx, host) + if err != nil { + return nil, err } - } - close(errorsChannel) - close(ipsChannel) - if err != nil { - return nil, err + ips = append(ips, newIPs...) } return uniqueSortedIPs(ips), nil } @@ -169,14 +177,6 @@ func uniqueSortedIPs(ips []net.IP) []net.IP { return ips } -func formatIPs(ips []net.IP) (s string) { - ipStrings := make([]string, len(ips)) - for i := range ips { - ipStrings[i] = fmt.Sprintf("{%s}", strings.ReplaceAll(ips[i].String(), ".", ", ")) - } - return strings.Join(ipStrings, ", ") -} - type server struct { subdomain string region string