diff --git a/internal/updater/cyberghost.go b/internal/updater/cyberghost.go index 57e0fcb7..10842823 100644 --- a/internal/updater/cyberghost.go +++ b/internal/updater/cyberghost.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sort" + "time" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" @@ -65,7 +66,8 @@ func tryCyberghostHostname(ctx context.Context, lookupIP lookupIPFunc, defer func() { <-guard }() - IPs, err := resolveRepeat(ctx, lookupIP, host, 2) + const repetition = 10 + IPs, err := resolveRepeat(ctx, lookupIP, host, repetition, time.Second) if err != nil || len(IPs) == 0 { results <- models.CyberghostServer{} return diff --git a/internal/updater/privado.go b/internal/updater/privado.go index 7141f582..a4113727 100644 --- a/internal/updater/privado.go +++ b/internal/updater/privado.go @@ -34,10 +34,9 @@ func findPrivadoServersFromZip(ctx context.Context, client network.Client, looku if err != nil { return nil, nil, err } + + hosts := make([]string, 0, len(contents)) for fileName, content := range contents { - if err := ctx.Err(); err != nil { - return nil, warnings, err - } hostname, warning, err := extractHostFromOVPN(content) if len(warning) > 0 { warnings = append(warnings, warning) @@ -45,16 +44,23 @@ func findPrivadoServersFromZip(ctx context.Context, client network.Client, looku if err != nil { return nil, warnings, fmt.Errorf("%w in %q", err, fileName) } - const repetition = 1 - IPs, err := resolveRepeat(ctx, lookupIP, hostname, repetition) - switch { - case err != nil: - return nil, warnings, err - case len(IPs) == 0: + hosts = append(hosts, hostname) + } + + const repetition = 1 + const timeBetween = 1 + const failOnErr = false + hostToIPs, newWarnings, _ := parallelResolve(ctx, lookupIP, hosts, repetition, timeBetween, failOnErr) + warnings = append(warnings, newWarnings...) + + for hostname, IPs := range hostToIPs { + switch len(IPs) { + case 0: warning := fmt.Sprintf("no IP address found for host %q", hostname) warnings = append(warnings, warning) continue - case len(IPs) > 1: + case 1: + default: warning := fmt.Sprintf("more than one IP address found for host %q", hostname) warnings = append(warnings, warning) } diff --git a/internal/updater/purevpn.go b/internal/updater/purevpn.go index 4cbee0a8..3b4b8732 100644 --- a/internal/updater/purevpn.go +++ b/internal/updater/purevpn.go @@ -38,11 +38,9 @@ func findPurevpnServers(ctx context.Context, client network.Client, lookupIP loo if err != nil { return nil, nil, err } - uniqueServers := map[string]models.PurevpnServer{} + + hosts := make([]string, 0, len(contents)) for fileName, content := range contents { - if err := ctx.Err(); err != nil { - return nil, warnings, err - } if strings.HasSuffix(fileName, "-tcp.ovpn") { continue // only parse UDP files } @@ -53,12 +51,20 @@ func findPurevpnServers(ctx context.Context, client network.Client, lookupIP loo if err != nil { return nil, warnings, fmt.Errorf("%w in %q", err, fileName) } - const repetition = 5 - IPs, err := resolveRepeat(ctx, lookupIP, host, repetition) - switch { - case err != nil: - return nil, warnings, err - case len(IPs) == 0: + hosts = append(hosts, host) + } + + const repetition = 20 + const timeBetween = time.Second + const failOnErr = true + hostToIPs, _, err := parallelResolve(ctx, lookupIP, hosts, repetition, timeBetween, failOnErr) + if err != nil { + return nil, warnings, err + } + + uniqueServers := make(map[string]models.PurevpnServer, len(hostToIPs)) + for host, IPs := range hostToIPs { + if len(IPs) == 0 { warning := fmt.Sprintf("no IP address found for host %q", host) warnings = append(warnings, warning) continue diff --git a/internal/updater/resolver.go b/internal/updater/resolver.go index 45bea5a8..09990f1d 100644 --- a/internal/updater/resolver.go +++ b/internal/updater/resolver.go @@ -5,14 +5,16 @@ import ( "context" "net" "sort" + "time" ) func newResolver(resolverAddress string) *net.Resolver { + d := net.Dialer{} + resolverAddress = net.JoinHostPort(resolverAddress, "53") return &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - d := net.Dialer{} - return d.DialContext(ctx, "udp", net.JoinHostPort(resolverAddress, "53")) + return d.DialContext(ctx, "udp", resolverAddress) }, } } @@ -31,35 +33,76 @@ func newLookupIP(r *net.Resolver) lookupIPFunc { } } -func resolveRepeat(ctx context.Context, lookupIP lookupIPFunc, host string, n int) (ips []net.IP, err error) { - foundIPs := make(chan []net.IP) - errors := make(chan error) +func parallelResolve(ctx context.Context, lookupIP lookupIPFunc, hosts []string, + repetition int, timeBetween time.Duration, failOnErr bool) ( + hostToIPs map[string][]net.IP, warnings []string, err error) { ctx, cancel := context.WithCancel(ctx) defer cancel() - for i := 0; i < n; i++ { - go func() { - newIPs, err := lookupIP(ctx, host) - if err != nil { - errors <- err - } else { - foundIPs <- newIPs - } - }() + + type result struct { + host string + ips []net.IP } - uniqueIPs := make(map[string]struct{}) - for i := 0; i < n; i++ { - select { - case newIPs := <-foundIPs: - for _, ip := range newIPs { - key := ip.String() - uniqueIPs[key] = struct{}{} + results := make(chan result) + defer close(results) + errors := make(chan error) + defer close(errors) + + for _, host := range hosts { + go func(host string) { + ips, err := resolveRepeat(ctx, lookupIP, host, repetition, timeBetween) + if err != nil { + errors <- err + return } + results <- result{ + host: host, + ips: ips, + } + }(host) + } + + hostToIPs = make(map[string][]net.IP, len(hosts)) + + for range hosts { + select { case newErr := <-errors: - if err == nil { + if !failOnErr { + warnings = append(warnings, newErr.Error()) + } else if err == nil { err = newErr cancel() } + case r := <-results: + hostToIPs[r.host] = r.ips + } + } + + return hostToIPs, warnings, err +} + +func resolveRepeat(ctx context.Context, lookupIP lookupIPFunc, host string, + repetition int, timeBetween time.Duration) (ips []net.IP, err error) { + uniqueIPs := make(map[string]struct{}) + + for i := 0; i < repetition; i++ { + newIPs, err := lookupIP(ctx, host) + if err != nil { + return nil, err + } + for _, ip := range newIPs { + key := ip.String() + uniqueIPs[key] = struct{}{} + } + timer := time.NewTimer(timeBetween) + select { + case <-timer.C: + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return nil, ctx.Err() } } @@ -76,5 +119,5 @@ func resolveRepeat(ctx context.Context, lookupIP lookupIPFunc, host string, n in return bytes.Compare(ips[i], ips[j]) < 1 }) - return ips, err + return ips, nil } diff --git a/internal/updater/resolver_test.go b/internal/updater/resolver_test.go index 13fa6b5a..eea7aec8 100644 --- a/internal/updater/resolver_test.go +++ b/internal/updater/resolver_test.go @@ -26,7 +26,6 @@ func Test_resolveRepeat(t *testing.T) { }, lookupIPErr: fmt.Errorf("feeling sick"), n: 1, - ips: []net.IP{}, err: fmt.Errorf("feeling sick"), }, "successful": { @@ -66,7 +65,7 @@ func Test_resolveRepeat(t *testing.T) { } ips, err := resolveRepeat( - context.Background(), lookupIP, host, testCase.n) + context.Background(), lookupIP, host, testCase.n, 0) if testCase.err != nil { require.Error(t, err) assert.Equal(t, testCase.err.Error(), err.Error()) diff --git a/internal/updater/surfshark.go b/internal/updater/surfshark.go index 96ecfc74..02c46084 100644 --- a/internal/updater/surfshark.go +++ b/internal/updater/surfshark.go @@ -7,6 +7,7 @@ import ( "net/http" "sort" "strings" + "time" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/golibs/network" @@ -48,13 +49,24 @@ func findSurfsharkServersFromAPI(ctx context.Context, client network.Client, loo if err := json.Unmarshal(b, &jsonServers); err != nil { return nil, nil, err } + + hosts := make([]string, len(jsonServers)) + for i := range jsonServers { + hosts[i] = jsonServers[i].Host + } + + const repetition = 20 + const timeBetween = time.Second + const failOnErr = true + hostToIPs, _, err := parallelResolve(ctx, lookupIP, hosts, repetition, timeBetween, failOnErr) + if err != nil { + return nil, nil, err + } + for _, jsonServer := range jsonServers { host := jsonServer.Host - const repetition = 5 - IPs, err := resolveRepeat(ctx, lookupIP, host, repetition) - if err != nil { - return nil, warnings, err - } else if len(IPs) == 0 { + IPs := hostToIPs[host] + if len(IPs) == 0 { warning := fmt.Sprintf("no IP address found for host %q", host) warnings = append(warnings, warning) continue @@ -76,10 +88,8 @@ func findSurfsharkServersFromZip(ctx context.Context, client network.Client, loo return nil, nil, err } mapping := surfsharkSubdomainToRegion() + hosts := make([]string, 0, len(contents)) for fileName, content := range contents { - if err := ctx.Err(); err != nil { - return nil, warnings, err - } if strings.HasSuffix(fileName, "_tcp.ovpn") { continue // only parse UDP files } @@ -92,11 +102,19 @@ func findSurfsharkServersFromZip(ctx context.Context, client network.Client, loo warnings = append(warnings, err.Error()+" in "+fileName) continue } - const repetition = 5 - IPs, err := resolveRepeat(ctx, lookupIP, host, repetition) - if err != nil { - return nil, warnings, err - } else if len(IPs) == 0 { + hosts = append(hosts, host) + } + + const repetition = 20 + const timeBetween = time.Second + const failOnErr = true + hostToIPs, _, err := parallelResolve(ctx, lookupIP, hosts, repetition, timeBetween, failOnErr) + if err != nil { + return nil, warnings, err + } + + for host, IPs := range hostToIPs { + if len(IPs) == 0 { warning := fmt.Sprintf("no IP address found for host %q", host) warnings = append(warnings, warning) continue @@ -118,11 +136,8 @@ func findSurfsharkServersFromZip(ctx context.Context, client network.Client, loo } // process entries in mapping that were not in zip file - remainingServers, newWarnings, err := getRemainingServers(ctx, mapping, lookupIP) + remainingServers, newWarnings := getRemainingServers(ctx, mapping, lookupIP) warnings = append(warnings, newWarnings...) - if err != nil { - return nil, warnings, err - } servers = append(servers, remainingServers...) sort.Slice(servers, func(i, j int) bool { @@ -132,31 +147,28 @@ func findSurfsharkServersFromZip(ctx context.Context, client network.Client, loo } func getRemainingServers(ctx context.Context, mapping map[string]string, lookupIP lookupIPFunc) ( - servers []models.SurfsharkServer, warnings []string, err error) { - for subdomain, region := range mapping { - if err := ctx.Err(); err != nil { - return servers, warnings, err - } - host := subdomain + ".prod.surfshark.com" - const repetition = 3 - IPs, err := resolveRepeat(ctx, lookupIP, host, repetition) - if err != nil { - warning := fmt.Sprintf("subdomain %q for region %q from mapping: %s", subdomain, region, err) - warnings = append(warnings, warning) - continue - } else if len(IPs) == 0 { - warning := fmt.Sprintf("subdomain %q for region %q from mapping did not resolve to any IP address", - subdomain, region) - warnings = append(warnings, warning) - continue - } + servers []models.SurfsharkServer, warnings []string) { + hosts := make([]string, len(mapping)) + i := 0 + for subdomain := range mapping { + hosts[i] = subdomain + ".prod.surfshark.com" + } + + const repetition = 20 + const timeBetween = time.Second + const failOnErr = false + hostToIPs, warnings, _ := parallelResolve(ctx, lookupIP, hosts, repetition, timeBetween, failOnErr) + + for host, IPs := range hostToIPs { + subdomain := strings.TrimSuffix(host, ".prod.surfshark.com") server := models.SurfsharkServer{ - Region: region, + Region: mapping[subdomain], IPs: uniqueSortedIPs(IPs), } servers = append(servers, server) } - return servers, warnings, nil + + return servers, warnings } func stringifySurfsharkServers(servers []models.SurfsharkServer) (s string) { diff --git a/internal/updater/vyprvpn.go b/internal/updater/vyprvpn.go index bfacba67..aabd8c10 100644 --- a/internal/updater/vyprvpn.go +++ b/internal/updater/vyprvpn.go @@ -35,6 +35,8 @@ func findVyprvpnServers(ctx context.Context, client network.Client, lookupIP loo if err != nil { return nil, nil, err } + + hostToRegion := make(map[string]string, len(contents)) for fileName, content := range contents { if err := ctx.Err(); err != nil { return nil, warnings, err @@ -46,15 +48,29 @@ func findVyprvpnServers(ctx context.Context, client network.Client, lookupIP loo if err != nil { return nil, warnings, fmt.Errorf("%w in %s", err, fileName) } - const repetitions = 1 - IPs, err := resolveRepeat(ctx, lookupIP, host, repetitions) - if err != nil { - return nil, warnings, err - } region := strings.TrimSuffix(fileName, ".ovpn") region = strings.ReplaceAll(region, " - ", " ") + hostToRegion[host] = region + } + + hosts := make([]string, len(hostToRegion)) + i := 0 + for host := range hostToRegion { + hosts[i] = host + i++ + } + + const repetition = 1 + const timeBetween = 1 + const failOnErr = true + hostToIPs, _, err := parallelResolve(ctx, lookupIP, hosts, repetition, timeBetween, failOnErr) + if err != nil { + return nil, warnings, err + } + + for host, IPs := range hostToIPs { server := models.VyprvpnServer{ - Region: region, + Region: hostToRegion[host], IPs: uniqueSortedIPs(IPs), } servers = append(servers, server)