Updater loop with period and http route (#240)

* Updater loop with period and http route
* Using DNS over TLS to update servers
* Better logging
* Remove goroutines for cyberghost updater
* Respects context for servers update (quite slow overall)
* Increase shutdown grace period to 5 seconds
* Update announcement
* Add log lines for each provider update start
This commit is contained in:
Quentin McGaw
2020-09-12 14:04:54 -04:00
committed by GitHub
parent ee64cbf1fd
commit a19efbd923
19 changed files with 358 additions and 82 deletions

View File

@@ -8,53 +8,46 @@ import (
"github.com/qdm12/gluetun/internal/models"
)
func (u *updater) updateCyberghost(ctx context.Context) {
servers := findCyberghostServers(ctx, u.lookupIP)
func (u *updater) updateCyberghost(ctx context.Context) (err error) {
servers, err := findCyberghostServers(ctx, u.lookupIP)
if err != nil {
return err
}
if u.options.Stdout {
u.println(stringifyCyberghostServers(servers))
}
u.servers.Cyberghost.Timestamp = u.timeNow().Unix()
u.servers.Cyberghost.Servers = servers
return nil
}
func findCyberghostServers(ctx context.Context, lookupIP lookupIPFunc) (servers []models.CyberghostServer) {
func findCyberghostServers(ctx context.Context, lookupIP lookupIPFunc) (servers []models.CyberghostServer, err error) {
groups := getCyberghostGroups()
allCountryCodes := getCountryCodes()
cyberghostCountryCodes := getCyberghostSubdomainToRegion()
possibleCountryCodes := mergeCountryCodes(cyberghostCountryCodes, allCountryCodes)
resultsChannel := make(chan models.CyberghostServer)
const maxGoroutines = 10
guard := make(chan struct{}, maxGoroutines)
for groupID, groupName := range groups {
for countryCode, region := range possibleCountryCodes {
go func(groupName, groupID, region, countryCode string) {
host := fmt.Sprintf("%s-%s.cg-dialup.net", groupID, countryCode)
guard <- struct{}{}
IPs, err := resolveRepeat(ctx, lookupIP, host, 2)
if err != nil {
IPs = nil
}
<-guard
resultsChannel <- models.CyberghostServer{
Region: region,
Group: groupName,
IPs: IPs,
}
}(groupName, groupID, region, countryCode)
if err := ctx.Err(); err != nil {
return nil, err
}
host := fmt.Sprintf("%s-%s.cg-dialup.net", groupID, countryCode)
IPs, err := resolveRepeat(ctx, lookupIP, host, 2)
if err != nil || len(IPs) == 0 {
continue
}
servers = append(servers, models.CyberghostServer{
Region: region,
Group: groupName,
IPs: IPs,
})
}
}
for i := 0; i < len(groups)*len(possibleCountryCodes); i++ {
server := <-resultsChannel
if server.IPs == nil {
continue
}
servers = append(servers, server)
}
sort.Slice(servers, func(i, j int) bool {
return servers[i].Region < servers[j].Region
})
return servers
return servers, nil
}
//nolint:goconst

146
internal/updater/loop.go Normal file
View File

@@ -0,0 +1,146 @@
package updater
import (
"context"
"net/http"
"sync"
"time"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/golibs/logging"
)
type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup)
RunRestartTicker(ctx context.Context)
Restart()
Stop()
GetPeriod() (period time.Duration)
SetPeriod(period time.Duration)
}
type looper struct {
period time.Duration
periodMutex sync.RWMutex
updater Updater
storage storage.Storage
setAllServers func(allServers models.AllServers)
logger logging.Logger
restart chan struct{}
stop chan struct{}
updateTicker chan struct{}
}
func NewLooper(options Options, period time.Duration, currentServers models.AllServers,
storage storage.Storage, setAllServers func(allServers models.AllServers),
client *http.Client, logger logging.Logger) Looper {
loggerWithPrefix := logger.WithPrefix("updater: ")
return &looper{
period: period,
updater: New(options, client, currentServers, loggerWithPrefix),
storage: storage,
setAllServers: setAllServers,
logger: loggerWithPrefix,
restart: make(chan struct{}),
stop: make(chan struct{}),
updateTicker: make(chan struct{}),
}
}
func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) Stop() { l.stop <- struct{}{} }
func (l *looper) GetPeriod() (period time.Duration) {
l.periodMutex.RLock()
defer l.periodMutex.RUnlock()
return l.period
}
func (l *looper) SetPeriod(period time.Duration) {
l.periodMutex.Lock()
l.period = period
l.periodMutex.Unlock()
l.updateTicker <- struct{}{}
}
func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Error(err)
l.logger.Info("retrying in 5 minutes")
ctx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel() // just for the linter
<-ctx.Done()
}
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
select {
case <-l.restart:
l.logger.Info("starting...")
case <-ctx.Done():
return
}
defer l.logger.Warn("loop exited")
enabled := true
for ctx.Err() == nil {
for !enabled {
// wait for a signal to re-enable
select {
case <-l.stop:
l.logger.Info("already disabled")
case <-l.restart:
enabled = true
case <-ctx.Done():
return
}
}
// Enabled and has a period set
servers, err := l.updater.UpdateServers(ctx)
if err != nil {
if ctx.Err() != nil {
return
}
l.logAndWait(ctx, err)
continue
}
l.setAllServers(servers)
if err := l.storage.FlushToFile(servers); err != nil {
l.logger.Error(err)
}
l.logger.Info("Updated servers information")
select {
case <-l.restart: // triggered restart
case <-l.stop:
enabled = false
case <-ctx.Done():
return
}
}
}
func (l *looper) RunRestartTicker(ctx context.Context) {
ticker := time.NewTicker(time.Hour)
period := l.GetPeriod()
if period > 0 {
ticker = time.NewTicker(period)
} else {
ticker.Stop()
}
for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
l.restart <- struct{}{}
case <-l.updateTicker:
ticker.Stop()
ticker = time.NewTicker(l.GetPeriod())
}
}
}

