Code maintenance: storage merging reworked

This commit is contained in:
Quentin McGaw
2020-12-29 17:49:38 +00:00
parent e643ce5b99
commit bedf613cff
6 changed files with 174 additions and 113 deletions

View File

@@ -122,9 +122,8 @@ func _main(background context.Context, buildInfo models.BuildInformation,
} }
// TODO run this in a loop or in openvpn to reload from file without restarting // TODO run this in a loop or in openvpn to reload from file without restarting
storage := storage.New(logger, os) storage := storage.New(logger, os, constants.ServersData)
const updateServerFile = true allServers, err := storage.SyncServers(constants.GetAllServers())
allServers, err := storage.SyncServers(constants.GetAllServers(), updateServerFile)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
return 1 return 1

View File

@@ -70,7 +70,8 @@ func OpenvpnConfig(os os.OS) error {
if err != nil { if err != nil {
return err return err
} }
allServers, err := storage.New(logger, os).SyncServers(constants.GetAllServers(), false) allServers, err := storage.New(logger, os, constants.ServersData).
SyncServers(constants.GetAllServers())
if err != nil { if err != nil {
return err return err
} }
@@ -121,9 +122,8 @@ func Update(args []string, os os.OS) error {
ctx := context.Background() ctx := context.Background()
const clientTimeout = 10 * time.Second const clientTimeout = 10 * time.Second
httpClient := &http.Client{Timeout: clientTimeout} httpClient := &http.Client{Timeout: clientTimeout}
storage := storage.New(logger, os) storage := storage.New(logger, os, constants.ServersData)
const writeSync = false currentServers, err := storage.SyncServers(constants.GetAllServers())
currentServers, err := storage.SyncServers(constants.GetAllServers(), writeSync)
if err != nil { if err != nil {
return fmt.Errorf("cannot update servers: %w", err) return fmt.Errorf("cannot update servers: %w", err)
} }

View File

@@ -29,4 +29,6 @@ const (
ClientKey models.Filepath = "/gluetun/client.key" ClientKey models.Filepath = "/gluetun/client.key"
// Client certificate filepath, used by Cyberghost. // Client certificate filepath, used by Cyberghost.
ClientCertificate models.Filepath = "/gluetun/client.crt" ClientCertificate models.Filepath = "/gluetun/client.crt"
// Servers information filepath.
ServersData = "/gluetun/servers.json"
) )

View File

@@ -14,80 +14,119 @@ func getUnixTimeDifference(unix1, unix2 int64) (difference time.Duration) {
return difference.Truncate(time.Second) return difference.Truncate(time.Second)
} }
func (s *storage) mergeServers(hardcoded, persistent models.AllServers) (merged models.AllServers) { func (s *storage) mergeServers(hardcoded, persisted models.AllServers) models.AllServers {
merged.Version = hardcoded.Version return models.AllServers{
merged.Cyberghost = hardcoded.Cyberghost Version: hardcoded.Version,
if persistent.Cyberghost.Timestamp > hardcoded.Cyberghost.Timestamp { Cyberghost: s.mergeCyberghost(hardcoded.Cyberghost, persisted.Cyberghost),
s.logger.Info("Using Cyberghost servers from file (%s more recent)", Mullvad: s.mergeMullvad(hardcoded.Mullvad, persisted.Mullvad),
getUnixTimeDifference(persistent.Cyberghost.Timestamp, hardcoded.Cyberghost.Timestamp)) Nordvpn: s.mergeNordVPN(hardcoded.Nordvpn, persisted.Nordvpn),
merged.Cyberghost = persistent.Cyberghost Pia: s.mergePIA(hardcoded.Pia, persisted.Pia),
Privado: s.mergePrivado(hardcoded.Privado, persisted.Privado),
Purevpn: s.mergePureVPN(hardcoded.Purevpn, persisted.Purevpn),
Surfshark: s.mergeSurfshark(hardcoded.Surfshark, persisted.Surfshark),
Vyprvpn: s.mergeVyprvpn(hardcoded.Vyprvpn, persisted.Vyprvpn),
Windscribe: s.mergeWindscribe(hardcoded.Windscribe, persisted.Windscribe),
} }
merged.Mullvad = hardcoded.Mullvad }
if persistent.Mullvad.Timestamp > hardcoded.Mullvad.Timestamp {
s.logger.Info("Using Mullvad servers from file (%s more recent)", func (s *storage) mergeCyberghost(hardcoded, persisted models.CyberghostServers) models.CyberghostServers {
getUnixTimeDifference(persistent.Mullvad.Timestamp, hardcoded.Mullvad.Timestamp)) if persisted.Timestamp <= hardcoded.Timestamp {
merged.Mullvad = persistent.Mullvad return hardcoded
} }
merged.Nordvpn = hardcoded.Nordvpn s.logger.Info("Using Cyberghost servers from file (%s more recent)",
if persistent.Nordvpn.Timestamp > hardcoded.Nordvpn.Timestamp { getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
s.logger.Info("Using Nordvpn servers from file (%s more recent)", return persisted
getUnixTimeDifference(persistent.Nordvpn.Timestamp, hardcoded.Nordvpn.Timestamp)) }
merged.Nordvpn = persistent.Nordvpn
} func (s *storage) mergeMullvad(hardcoded, persisted models.MullvadServers) models.MullvadServers {
merged.Pia = hardcoded.Pia if persisted.Timestamp <= hardcoded.Timestamp {
if persistent.Pia.Timestamp > hardcoded.Pia.Timestamp { return hardcoded
versionDiff := hardcoded.Pia.Version - persistent.Pia.Version }
if versionDiff > 0 { s.logger.Info("Using Mullvad servers from file (%s more recent)",
s.logger.Info("Private Internet Access servers from file discarded because they are %d versions behind", getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
versionDiff) return persisted
merged.Pia = hardcoded.Pia }
} else {
s.logger.Info("Using Private Internet Access servers from file (%s more recent)", func (s *storage) mergeNordVPN(hardcoded, persisted models.NordvpnServers) models.NordvpnServers {
getUnixTimeDifference(persistent.Pia.Timestamp, hardcoded.Pia.Timestamp)) if persisted.Timestamp <= hardcoded.Timestamp {
merged.Pia = persistent.Pia return hardcoded
} }
} s.logger.Info("Using NordVPN servers from file (%s more recent)",
getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
merged.Privado = hardcoded.Privado return persisted
versionDiff := int(persistent.Privado.Version) - int(hardcoded.Privado.Version) }
switch {
case versionDiff > 0: func (s *storage) mergePIA(hardcoded, persisted models.PiaServers) models.PiaServers {
s.logger.Info("Using Privado servers from file (%d version(s) more recent)", versionDiff) if persisted.Timestamp <= hardcoded.Timestamp {
merged.Privado = persistent.Privado return hardcoded
case persistent.Privado.Timestamp > hardcoded.Privado.Timestamp: }
s.logger.Info("Using Privado servers from file (%s more recent)", versionDiff := hardcoded.Version - persisted.Version
getUnixTimeDifference(persistent.Privado.Timestamp, hardcoded.Privado.Timestamp)) if versionDiff > 0 {
merged.Privado = persistent.Privado s.logger.Info(
} "PIA servers from file discarded because they are %d versions behind",
versionDiff)
merged.Purevpn = hardcoded.Purevpn return hardcoded
if persistent.Purevpn.Timestamp > hardcoded.Purevpn.Timestamp { }
s.logger.Info("Using Purevpn servers from file (%s more recent)", s.logger.Info("Using PIA servers from file (%s more recent)",
getUnixTimeDifference(persistent.Purevpn.Timestamp, hardcoded.Purevpn.Timestamp)) getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
merged.Purevpn = persistent.Purevpn return persisted
} }
merged.Surfshark = hardcoded.Surfshark
if persistent.Surfshark.Timestamp > hardcoded.Surfshark.Timestamp { func (s *storage) mergePrivado(hardcoded, persisted models.PrivadoServers) models.PrivadoServers {
s.logger.Info("Using Surfshark servers from file (%s more recent)", if persisted.Timestamp <= hardcoded.Timestamp {
getUnixTimeDifference(persistent.Surfshark.Timestamp, hardcoded.Surfshark.Timestamp)) return hardcoded
merged.Surfshark = persistent.Surfshark }
} versionDiff := hardcoded.Version - persisted.Version
merged.Vyprvpn = hardcoded.Vyprvpn if versionDiff > 0 {
if persistent.Vyprvpn.Timestamp > hardcoded.Vyprvpn.Timestamp { s.logger.Info(
s.logger.Info("Using Vyprvpn servers from file (%s more recent)", "Privado servers from file discarded because they are %d versions behind",
getUnixTimeDifference(persistent.Vyprvpn.Timestamp, hardcoded.Vyprvpn.Timestamp)) versionDiff)
merged.Vyprvpn = persistent.Vyprvpn return hardcoded
} }
merged.Windscribe = hardcoded.Windscribe s.logger.Info("Using Privado servers from file (%s more recent)",
if persistent.Windscribe.Timestamp > hardcoded.Windscribe.Timestamp { getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
if hardcoded.Windscribe.Version == 2 && persistent.Windscribe.Version == 1 { return persisted
s.logger.Info("Windscribe servers from file discarded because they are one version behind") }
merged.Windscribe = hardcoded.Windscribe
} else { func (s *storage) mergePureVPN(hardcoded, persisted models.PurevpnServers) models.PurevpnServers {
s.logger.Info("Using Windscribe servers from file (%s more recent)", if persisted.Timestamp <= hardcoded.Timestamp {
getUnixTimeDifference(persistent.Windscribe.Timestamp, hardcoded.Windscribe.Timestamp)) return hardcoded
merged.Windscribe = persistent.Windscribe }
} s.logger.Info("Using PureVPN servers from file (%s more recent)",
} getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
return merged return persisted
}
func (s *storage) mergeSurfshark(hardcoded, persisted models.SurfsharkServers) models.SurfsharkServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
s.logger.Info("Using Surfshark servers from file (%s more recent)",
getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
return persisted
}
func (s *storage) mergeVyprvpn(hardcoded, persisted models.VyprvpnServers) models.VyprvpnServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
s.logger.Info("Using VyprVPN servers from file (%s more recent)",
getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
return persisted
}
func (s *storage) mergeWindscribe(hardcoded, persisted models.WindscribeServers) models.WindscribeServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
versionDiff := hardcoded.Version - persisted.Version
if versionDiff > 0 {
s.logger.Info(
"Windscribe servers from file discarded because they are %d versions behind",
versionDiff)
return hardcoded
}
s.logger.Info("Using Windscribe servers from file (%s more recent)",
getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
return persisted
} }

View File

@@ -7,18 +7,21 @@ import (
) )
type Storage interface { type Storage interface {
SyncServers(hardcodedServers models.AllServers, write bool) (allServers models.AllServers, err error) // Passing an empty filepath disables writing to a file
SyncServers(hardcodedServers models.AllServers) (allServers models.AllServers, err error)
FlushToFile(servers models.AllServers) error FlushToFile(servers models.AllServers) error
} }
type storage struct { type storage struct {
os os.OS os os.OS
logger logging.Logger logger logging.Logger
filepath string
} }
func New(logger logging.Logger, os os.OS) Storage { func New(logger logging.Logger, os os.OS, filepath string) Storage {
return &storage{ return &storage{
os: os, os: os,
logger: logger.WithPrefix("storage: "), logger: logger.WithPrefix("storage: "),
filepath: filepath,
} }
} }

View File

@@ -2,6 +2,7 @@ package storage
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"reflect" "reflect"
@@ -9,8 +10,9 @@ import (
"github.com/qdm12/gluetun/internal/os" "github.com/qdm12/gluetun/internal/os"
) )
const ( var (
jsonFilepath = "/gluetun/servers.json" ErrCannotReadFile = errors.New("cannot read servers from file")
ErrCannotWriteFile = errors.New("cannot write servers to file")
) )
func countServers(allServers models.AllServers) int { func countServers(allServers models.AllServers) int {
@@ -25,38 +27,54 @@ func countServers(allServers models.AllServers) int {
len(allServers.Windscribe.Servers) len(allServers.Windscribe.Servers)
} }
func (s *storage) SyncServers(hardcodedServers models.AllServers, write bool) ( func (s *storage) SyncServers(hardcodedServers models.AllServers) (
allServers models.AllServers, err error) { allServers models.AllServers, err error) {
// Eventually read file serversOnFile, err := s.readFromFile(s.filepath)
var serversOnFile models.AllServers if err != nil {
file, err := s.os.OpenFile(jsonFilepath, os.O_RDONLY, 0) return allServers, fmt.Errorf("%w: %s", ErrCannotReadFile, err)
if err != nil && !os.IsNotExist(err) {
return allServers, err
}
if err == nil {
var serversOnFile models.AllServers
decoder := json.NewDecoder(file)
if err := decoder.Decode(&serversOnFile); err != nil {
_ = file.Close()
return allServers, err
}
return allServers, file.Close()
} }
// Merge data from file and hardcoded hardcodedCount := countServers(hardcodedServers)
s.logger.Info("Merging by most recent %d hardcoded servers and %d servers read from %s", countOnFile := countServers(serversOnFile)
countServers(hardcodedServers), countServers(serversOnFile), jsonFilepath)
allServers = s.mergeServers(hardcodedServers, serversOnFile) if countOnFile == 0 {
s.logger.Info("creating %s with %d hardcoded servers", s.filepath, hardcodedCount)
allServers = hardcodedServers
} else {
s.logger.Info(
"merging by most recent %d hardcoded servers and %d servers read from %s",
hardcodedCount, countOnFile, s.filepath)
allServers = s.mergeServers(hardcodedServers, serversOnFile)
}
// Eventually write file // Eventually write file
if !write || reflect.DeepEqual(serversOnFile, allServers) { if s.filepath == "" || reflect.DeepEqual(serversOnFile, allServers) {
return allServers, nil return allServers, nil
} }
return allServers, s.FlushToFile(allServers)
if err := s.FlushToFile(allServers); err != nil {
return allServers, fmt.Errorf("%w: %s", ErrCannotWriteFile, err)
}
return allServers, nil
}
func (s *storage) readFromFile(filepath string) (servers models.AllServers, err error) {
file, err := s.os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) {
return servers, nil
} else if err != nil {
return servers, err
}
decoder := json.NewDecoder(file)
if err := decoder.Decode(&servers); err != nil {
_ = file.Close()
return servers, err
}
return servers, file.Close()
} }
func (s *storage) FlushToFile(servers models.AllServers) error { func (s *storage) FlushToFile(servers models.AllServers) error {
file, err := s.os.OpenFile(jsonFilepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) file, err := s.os.OpenFile(s.filepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil { if err != nil {
return err return err
} }
@@ -64,7 +82,7 @@ func (s *storage) FlushToFile(servers models.AllServers) error {
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
if err := encoder.Encode(servers); err != nil { if err := encoder.Encode(servers); err != nil {
_ = file.Close() _ = file.Close()
return fmt.Errorf("cannot write to file: %w", err) return err
} }
return file.Close() return file.Close()
} }