diff --git a/internal/publicip/multi.go b/internal/publicip/multi.go new file mode 100644 index 00000000..089f5747 --- /dev/null +++ b/internal/publicip/multi.go @@ -0,0 +1,55 @@ +package publicip + +import ( + "context" + "net" + "net/http" +) + +// MultiInfo obtains the public IP address information for every IP +// addresses provided and returns a slice of results with the corresponding +// order as to the IP addresses slice order. +// If an error is encountered, all the operations are canceled and +// an error is returned, so the results returned should be considered +// incomplete in this case. +func MultiInfo(ctx context.Context, client *http.Client, ips []net.IP) ( + results []Result, err error) { + ctx, cancel := context.WithCancel(ctx) + + type asyncResult struct { + index int + result Result + err error + } + resultsCh := make(chan asyncResult) + + for i, ip := range ips { + go func(index int, ip net.IP) { + aResult := asyncResult{ + index: index, + } + aResult.result, aResult.err = Info(ctx, client, ip) + resultsCh <- aResult + }(i, ip) + } + + results = make([]Result, len(ips)) + for i := 0; i < len(ips); i++ { + aResult := <-resultsCh + if aResult.err != nil { + if err == nil { + // Cancel on the first error encountered + err = aResult.err + cancel() + } + continue // ignore errors after the first one + } + + results[aResult.index] = aResult.result + } + + close(resultsCh) + cancel() + + return results, err +} diff --git a/internal/updater/providers/purevpn/servers.go b/internal/updater/providers/purevpn/servers.go index 5d14971d..4009f0cb 100644 --- a/internal/updater/providers/purevpn/servers.go +++ b/internal/updater/providers/purevpn/servers.go @@ -6,6 +6,7 @@ import ( "context" "errors" "fmt" + "net" "net/http" "strings" @@ -73,15 +74,21 @@ func GetServers(ctx context.Context, client *http.Client, ErrNotEnoughServers, len(servers), minServers) } + // Get public IP address information + ipsToGetInfo := make([]net.IP, len(servers)) + for i := range servers { + ipsToGetInfo[i] = servers[i].IPs[0] + } + ipsInfo, err := publicip.MultiInfo(ctx, client, ipsToGetInfo) + if err != nil { + return nil, warnings, err + } + // Dedup by location lts := make(locationToServer) - for _, server := range servers { - ipInfo, err := publicip.Info(ctx, client, server.IPs[0]) - if err != nil { - return nil, warnings, err - } - + for i, server := range servers { // TODO split servers by host + ipInfo := ipsInfo[i] lts.add(ipInfo.Country, ipInfo.Region, ipInfo.City, server.IPs) }