View File

@@ -15,8 +15,10 @@ import (
func (u *updater) updateNordvpn() (err error) {
servers, warnings, err := findNordvpnServers(u.httpGet)
for _, warning := range warnings {
u.println(warning)
if u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("Nordvpn: %s", warning)
}
}
if err != nil {
return fmt.Errorf("cannot update Nordvpn servers: %w", err)

View File

@@ -10,7 +10,24 @@ type Options struct {
Surfshark bool
Vyprvpn bool
Windscribe bool
File bool // update JSON file (user side)
Stdout bool // update constants file (maintainer side)
Stdout bool // in order to update constants file (maintainer side)
CLI bool
DNSAddress string
}
func NewOptions(dnsAddress string) Options {
return Options{
Cyberghost: true,
Mullvad: true,
Nordvpn: true,
PIA: true,
PIAold: true,
Purevpn: true,
Surfshark: true,
Vyprvpn: true,
Windscribe: true,
Stdout: false,
CLI: false,
DNSAddress: dnsAddress,
}
}

View File

@@ -52,6 +52,9 @@ func (u *updater) updatePIAOld(ctx context.Context) (err error) {
}
servers := make([]models.PIAServer, 0, len(contents))
for fileName, content := range contents {
if err := ctx.Err(); err != nil {
return err
}
remoteLines := extractRemoteLinesFromOpenvpn(content)
if len(remoteLines) == 0 {
return fmt.Errorf("cannot find any remote lines in %s", fileName)

View File

@@ -14,8 +14,10 @@ import (
func (u *updater) updatePurevpn(ctx context.Context) (err error) {
servers, warnings, err := findPurevpnServers(ctx, u.httpGet, u.lookupIP)
for _, warning := range warnings {
u.println(warning)
if u.options.CLI {
for _, warning := range warnings {
u.logger.Warn("PureVPN: %s", warning)
}
}
if err != nil {
return fmt.Errorf("cannot update Purevpn servers: %w", err)
@@ -76,6 +78,9 @@ func findPurevpnServers(ctx context.Context, httpGet httpGetFunc, lookupIP looku
return data[i].Region < data[j].Region
})
for _, jsonServer := range data {
if err := ctx.Err(); err != nil {
return nil, warnings, err
}
if jsonServer.UDP == "" && jsonServer.TCP == "" {
warnings = append(warnings, fmt.Sprintf("server %s %s %s does not support TCP and UDP for openvpn", jsonServer.Region, jsonServer.Country, jsonServer.City))
continue

View File

@@ -30,6 +30,9 @@ func findSurfsharkServers(ctx context.Context, lookupIP lookupIPFunc) (servers [
return nil, err
}
for fileName, content := range contents {
if err := ctx.Err(); err != nil {
return nil, err
}
if strings.HasSuffix(fileName, "_tcp.ovpn") {
continue // only parse UDP files
}

View File

@@ -6,110 +6,138 @@ import (
"net/http"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/golibs/logging"
)
type Updater interface {
UpdateServers(ctx context.Context) error
UpdateServers(ctx context.Context) (allServers models.AllServers, err error)
}
type updater struct {
// configuration
options Options
storage storage.Storage
// state
servers models.AllServers
// Functions for tests
logger logging.Logger
timeNow func() time.Time
println func(s string)
httpGet httpGetFunc
lookupIP lookupIPFunc
}
func New(options Options, storage storage.Storage, httpClient *http.Client) Updater {
func New(options Options, httpClient *http.Client, currentServers models.AllServers, logger logging.Logger) Updater {
if len(options.DNSAddress) == 0 {
options.DNSAddress = "1.1.1.1"
}
resolver := newResolver(options.DNSAddress)
return &updater{
storage: storage,
logger: logger,
timeNow: time.Now,
println: func(s string) { fmt.Println(s) },
httpGet: httpClient.Get,
lookupIP: newLookupIP(resolver),
options: options,
servers: currentServers,
}
}
// TODO parallelize DNS resolution
func (u *updater) UpdateServers(ctx context.Context) (err error) {
const writeSync = false
u.servers, err = u.storage.SyncServers(constants.GetAllServers(), writeSync)
if err != nil {
return fmt.Errorf("cannot update servers: %w", err)
}
func (u *updater) UpdateServers(ctx context.Context) (allServers models.AllServers, err error) { //nolint:gocognit
if u.options.Cyberghost {
u.updateCyberghost(ctx)
u.logger.Info("updating Cyberghost servers...")
if err := u.updateCyberghost(ctx); err != nil {
if ctxErr := ctx.Err(); ctxErr != nil {
return allServers, ctxErr
}
u.logger.Error(err)
}
}
if u.options.Mullvad {
u.logger.Info("updating Mullvad servers...")
if err := u.updateMullvad(); err != nil {
return err
u.logger.Error(err)
}
if err := ctx.Err(); err != nil {
return allServers, err
}
}
if u.options.Nordvpn {
// TODO support servers offering only TCP or only UDP
u.logger.Info("updating NordVPN servers...")
if err := u.updateNordvpn(); err != nil {
return err
u.logger.Error(err)
}
if err := ctx.Err(); err != nil {
return allServers, err
}
}
if u.options.PIA {
u.logger.Info("updating Private Internet Access (v4) servers...")
if err := u.updatePIA(); err != nil {
return err
u.logger.Error(err)
}
if ctx.Err() != nil {
return allServers, ctx.Err()
}
}
if u.options.PIAold {
u.logger.Info("updating Private Internet Access old (v3) servers...")
if err := u.updatePIAOld(ctx); err != nil {
return err
if ctxErr := ctx.Err(); ctxErr != nil {
return allServers, ctxErr
}
u.logger.Error(err)
}
}
if u.options.Purevpn {
u.logger.Info("updating PureVPN servers...")
// TODO support servers offering only TCP or only UDP
if err := u.updatePurevpn(ctx); err != nil {
return err
if ctxErr := ctx.Err(); ctxErr != nil {
return allServers, ctxErr
}
u.logger.Error(err)
}
}
if u.options.Surfshark {
u.logger.Info("updating Surfshark servers...")
if err := u.updateSurfshark(ctx); err != nil {
return err
if ctxErr := ctx.Err(); ctxErr != nil {
return allServers, ctxErr
}
u.logger.Error(err)
}
}
if u.options.Vyprvpn {
u.logger.Info("updating Vyprvpn servers...")
if err := u.updateVyprvpn(ctx); err != nil {
return err
if ctxErr := ctx.Err(); ctxErr != nil {
return allServers, ctxErr
}
u.logger.Error(err)
}
}
if u.options.Windscribe {
u.updateWindscribe(ctx)
}
if u.options.File {
if err := u.storage.FlushToFile(u.servers); err != nil {
return fmt.Errorf("cannot update servers: %w", err)
u.logger.Info("updating Windscribe servers...")
if err := u.updateWindscribe(ctx); err != nil {
if ctxErr := ctx.Err(); ctxErr != nil {
return allServers, ctxErr
}
u.logger.Error(err)
}
}
return nil
return u.servers, nil
}

View File

@@ -30,6 +30,9 @@ func findVyprvpnServers(ctx context.Context, lookupIP lookupIPFunc) (servers []m
return nil, err
}
for fileName, content := range contents {
if err := ctx.Err(); err != nil {
return nil, err
}
remoteLines := extractRemoteLinesFromOpenvpn(content)
if len(remoteLines) == 0 {
return nil, fmt.Errorf("cannot find any remote lines in %s", fileName)

View File

@@ -2,26 +2,34 @@ package updater
import (
"context"
"fmt"
"sort"
"github.com/qdm12/gluetun/internal/models"
)
func (u *updater) updateWindscribe(ctx context.Context) {
servers := findWindscribeServers(ctx, u.lookupIP)
func (u *updater) updateWindscribe(ctx context.Context) (err error) {
servers, err := findWindscribeServers(ctx, u.lookupIP)
if err != nil {
return fmt.Errorf("cannot update Windscribe servers: %w", err)
}
if u.options.Stdout {
u.println(stringifyWindscribeServers(servers))
}
u.servers.Windscribe.Timestamp = u.timeNow().Unix()
u.servers.Windscribe.Servers = servers
return nil
}
func findWindscribeServers(ctx context.Context, lookupIP lookupIPFunc) (servers []models.WindscribeServer) {
func findWindscribeServers(ctx context.Context, lookupIP lookupIPFunc) (servers []models.WindscribeServer, err error) {
allCountryCodes := getCountryCodes()
windscribeCountryCodes := getWindscribeSubdomainToRegion()
possibleCountryCodes := mergeCountryCodes(windscribeCountryCodes, allCountryCodes)
const domain = "windscribe.com"
for countryCode, region := range possibleCountryCodes {
if err := ctx.Err(); err != nil {
return nil, err
}
host := countryCode + "." + domain
ips, err := resolveRepeat(ctx, lookupIP, host, 2)
if err != nil || len(ips) == 0 {
@@ -35,7 +43,7 @@ func findWindscribeServers(ctx context.Context, lookupIP lookupIPFunc) (servers
sort.Slice(servers, func(i, j int) bool {
return servers[i].Region < servers[j].Region
})
return servers
return servers, nil
}
func mergeCountryCodes(base, extend map[string]string) (merged map[string]string) {