diff --git a/internal/updater/alias.go b/internal/updater/alias.go deleted file mode 100644 index 8ba87c5b..00000000 --- a/internal/updater/alias.go +++ /dev/null @@ -1,10 +0,0 @@ -package updater - -import ( - "context" - "net" -) - -type ( - lookupIPFunc func(ctx context.Context, host string) (ips []net.IP, err error) -) diff --git a/internal/updater/cyberghost.go b/internal/updater/cyberghost.go index 10842823..bae91f5b 100644 --- a/internal/updater/cyberghost.go +++ b/internal/updater/cyberghost.go @@ -2,16 +2,16 @@ package updater import ( "context" - "fmt" "sort" "time" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/updater/resolver" ) func (u *updater) updateCyberghost(ctx context.Context) (err error) { - servers, err := findCyberghostServers(ctx, u.lookupIP) + servers, err := findCyberghostServers(ctx, u.presolver) if err != nil { return err } @@ -23,62 +23,82 @@ func (u *updater) updateCyberghost(ctx context.Context) (err error) { return nil } -func findCyberghostServers(ctx context.Context, lookupIP lookupIPFunc) (servers []models.CyberghostServer, err error) { +func findCyberghostServers(ctx context.Context, presolver resolver.Parallel) ( + servers []models.CyberghostServer, err error) { groups := getCyberghostGroups() allCountryCodes := constants.CountryCodes() cyberghostCountryCodes := getCyberghostSubdomainToRegion() possibleCountryCodes := mergeCountryCodes(cyberghostCountryCodes, allCountryCodes) - results := make(chan models.CyberghostServer) - const maxGoroutines = 10 - guard := make(chan struct{}, maxGoroutines) - defer close(guard) + // key is the host + possibleServers := make(map[string]models.CyberghostServer, len(groups)*len(possibleCountryCodes)) + possibleHosts := make([]string, 0, len(groups)*len(possibleCountryCodes)) for groupID, groupName := range groups { for countryCode, region := range possibleCountryCodes { - if err := ctx.Err(); err != nil { - return nil, err - } const domain = "cg-dialup.net" - host := fmt.Sprintf("%s-%s.%s", groupID, countryCode, domain) - go tryCyberghostHostname(ctx, lookupIP, host, groupName, region, results, guard) + possibleHost := groupID + "-" + countryCode + "." + domain + possibleHosts = append(possibleHosts, possibleHost) + possibleServer := models.CyberghostServer{ + Region: region, + Group: groupName, + } + possibleServers[possibleHost] = possibleServer } } - for i := 0; i < len(groups)*len(possibleCountryCodes); i++ { - server := <-results - if server.IPs == nil { - continue + + const ( + maxFailRatio = 1 + minFound = 100 + maxDuration = 10 * time.Second + maxNoNew = 2 + maxFails = 1 + ) + settings := resolver.ParallelSettings{ + MaxFailRatio: maxFailRatio, + MinFound: minFound, + Repeat: resolver.RepeatSettings{ + MaxDuration: maxDuration, + BetweenDuration: time.Second, + MaxNoNew: maxNoNew, + MaxFails: maxFails, + SortIPs: true, + }, + } + hostToIPs, _, err := presolver.Resolve(ctx, possibleHosts, settings) + if err != nil { + return nil, err + } + + if err := ctx.Err(); err != nil { + return nil, err + } + + // Set IPs for servers found + for host, IPs := range hostToIPs { + server := possibleServers[host] + server.IPs = IPs + possibleServers[host] = server + } + + // Remove servers with no IPs (aka not found) + for host, server := range possibleServers { + if len(server.IPs) == 0 { + delete(possibleServers, host) } + } + + // Flatten possibleServers to a slice + servers = make([]models.CyberghostServer, 0, len(possibleServers)) + for _, server := range possibleServers { servers = append(servers, server) } - if err := ctx.Err(); err != nil { - return servers, err - } + sort.Slice(servers, func(i, j int) bool { return servers[i].Region < servers[j].Region }) return servers, nil } -func tryCyberghostHostname(ctx context.Context, lookupIP lookupIPFunc, - host, groupName, region string, - results chan<- models.CyberghostServer, guard chan struct{}) { - guard <- struct{}{} - defer func() { - <-guard - }() - const repetition = 10 - IPs, err := resolveRepeat(ctx, lookupIP, host, repetition, time.Second) - if err != nil || len(IPs) == 0 { - results <- models.CyberghostServer{} - return - } - results <- models.CyberghostServer{ - Region: region, - Group: groupName, - IPs: IPs, - } -} - //nolint:goconst func stringifyCyberghostServers(servers []models.CyberghostServer) (s string) { s = "func CyberghostServers() []models.CyberghostServer {\n" diff --git a/internal/updater/fastestvpn.go b/internal/updater/fastestvpn.go index 895576b2..3f344165 100644 --- a/internal/updater/fastestvpn.go +++ b/internal/updater/fastestvpn.go @@ -7,12 +7,14 @@ import ( "regexp" "sort" "strings" + "time" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/updater/resolver" ) func (u *updater) updateFastestvpn(ctx context.Context) (err error) { - servers, warnings, err := findFastestvpnServersFromZip(ctx, u.client, u.lookupIP) + servers, warnings, err := findFastestvpnServersFromZip(ctx, u.client, u.presolver) if u.options.CLI { for _, warning := range warnings { u.logger.Warn("FastestVPN: %s", warning) @@ -29,7 +31,7 @@ func (u *updater) updateFastestvpn(ctx context.Context) (err error) { return nil } -func findFastestvpnServersFromZip(ctx context.Context, client *http.Client, lookupIP lookupIPFunc) ( +func findFastestvpnServersFromZip(ctx context.Context, client *http.Client, presolver resolver.Parallel) ( servers []models.FastestvpnServer, warnings []string, err error) { const zipURL = "https://support.fastestvpn.com/download/openvpn-tcp-udp-config-files" contents, err := fetchAndExtractFiles(ctx, client, zipURL) @@ -98,10 +100,22 @@ func findFastestvpnServersFromZip(ctx context.Context, client *http.Client, look i++ } - const repetition = 1 - const timeBetween = 0 - const failOnErr = true - hostToIPs, _, err := parallelResolve(ctx, lookupIP, hosts, repetition, timeBetween, failOnErr) + const ( + maxFailRatio = 0.1 + maxNoNew = 1 + maxFails = 2 + ) + settings := resolver.ParallelSettings{ + MaxFailRatio: maxFailRatio, + Repeat: resolver.RepeatSettings{ + MaxDuration: time.Second, + MaxNoNew: maxNoNew, + MaxFails: maxFails, + SortIPs: true, + }, + } + hostToIPs, newWarnings, err := presolver.Resolve(ctx, hosts, settings) + warnings = append(warnings, newWarnings...) if err != nil { return nil, warnings, err } @@ -120,7 +134,7 @@ func findFastestvpnServersFromZip(ctx context.Context, client *http.Client, look TCP: data.TCP, UDP: data.UDP, Country: data.Country, - IPs: uniqueSortedIPs(IPs), + IPs: IPs, } servers = append(servers, server) } diff --git a/internal/updater/hma.go b/internal/updater/hma.go index 62e633ce..9461db64 100644 --- a/internal/updater/hma.go +++ b/internal/updater/hma.go @@ -12,10 +12,11 @@ import ( "unicode" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/updater/resolver" ) func (u *updater) updateHideMyAss(ctx context.Context) (err error) { - servers, warnings, err := findHideMyAssServers(ctx, u.client, u.lookupIP) + servers, warnings, err := findHideMyAssServers(ctx, u.client, u.presolver) if u.options.CLI { for _, warning := range warnings { u.logger.Warn("HideMyAss: %s", warning) @@ -32,7 +33,7 @@ func (u *updater) updateHideMyAss(ctx context.Context) (err error) { return nil } -func findHideMyAssServers(ctx context.Context, client *http.Client, lookupIP lookupIPFunc) ( +func findHideMyAssServers(ctx context.Context, client *http.Client, presolver resolver.Parallel) ( servers []models.HideMyAssServer, warnings []string, err error) { TCPhostToURL, err := findHideMyAssHostToURLForProto(ctx, client, "TCP") if err != nil { @@ -59,10 +60,27 @@ func findHideMyAssServers(ctx context.Context, client *http.Client, lookupIP loo i++ } - const failOnErr = false - const resolveRepetition = 5 - const timeBetween = 2 * time.Second - hostToIPs, warnings, _ := parallelResolve(ctx, lookupIP, hosts, resolveRepetition, timeBetween, failOnErr) + const ( + maxFailRatio = 0.1 + maxDuration = 15 * time.Second + betweenDuration = 2 * time.Second + maxNoNew = 2 + maxFails = 2 + ) + settings := resolver.ParallelSettings{ + MaxFailRatio: maxFailRatio, + Repeat: resolver.RepeatSettings{ + MaxDuration: maxDuration, + BetweenDuration: betweenDuration, + MaxNoNew: maxNoNew, + MaxFails: maxFails, + SortIPs: true, + }, + } + hostToIPs, warnings, err := presolver.Resolve(ctx, hosts, settings) + if err != nil { + return nil, warnings, err + } servers = make([]models.HideMyAssServer, 0, len(hostToIPs)) for host, IPs := range hostToIPs { diff --git a/internal/updater/privado.go b/internal/updater/privado.go index b83a70d6..5ca3f915 100644 --- a/internal/updater/privado.go +++ b/internal/updater/privado.go @@ -5,12 +5,14 @@ import ( "fmt" "net/http" "sort" + "time" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/updater/resolver" ) func (u *updater) updatePrivado(ctx context.Context) (err error) { - servers, warnings, err := findPrivadoServersFromZip(ctx, u.client, u.lookupIP) + servers, warnings, err := findPrivadoServersFromZip(ctx, u.client, u.presolver) if u.options.CLI { for _, warning := range warnings { u.logger.Warn("Privado: %s", warning) @@ -27,7 +29,7 @@ func (u *updater) updatePrivado(ctx context.Context) (err error) { return nil } -func findPrivadoServersFromZip(ctx context.Context, client *http.Client, lookupIP lookupIPFunc) ( +func findPrivadoServersFromZip(ctx context.Context, client *http.Client, presolver resolver.Parallel) ( servers []models.PrivadoServer, warnings []string, err error) { const zipURL = "https://privado.io/apps/ovpn_configs.zip" contents, err := fetchAndExtractFiles(ctx, client, zipURL) @@ -47,11 +49,26 @@ func findPrivadoServersFromZip(ctx context.Context, client *http.Client, lookupI hosts = append(hosts, hostname) } - const repetition = 1 - const timeBetween = 1 - const failOnErr = false - hostToIPs, newWarnings, _ := parallelResolve(ctx, lookupIP, hosts, repetition, timeBetween, failOnErr) + const ( + maxFailRatio = 0.1 + maxDuration = 3 * time.Second + maxNoNew = 1 + maxFails = 2 + ) + settings := resolver.ParallelSettings{ + MaxFailRatio: maxFailRatio, + Repeat: resolver.RepeatSettings{ + MaxDuration: maxDuration, + MaxNoNew: maxNoNew, + MaxFails: maxFails, + SortIPs: true, + }, + } + hostToIPs, newWarnings, err := presolver.Resolve(ctx, hosts, settings) warnings = append(warnings, newWarnings...) + if err != nil { + return nil, warnings, err + } for hostname, IPs := range hostToIPs { switch len(IPs) { diff --git a/internal/updater/privatevpn.go b/internal/updater/privatevpn.go index 2cb54643..4e755c83 100644 --- a/internal/updater/privatevpn.go +++ b/internal/updater/privatevpn.go @@ -11,10 +11,11 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/updater/resolver" ) func (u *updater) updatePrivatevpn(ctx context.Context) (err error) { - servers, warnings, err := findPrivatevpnServersFromZip(ctx, u.client, u.lookupIP) + servers, warnings, err := findPrivatevpnServersFromZip(ctx, u.client, u.presolver) if u.options.CLI { for _, warning := range warnings { u.logger.Warn("Privatevpn: %s", warning) @@ -31,7 +32,7 @@ func (u *updater) updatePrivatevpn(ctx context.Context) (err error) { return nil } -func findPrivatevpnServersFromZip(ctx context.Context, client *http.Client, lookupIP lookupIPFunc) ( +func findPrivatevpnServersFromZip(ctx context.Context, client *http.Client, presolver resolver.Parallel) ( servers []models.PrivatevpnServer, warnings []string, err error) { // Note: all servers do both TCP and UDP const zipURL = "https://privatevpn.com/client/PrivateVPN-TUN.zip" @@ -93,10 +94,27 @@ func findPrivatevpnServersFromZip(ctx context.Context, client *http.Client, look i++ } - const failOnError = false - hostToIPs, newWarnings, _ := parallelResolve(ctx, lookupIP, hostnames, 5, time.Second, failOnError) - if len(newWarnings) > 0 { - warnings = append(warnings, newWarnings...) + const ( + maxFailRatio = 0.1 + maxDuration = 6 * time.Second + betweenDuration = time.Second + maxNoNew = 2 + maxFails = 2 + ) + settings := resolver.ParallelSettings{ + MaxFailRatio: maxFailRatio, + Repeat: resolver.RepeatSettings{ + MaxDuration: maxDuration, + BetweenDuration: betweenDuration, + MaxNoNew: maxNoNew, + MaxFails: maxFails, + SortIPs: true, + }, + } + hostToIPs, newWarnings, err := presolver.Resolve(ctx, hostnames, settings) + warnings = append(warnings, newWarnings...) + if err != nil { + return nil, warnings, err } for hostname, server := range uniqueServers { diff --git a/internal/updater/purevpn.go b/internal/updater/purevpn.go index 7125ca8a..b3a5e9ba 100644 --- a/internal/updater/purevpn.go +++ b/internal/updater/purevpn.go @@ -10,10 +10,11 @@ import ( "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/publicip" + "github.com/qdm12/gluetun/internal/updater/resolver" ) func (u *updater) updatePurevpn(ctx context.Context) (err error) { - servers, warnings, err := findPurevpnServers(ctx, u.client, u.lookupIP) + servers, warnings, err := findPurevpnServers(ctx, u.client, u.presolver) if u.options.CLI { for _, warning := range warnings { u.logger.Warn("PureVPN: %s", warning) @@ -30,7 +31,7 @@ func (u *updater) updatePurevpn(ctx context.Context) (err error) { return nil } -func findPurevpnServers(ctx context.Context, client *http.Client, lookupIP lookupIPFunc) ( +func findPurevpnServers(ctx context.Context, client *http.Client, presolver resolver.Parallel) ( servers []models.PurevpnServer, warnings []string, err error) { const zipURL = "https://s3-us-west-1.amazonaws.com/heartbleed/windows/New+OVPN+Files.zip" contents, err := fetchAndExtractFiles(ctx, client, zipURL) @@ -53,10 +54,25 @@ func findPurevpnServers(ctx context.Context, client *http.Client, lookupIP looku hosts = append(hosts, host) } - const repetition = 20 - const timeBetween = time.Second - const failOnErr = true - hostToIPs, _, err := parallelResolve(ctx, lookupIP, hosts, repetition, timeBetween, failOnErr) + const ( + maxFailRatio = 0.1 + maxDuration = 20 * time.Second + betweenDuration = time.Second + maxNoNew = 2 + maxFails = 2 + ) + settings := resolver.ParallelSettings{ + MaxFailRatio: maxFailRatio, + Repeat: resolver.RepeatSettings{ + MaxDuration: maxDuration, + BetweenDuration: betweenDuration, + MaxNoNew: maxNoNew, + MaxFails: maxFails, + SortIPs: true, + }, + } + hostToIPs, newWarnings, err := presolver.Resolve(ctx, hosts, settings) + warnings = append(warnings, newWarnings...) if err != nil { return nil, warnings, err } diff --git a/internal/updater/resolver.go b/internal/updater/resolver.go deleted file mode 100644 index 0407baa0..00000000 --- a/internal/updater/resolver.go +++ /dev/null @@ -1,134 +0,0 @@ -package updater - -import ( - "bytes" - "context" - "net" - "sort" - "time" -) - -func newResolver(resolverAddress string) *net.Resolver { - d := net.Dialer{} - resolverAddress = net.JoinHostPort(resolverAddress, "53") - return &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - return d.DialContext(ctx, "udp", resolverAddress) - }, - } -} - -func newLookupIP(r *net.Resolver) lookupIPFunc { - return func(ctx context.Context, host string) (ips []net.IP, err error) { - addresses, err := r.LookupIPAddr(ctx, host) - if err != nil { - return nil, err - } - ips = make([]net.IP, len(addresses)) - for i := range addresses { - ips[i] = addresses[i].IP - } - return ips, nil - } -} - -func parallelResolve(ctx context.Context, lookupIP lookupIPFunc, hosts []string, - repetition int, timeBetween time.Duration, failOnErr bool) ( - hostToIPs map[string][]net.IP, warnings []string, err error) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - type result struct { - host string - ips []net.IP - } - - results := make(chan result) - defer close(results) - errors := make(chan error) - defer close(errors) - - for _, host := range hosts { - go func(host string) { - ips, err := resolveRepeat(ctx, lookupIP, host, repetition, timeBetween) - if err != nil { - errors <- err - return - } - results <- result{ - host: host, - ips: ips, - } - }(host) - } - - hostToIPs = make(map[string][]net.IP, len(hosts)) - - for range hosts { - select { - case newErr := <-errors: - if !failOnErr { - warnings = append(warnings, newErr.Error()) - } else if err == nil { - err = newErr - cancel() - } - case r := <-results: - hostToIPs[r.host] = r.ips - } - } - - return hostToIPs, warnings, err -} - -func resolveRepeat(ctx context.Context, lookupIP lookupIPFunc, host string, - repetition int, timeBetween time.Duration) (ips []net.IP, err error) { - uniqueIPs := make(map[string]struct{}) - - i := 0 - for { - newIPs, newErr := lookupIP(ctx, host) - if err == nil { - err = newErr // it's fine to fail some of the resolutions - } - for _, ip := range newIPs { - key := ip.String() - uniqueIPs[key] = struct{}{} - } - - i++ - if i == repetition { - break - } - - timer := time.NewTimer(timeBetween) - select { - case <-timer.C: - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C - } - return nil, ctx.Err() - } - } - - if len(uniqueIPs) == 0 { - return nil, err - } - - ips = make([]net.IP, 0, len(uniqueIPs)) - for key := range uniqueIPs { - ip := net.ParseIP(key) - if ipv4 := ip.To4(); ipv4 != nil { - ip = ipv4 - } - ips = append(ips, ip) - } - - sort.Slice(ips, func(i, j int) bool { - return bytes.Compare(ips[i], ips[j]) < 1 - }) - - return ips, err -} diff --git a/internal/updater/resolver/ips.go b/internal/updater/resolver/ips.go new file mode 100644 index 00000000..2cf5cf53 --- /dev/null +++ b/internal/updater/resolver/ips.go @@ -0,0 +1,15 @@ +package resolver + +import "net" + +func uniqueIPsToSlice(uniqueIPs map[string]struct{}) (ips []net.IP) { + ips = make([]net.IP, 0, len(uniqueIPs)) + for key := range uniqueIPs { + IP := net.ParseIP(key) + if IPv4 := IP.To4(); IPv4 != nil { + IP = IPv4 + } + ips = append(ips, IP) + } + return ips +} diff --git a/internal/updater/resolver/ips_test.go b/internal/updater/resolver/ips_test.go new file mode 100644 index 00000000..7879b8ac --- /dev/null +++ b/internal/updater/resolver/ips_test.go @@ -0,0 +1,41 @@ +package resolver + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_uniqueIPsToSlice(t *testing.T) { + t.Parallel() + testCases := map[string]struct { + inputIPs map[string]struct{} + outputIPs []net.IP + }{ + "nil": { + inputIPs: nil, + outputIPs: []net.IP{}, + }, + "empty": { + inputIPs: map[string]struct{}{}, + outputIPs: []net.IP{}, + }, + "single IPv4": { + inputIPs: map[string]struct{}{"1.1.1.1": {}}, + outputIPs: []net.IP{{1, 1, 1, 1}}, + }, + "two IPv4s": { + inputIPs: map[string]struct{}{"1.1.1.1": {}, "1.1.2.1": {}}, + outputIPs: []net.IP{{1, 1, 1, 1}, {1, 1, 2, 1}}, + }, + } + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + outputIPs := uniqueIPsToSlice(testCase.inputIPs) + assert.Equal(t, testCase.outputIPs, outputIPs) + }) + } +} diff --git a/internal/updater/resolver/net.go b/internal/updater/resolver/net.go new file mode 100644 index 00000000..1e8a2971 --- /dev/null +++ b/internal/updater/resolver/net.go @@ -0,0 +1,17 @@ +package resolver + +import ( + "context" + "net" +) + +func newResolver(resolverAddress string) *net.Resolver { + d := net.Dialer{} + resolverAddress = net.JoinHostPort(resolverAddress, "53") + return &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return d.DialContext(ctx, "udp", resolverAddress) + }, + } +} diff --git a/internal/updater/resolver/parallel.go b/internal/updater/resolver/parallel.go new file mode 100644 index 00000000..be33be46 --- /dev/null +++ b/internal/updater/resolver/parallel.go @@ -0,0 +1,127 @@ +package resolver + +import ( + "context" + "errors" + "fmt" + "net" +) + +type Parallel interface { + Resolve(ctx context.Context, hosts []string, settings ParallelSettings) ( + hostToIPs map[string][]net.IP, warnings []string, err error) +} + +type parallel struct { + repeatResolver Repeat +} + +func NewParallelResolver(address string) Parallel { + return ¶llel{ + repeatResolver: NewRepeat(address), + } +} + +type ParallelSettings struct { + Repeat RepeatSettings + FailEarly bool + // Maximum ratio of the hosts failing DNS resolution + // divided by the total number of hosts requested. + // This value is between 0 and 1. Note this is only + // applicable if FailEarly is not set to true. + MaxFailRatio float64 + // MinFound is the minimum number of hosts to be found. + // If it is bigger than the number of hosts given, it + // is set to the number of hosts given. + MinFound int +} + +type parallelResult struct { + host string + IPs []net.IP +} + +var ( + ErrMinFound = errors.New("not enough hosts found") + ErrMaxFailRatio = errors.New("maximum failure ratio reached") +) + +func (pr *parallel) Resolve(ctx context.Context, hosts []string, + settings ParallelSettings) (hostToIPs map[string][]net.IP, warnings []string, err error) { + minFound := settings.MinFound + if minFound > len(hosts) { + minFound = len(hosts) + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + results := make(chan parallelResult) + defer close(results) + errors := make(chan error) + defer close(errors) + + for _, host := range hosts { + go pr.resolveAsync(ctx, host, settings.Repeat, results, errors) + } + + hostToIPs = make(map[string][]net.IP, len(hosts)) + maxFails := int(settings.MaxFailRatio * float64(len(hosts))) + + for range hosts { + select { + case newErr := <-errors: + if settings.FailEarly { + if err == nil { + // only set the error to the first error encountered + // and not the context canceled errors coming after. + err = newErr + cancel() + } + break + } + + // do not add warnings coming from the call to cancel() + if len(warnings) < maxFails { + warnings = append(warnings, newErr.Error()) + } + + if len(warnings) == maxFails { + cancel() // cancel only once when we reach maxFails + } + case result := <-results: + hostToIPs[result.host] = result.IPs + } + } + + if err != nil { // fail early + return nil, warnings, err + } + + if len(hostToIPs) < minFound { + return nil, warnings, + fmt.Errorf("%w: found %d hosts but expected at least %d", + ErrMinFound, len(hostToIPs), minFound) + } + + failureRatio := float64(len(warnings)) / float64(len(hosts)) + if failureRatio > settings.MaxFailRatio { + return hostToIPs, warnings, + fmt.Errorf("%w: %.2f failure ratio reached", ErrMaxFailRatio, failureRatio) + } + + return hostToIPs, warnings, nil +} + +func (pr *parallel) resolveAsync(ctx context.Context, host string, + settings RepeatSettings, results chan<- parallelResult, errors chan<- error) { + IPs, err := pr.repeatResolver.Resolve(ctx, host, settings) + if err != nil { + errors <- err + return + } + results <- parallelResult{ + host: host, + IPs: IPs, + } +} diff --git a/internal/updater/resolver/repeat.go b/internal/updater/resolver/repeat.go new file mode 100644 index 00000000..a394b042 --- /dev/null +++ b/internal/updater/resolver/repeat.go @@ -0,0 +1,141 @@ +package resolver + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "sort" + "time" +) + +type Repeat interface { + Resolve(ctx context.Context, host string, settings RepeatSettings) (IPs []net.IP, err error) +} + +type repeat struct { + resolver *net.Resolver +} + +func NewRepeat(address string) Repeat { + return &repeat{ + resolver: newResolver(address), + } +} + +type RepeatSettings struct { + MaxDuration time.Duration + BetweenDuration time.Duration + MaxNoNew int + // Maximum consecutive DNS resolution failures + MaxFails int + SortIPs bool +} + +func (r *repeat) Resolve(ctx context.Context, host string, settings RepeatSettings) (ips []net.IP, err error) { + timedCtx, cancel := context.WithTimeout(ctx, settings.MaxDuration) + defer cancel() + + noNewCounter := 0 + failCounter := 0 + uniqueIPs := make(map[string]struct{}) + + for err == nil { + // TODO + // - one resolving every 100ms for round robin DNS responses + // - one every second for time based DNS cycling responses + noNewCounter, failCounter, err = r.resolveOnce(ctx, timedCtx, host, settings, uniqueIPs, noNewCounter, failCounter) + } + + if len(uniqueIPs) == 0 { + return nil, err + } + + ips = uniqueIPsToSlice(uniqueIPs) + + if settings.SortIPs { + sort.Slice(ips, func(i, j int) bool { + return bytes.Compare(ips[i], ips[j]) < 1 + }) + } + + return ips, nil +} + +var ( + ErrMaxNoNew = errors.New("reached the maximum number of no new update") + ErrMaxFails = errors.New("reached the maximum number of consecutive failures") + ErrTimeout = errors.New("reached the timeout") +) + +func (r *repeat) resolveOnce(ctx, timedCtx context.Context, host string, + settings RepeatSettings, uniqueIPs map[string]struct{}, noNewCounter, failCounter int) ( + newNoNewCounter, newFailCounter int, err error) { + IPs, err := r.lookupIPs(timedCtx, host) + if err != nil { + failCounter++ + if settings.MaxFails > 0 && failCounter == settings.MaxFails { + return noNewCounter, failCounter, fmt.Errorf("%w: %d failed attempts resolving %s: %s", + ErrMaxFails, settings.MaxFails, host, err) + } + // it's fine to fail some of the resolutions + return noNewCounter, failCounter, nil + } + failCounter = 0 // reset the counter if we had no error + + anyNew := false + for _, IP := range IPs { + key := IP.String() + if _, ok := uniqueIPs[key]; !ok { + anyNew = true + uniqueIPs[key] = struct{}{} + } + } + + if !anyNew { + noNewCounter++ + } + + if settings.MaxNoNew > 0 && noNewCounter == settings.MaxNoNew { + // we reached the maximum number of resolutions without + // finding any new IP address to our unique IP addresses set. + return noNewCounter, failCounter, + fmt.Errorf("%w: %d times no updated for %d IP addresses found", + ErrMaxNoNew, noNewCounter, len(uniqueIPs)) + } + + timer := time.NewTimer(settings.BetweenDuration) + select { + case <-timer.C: + return noNewCounter, failCounter, nil + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return noNewCounter, failCounter, ctx.Err() + case <-timedCtx.Done(): + if err := ctx.Err(); err != nil { + // timedCtx was canceled from its parent context + return noNewCounter, failCounter, err + } + return noNewCounter, failCounter, + fmt.Errorf("%w: %s", ErrTimeout, timedCtx.Err()) + } +} + +func (r *repeat) lookupIPs(ctx context.Context, host string) (ips []net.IP, err error) { + addresses, err := r.resolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + ips = make([]net.IP, 0, len(addresses)) + for i := range addresses { + ip := addresses[i].IP + if ip == nil { + continue + } + ips = append(ips, ip) + } + return ips, nil +} diff --git a/internal/updater/resolver/resolver.go b/internal/updater/resolver/resolver.go new file mode 100644 index 00000000..b305835e --- /dev/null +++ b/internal/updater/resolver/resolver.go @@ -0,0 +1,3 @@ +// Package resolver defines custom resolvers to resolve +// hosts multiple times with adjustable settings. +package resolver diff --git a/internal/updater/resolver_test.go b/internal/updater/resolver_test.go deleted file mode 100644 index 5afb9434..00000000 --- a/internal/updater/resolver_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package updater - -import ( - "context" - "fmt" - "net" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func Test_resolveRepeat(t *testing.T) { - t.Parallel() - testCases := map[string]struct { - lookupIPResult [][]net.IP - lookupIPErr error - n int - ips []net.IP - err error - }{ - "failure twice": { - lookupIPResult: [][]net.IP{{}, {}}, - lookupIPErr: fmt.Errorf("feeling sick"), - n: 2, - err: fmt.Errorf("feeling sick"), - }, - "failure once": { - lookupIPResult: [][]net.IP{{}, {{1, 1, 1, 1}}}, - lookupIPErr: fmt.Errorf("feeling sick"), - n: 2, - ips: []net.IP{{1, 1, 1, 1}}, - err: fmt.Errorf("feeling sick"), - }, - "successful": { - lookupIPResult: [][]net.IP{ - {{1, 1, 1, 1}, {1, 1, 1, 2}}, - {{2, 1, 1, 1}, {2, 1, 1, 2}}, - {{2, 1, 1, 3}, {2, 1, 1, 2}}, - }, - n: 3, - ips: []net.IP{ - {1, 1, 1, 1}, - {1, 1, 1, 2}, - {2, 1, 1, 1}, - {2, 1, 1, 2}, - {2, 1, 1, 3}, - }, - }, - } - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - if testCase.lookupIPErr == nil { - require.Len(t, testCase.lookupIPResult, testCase.n) - } - const host = "blabla" - i := 0 - mutex := &sync.Mutex{} - lookupIP := func(ctx context.Context, argHost string) ( - ips []net.IP, err error) { - assert.Equal(t, host, argHost) - mutex.Lock() - result := testCase.lookupIPResult[i] - i++ - mutex.Unlock() - return result, testCase.err - } - - ips, err := resolveRepeat( - context.Background(), lookupIP, host, testCase.n, 0) - if testCase.err != nil { - require.Error(t, err) - assert.Equal(t, testCase.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - assert.Equal(t, testCase.ips, ips) - }) - } -} diff --git a/internal/updater/surfshark.go b/internal/updater/surfshark.go index 0b7311dd..93cfd74e 100644 --- a/internal/updater/surfshark.go +++ b/internal/updater/surfshark.go @@ -10,10 +10,11 @@ import ( "time" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/updater/resolver" ) func (u *updater) updateSurfshark(ctx context.Context) (err error) { - servers, warnings, err := findSurfsharkServersFromZip(ctx, u.client, u.lookupIP) + servers, warnings, err := findSurfsharkServersFromZip(ctx, u.client, u.presolver) if u.options.CLI { for _, warning := range warnings { u.logger.Warn("Surfshark: %s", warning) @@ -31,7 +32,7 @@ func (u *updater) updateSurfshark(ctx context.Context) (err error) { } //nolint:deadcode,unused -func findSurfsharkServersFromAPI(ctx context.Context, client *http.Client, lookupIP lookupIPFunc) ( +func findSurfsharkServersFromAPI(ctx context.Context, client *http.Client, presolver resolver.Parallel) ( servers []models.SurfsharkServer, warnings []string, err error) { const url = "https://my.surfshark.com/vpn/api/v4/server/clusters" @@ -69,12 +70,25 @@ func findSurfsharkServersFromAPI(ctx context.Context, client *http.Client, looku hosts[i] = jsonServers[i].Host } - const repetition = 20 - const timeBetween = time.Second - const failOnErr = true - hostToIPs, _, err := parallelResolve(ctx, lookupIP, hosts, repetition, timeBetween, failOnErr) + const ( + maxFailRatio = 0.1 + maxDuration = 20 * time.Second + betweenDuration = time.Second + maxNoNew = 2 + maxFails = 2 + ) + settings := resolver.ParallelSettings{ + MaxFailRatio: maxFailRatio, + Repeat: resolver.RepeatSettings{ + MaxDuration: maxDuration, + BetweenDuration: betweenDuration, + MaxNoNew: maxNoNew, + MaxFails: maxFails, + }, + } + hostToIPs, warnings, err := presolver.Resolve(ctx, hosts, settings) if err != nil { - return nil, nil, err + return nil, warnings, err } for _, jsonServer := range jsonServers { @@ -87,14 +101,14 @@ func findSurfsharkServersFromAPI(ctx context.Context, client *http.Client, looku } server := models.SurfsharkServer{ Region: jsonServer.Country + " " + jsonServer.Location, - IPs: uniqueSortedIPs(IPs), + IPs: IPs, } servers = append(servers, server) } return servers, warnings, nil } -func findSurfsharkServersFromZip(ctx context.Context, client *http.Client, lookupIP lookupIPFunc) ( +func findSurfsharkServersFromZip(ctx context.Context, client *http.Client, presolver resolver.Parallel) ( servers []models.SurfsharkServer, warnings []string, err error) { const zipURL = "https://my.surfshark.com/vpn/api/v1/server/configurations" contents, err := fetchAndExtractFiles(ctx, client, zipURL) @@ -119,10 +133,24 @@ func findSurfsharkServersFromZip(ctx context.Context, client *http.Client, looku hosts = append(hosts, host) } - const repetition = 20 - const timeBetween = time.Second - const failOnErr = true - hostToIPs, _, err := parallelResolve(ctx, lookupIP, hosts, repetition, timeBetween, failOnErr) + const ( + maxFailRatio = 0.1 + maxDuration = 20 * time.Second + betweenDuration = time.Second + maxNoNew = 2 + maxFails = 2 + ) + settings := resolver.ParallelSettings{ + MaxFailRatio: maxFailRatio, + Repeat: resolver.RepeatSettings{ + MaxDuration: maxDuration, + BetweenDuration: betweenDuration, + MaxNoNew: maxNoNew, + MaxFails: maxFails, + }, + } + hostToIPs, newWarnings, err := presolver.Resolve(ctx, hosts, settings) + warnings = append(warnings, newWarnings...) if err != nil { return nil, warnings, err } @@ -144,14 +172,18 @@ func findSurfsharkServersFromZip(ctx context.Context, client *http.Client, looku } server := models.SurfsharkServer{ Region: region, - IPs: uniqueSortedIPs(IPs), + IPs: IPs, } servers = append(servers, server) } // process entries in mapping that were not in zip file - remainingServers, newWarnings := getRemainingServers(ctx, mapping, lookupIP) + remainingServers, newWarnings, err := getRemainingServers(ctx, mapping, presolver) warnings = append(warnings, newWarnings...) + if err != nil { + return nil, warnings, err + } + servers = append(servers, remainingServers...) sort.Slice(servers, func(i, j int) bool { @@ -160,28 +192,46 @@ func findSurfsharkServersFromZip(ctx context.Context, client *http.Client, looku return servers, warnings, nil } -func getRemainingServers(ctx context.Context, mapping map[string]string, lookupIP lookupIPFunc) ( - servers []models.SurfsharkServer, warnings []string) { +func getRemainingServers(ctx context.Context, mapping map[string]string, presolver resolver.Parallel) ( + servers []models.SurfsharkServer, warnings []string, err error) { hosts := make([]string, 0, len(mapping)) for subdomain := range mapping { hosts = append(hosts, subdomain+".prod.surfshark.com") } - const repetition = 20 - const timeBetween = time.Second - const failOnErr = false - hostToIPs, warnings, _ := parallelResolve(ctx, lookupIP, hosts, repetition, timeBetween, failOnErr) + const ( + maxFailRatio = 0.3 + maxDuration = 20 * time.Second + betweenDuration = time.Second + maxNoNew = 2 + maxFails = 2 + ) + settings := resolver.ParallelSettings{ + MaxFailRatio: maxFailRatio, + Repeat: resolver.RepeatSettings{ + MaxDuration: maxDuration, + BetweenDuration: betweenDuration, + MaxNoNew: maxNoNew, + MaxFails: maxFails, + SortIPs: true, + }, + } + hostToIPs, warnings, err := presolver.Resolve(ctx, hosts, settings) + if err != nil { + return nil, warnings, err + } + servers = make([]models.SurfsharkServer, 0, len(hostToIPs)) for host, IPs := range hostToIPs { subdomain := strings.TrimSuffix(host, ".prod.surfshark.com") server := models.SurfsharkServer{ Region: mapping[subdomain], - IPs: uniqueSortedIPs(IPs), + IPs: IPs, } servers = append(servers, server) } - return servers, warnings + return servers, warnings, nil } func stringifySurfsharkServers(servers []models.SurfsharkServer) (s string) { diff --git a/internal/updater/updater.go b/internal/updater/updater.go index 49753007..2562583b 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -9,6 +9,7 @@ import ( "github.com/qdm12/gluetun/internal/configuration" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/updater/resolver" "github.com/qdm12/golibs/logging" ) @@ -24,11 +25,11 @@ type updater struct { servers models.AllServers // Functions for tests - logger logging.Logger - timeNow func() time.Time - println func(s string) - lookupIP lookupIPFunc - client *http.Client + logger logging.Logger + timeNow func() time.Time + println func(s string) + presolver resolver.Parallel + client *http.Client } func New(settings configuration.Updater, httpClient *http.Client, @@ -36,15 +37,14 @@ func New(settings configuration.Updater, httpClient *http.Client, if len(settings.DNSAddress) == 0 { settings.DNSAddress = "1.1.1.1" } - resolver := newResolver(settings.DNSAddress) return &updater{ - logger: logger, - timeNow: time.Now, - println: func(s string) { fmt.Println(s) }, - lookupIP: newLookupIP(resolver), - client: httpClient, - options: settings, - servers: currentServers, + logger: logger, + timeNow: time.Now, + println: func(s string) { fmt.Println(s) }, + presolver: resolver.NewParallelResolver(settings.DNSAddress), + client: httpClient, + options: settings, + servers: currentServers, } } diff --git a/internal/updater/vyprvpn.go b/internal/updater/vyprvpn.go index 288ce59c..b3ff8156 100644 --- a/internal/updater/vyprvpn.go +++ b/internal/updater/vyprvpn.go @@ -6,12 +6,14 @@ import ( "net/http" "sort" "strings" + "time" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/updater/resolver" ) func (u *updater) updateVyprvpn(ctx context.Context) (err error) { - servers, warnings, err := findVyprvpnServers(ctx, u.client, u.lookupIP) + servers, warnings, err := findVyprvpnServers(ctx, u.client, u.presolver) if u.options.CLI { for _, warning := range warnings { u.logger.Warn("Vyprvpn: %s", warning) @@ -28,7 +30,7 @@ func (u *updater) updateVyprvpn(ctx context.Context) (err error) { return nil } -func findVyprvpnServers(ctx context.Context, client *http.Client, lookupIP lookupIPFunc) ( +func findVyprvpnServers(ctx context.Context, client *http.Client, presolver resolver.Parallel) ( servers []models.VyprvpnServer, warnings []string, err error) { const zipURL = "https://support.vyprvpn.com/hc/article_attachments/360052617332/Vypr_OpenVPN_20200320.zip" contents, err := fetchAndExtractFiles(ctx, client, zipURL) @@ -60,18 +62,31 @@ func findVyprvpnServers(ctx context.Context, client *http.Client, lookupIP looku i++ } - const repetition = 1 - const timeBetween = 1 - const failOnErr = true - hostToIPs, _, err := parallelResolve(ctx, lookupIP, hosts, repetition, timeBetween, failOnErr) + const ( + maxFailRatio = 0.1 + maxNoNew = 2 + maxFails = 2 + ) + settings := resolver.ParallelSettings{ + MaxFailRatio: maxFailRatio, + Repeat: resolver.RepeatSettings{ + MaxDuration: time.Second, + MaxNoNew: maxNoNew, + MaxFails: maxFails, + SortIPs: true, + }, + } + hostToIPs, newWarnings, err := presolver.Resolve(ctx, hosts, settings) + warnings = append(warnings, newWarnings...) if err != nil { return nil, warnings, err } + servers = make([]models.VyprvpnServer, 0, len(hostToIPs)) for host, IPs := range hostToIPs { server := models.VyprvpnServer{ Region: hostToRegion[host], - IPs: uniqueSortedIPs(IPs), + IPs: IPs, } servers = append(servers, server) }