Code maintenance: storage merging reworked
This commit is contained in:
@@ -122,9 +122,8 @@ func _main(background context.Context, buildInfo models.BuildInformation,
|
||||
}
|
||||
|
||||
// TODO run this in a loop or in openvpn to reload from file without restarting
|
||||
storage := storage.New(logger, os)
|
||||
const updateServerFile = true
|
||||
allServers, err := storage.SyncServers(constants.GetAllServers(), updateServerFile)
|
||||
storage := storage.New(logger, os, constants.ServersData)
|
||||
allServers, err := storage.SyncServers(constants.GetAllServers())
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
|
||||
@@ -70,7 +70,8 @@ func OpenvpnConfig(os os.OS) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
allServers, err := storage.New(logger, os).SyncServers(constants.GetAllServers(), false)
|
||||
allServers, err := storage.New(logger, os, constants.ServersData).
|
||||
SyncServers(constants.GetAllServers())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -121,9 +122,8 @@ func Update(args []string, os os.OS) error {
|
||||
ctx := context.Background()
|
||||
const clientTimeout = 10 * time.Second
|
||||
httpClient := &http.Client{Timeout: clientTimeout}
|
||||
storage := storage.New(logger, os)
|
||||
const writeSync = false
|
||||
currentServers, err := storage.SyncServers(constants.GetAllServers(), writeSync)
|
||||
storage := storage.New(logger, os, constants.ServersData)
|
||||
currentServers, err := storage.SyncServers(constants.GetAllServers())
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot update servers: %w", err)
|
||||
}
|
||||
|
||||
@@ -29,4 +29,6 @@ const (
|
||||
ClientKey models.Filepath = "/gluetun/client.key"
|
||||
// Client certificate filepath, used by Cyberghost.
|
||||
ClientCertificate models.Filepath = "/gluetun/client.crt"
|
||||
// Servers information filepath.
|
||||
ServersData = "/gluetun/servers.json"
|
||||
)
|
||||
|
||||
@@ -14,80 +14,119 @@ func getUnixTimeDifference(unix1, unix2 int64) (difference time.Duration) {
|
||||
return difference.Truncate(time.Second)
|
||||
}
|
||||
|
||||
func (s *storage) mergeServers(hardcoded, persistent models.AllServers) (merged models.AllServers) {
|
||||
merged.Version = hardcoded.Version
|
||||
merged.Cyberghost = hardcoded.Cyberghost
|
||||
if persistent.Cyberghost.Timestamp > hardcoded.Cyberghost.Timestamp {
|
||||
s.logger.Info("Using Cyberghost servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persistent.Cyberghost.Timestamp, hardcoded.Cyberghost.Timestamp))
|
||||
merged.Cyberghost = persistent.Cyberghost
|
||||
func (s *storage) mergeServers(hardcoded, persisted models.AllServers) models.AllServers {
|
||||
return models.AllServers{
|
||||
Version: hardcoded.Version,
|
||||
Cyberghost: s.mergeCyberghost(hardcoded.Cyberghost, persisted.Cyberghost),
|
||||
Mullvad: s.mergeMullvad(hardcoded.Mullvad, persisted.Mullvad),
|
||||
Nordvpn: s.mergeNordVPN(hardcoded.Nordvpn, persisted.Nordvpn),
|
||||
Pia: s.mergePIA(hardcoded.Pia, persisted.Pia),
|
||||
Privado: s.mergePrivado(hardcoded.Privado, persisted.Privado),
|
||||
Purevpn: s.mergePureVPN(hardcoded.Purevpn, persisted.Purevpn),
|
||||
Surfshark: s.mergeSurfshark(hardcoded.Surfshark, persisted.Surfshark),
|
||||
Vyprvpn: s.mergeVyprvpn(hardcoded.Vyprvpn, persisted.Vyprvpn),
|
||||
Windscribe: s.mergeWindscribe(hardcoded.Windscribe, persisted.Windscribe),
|
||||
}
|
||||
merged.Mullvad = hardcoded.Mullvad
|
||||
if persistent.Mullvad.Timestamp > hardcoded.Mullvad.Timestamp {
|
||||
s.logger.Info("Using Mullvad servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persistent.Mullvad.Timestamp, hardcoded.Mullvad.Timestamp))
|
||||
merged.Mullvad = persistent.Mullvad
|
||||
}
|
||||
merged.Nordvpn = hardcoded.Nordvpn
|
||||
if persistent.Nordvpn.Timestamp > hardcoded.Nordvpn.Timestamp {
|
||||
s.logger.Info("Using Nordvpn servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persistent.Nordvpn.Timestamp, hardcoded.Nordvpn.Timestamp))
|
||||
merged.Nordvpn = persistent.Nordvpn
|
||||
}
|
||||
merged.Pia = hardcoded.Pia
|
||||
if persistent.Pia.Timestamp > hardcoded.Pia.Timestamp {
|
||||
versionDiff := hardcoded.Pia.Version - persistent.Pia.Version
|
||||
if versionDiff > 0 {
|
||||
s.logger.Info("Private Internet Access servers from file discarded because they are %d versions behind",
|
||||
versionDiff)
|
||||
merged.Pia = hardcoded.Pia
|
||||
} else {
|
||||
s.logger.Info("Using Private Internet Access servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persistent.Pia.Timestamp, hardcoded.Pia.Timestamp))
|
||||
merged.Pia = persistent.Pia
|
||||
}
|
||||
}
|
||||
|
||||
merged.Privado = hardcoded.Privado
|
||||
versionDiff := int(persistent.Privado.Version) - int(hardcoded.Privado.Version)
|
||||
switch {
|
||||
case versionDiff > 0:
|
||||
s.logger.Info("Using Privado servers from file (%d version(s) more recent)", versionDiff)
|
||||
merged.Privado = persistent.Privado
|
||||
case persistent.Privado.Timestamp > hardcoded.Privado.Timestamp:
|
||||
s.logger.Info("Using Privado servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persistent.Privado.Timestamp, hardcoded.Privado.Timestamp))
|
||||
merged.Privado = persistent.Privado
|
||||
}
|
||||
|
||||
merged.Purevpn = hardcoded.Purevpn
|
||||
if persistent.Purevpn.Timestamp > hardcoded.Purevpn.Timestamp {
|
||||
s.logger.Info("Using Purevpn servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persistent.Purevpn.Timestamp, hardcoded.Purevpn.Timestamp))
|
||||
merged.Purevpn = persistent.Purevpn
|
||||
}
|
||||
merged.Surfshark = hardcoded.Surfshark
|
||||
if persistent.Surfshark.Timestamp > hardcoded.Surfshark.Timestamp {
|
||||
s.logger.Info("Using Surfshark servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persistent.Surfshark.Timestamp, hardcoded.Surfshark.Timestamp))
|
||||
merged.Surfshark = persistent.Surfshark
|
||||
}
|
||||
merged.Vyprvpn = hardcoded.Vyprvpn
|
||||
if persistent.Vyprvpn.Timestamp > hardcoded.Vyprvpn.Timestamp {
|
||||
s.logger.Info("Using Vyprvpn servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persistent.Vyprvpn.Timestamp, hardcoded.Vyprvpn.Timestamp))
|
||||
merged.Vyprvpn = persistent.Vyprvpn
|
||||
}
|
||||
merged.Windscribe = hardcoded.Windscribe
|
||||
if persistent.Windscribe.Timestamp > hardcoded.Windscribe.Timestamp {
|
||||
if hardcoded.Windscribe.Version == 2 && persistent.Windscribe.Version == 1 {
|
||||
s.logger.Info("Windscribe servers from file discarded because they are one version behind")
|
||||
merged.Windscribe = hardcoded.Windscribe
|
||||
} else {
|
||||
s.logger.Info("Using Windscribe servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persistent.Windscribe.Timestamp, hardcoded.Windscribe.Timestamp))
|
||||
merged.Windscribe = persistent.Windscribe
|
||||
}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func (s *storage) mergeCyberghost(hardcoded, persisted models.CyberghostServers) models.CyberghostServers {
|
||||
if persisted.Timestamp <= hardcoded.Timestamp {
|
||||
return hardcoded
|
||||
}
|
||||
s.logger.Info("Using Cyberghost servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
|
||||
return persisted
|
||||
}
|
||||
|
||||
func (s *storage) mergeMullvad(hardcoded, persisted models.MullvadServers) models.MullvadServers {
|
||||
if persisted.Timestamp <= hardcoded.Timestamp {
|
||||
return hardcoded
|
||||
}
|
||||
s.logger.Info("Using Mullvad servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
|
||||
return persisted
|
||||
}
|
||||
|
||||
func (s *storage) mergeNordVPN(hardcoded, persisted models.NordvpnServers) models.NordvpnServers {
|
||||
if persisted.Timestamp <= hardcoded.Timestamp {
|
||||
return hardcoded
|
||||
}
|
||||
s.logger.Info("Using NordVPN servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
|
||||
return persisted
|
||||
}
|
||||
|
||||
func (s *storage) mergePIA(hardcoded, persisted models.PiaServers) models.PiaServers {
|
||||
if persisted.Timestamp <= hardcoded.Timestamp {
|
||||
return hardcoded
|
||||
}
|
||||
versionDiff := hardcoded.Version - persisted.Version
|
||||
if versionDiff > 0 {
|
||||
s.logger.Info(
|
||||
"PIA servers from file discarded because they are %d versions behind",
|
||||
versionDiff)
|
||||
return hardcoded
|
||||
}
|
||||
s.logger.Info("Using PIA servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
|
||||
return persisted
|
||||
}
|
||||
|
||||
func (s *storage) mergePrivado(hardcoded, persisted models.PrivadoServers) models.PrivadoServers {
|
||||
if persisted.Timestamp <= hardcoded.Timestamp {
|
||||
return hardcoded
|
||||
}
|
||||
versionDiff := hardcoded.Version - persisted.Version
|
||||
if versionDiff > 0 {
|
||||
s.logger.Info(
|
||||
"Privado servers from file discarded because they are %d versions behind",
|
||||
versionDiff)
|
||||
return hardcoded
|
||||
}
|
||||
s.logger.Info("Using Privado servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
|
||||
return persisted
|
||||
}
|
||||
|
||||
func (s *storage) mergePureVPN(hardcoded, persisted models.PurevpnServers) models.PurevpnServers {
|
||||
if persisted.Timestamp <= hardcoded.Timestamp {
|
||||
return hardcoded
|
||||
}
|
||||
s.logger.Info("Using PureVPN servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
|
||||
return persisted
|
||||
}
|
||||
|
||||
func (s *storage) mergeSurfshark(hardcoded, persisted models.SurfsharkServers) models.SurfsharkServers {
|
||||
if persisted.Timestamp <= hardcoded.Timestamp {
|
||||
return hardcoded
|
||||
}
|
||||
s.logger.Info("Using Surfshark servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
|
||||
return persisted
|
||||
}
|
||||
|
||||
func (s *storage) mergeVyprvpn(hardcoded, persisted models.VyprvpnServers) models.VyprvpnServers {
|
||||
if persisted.Timestamp <= hardcoded.Timestamp {
|
||||
return hardcoded
|
||||
}
|
||||
s.logger.Info("Using VyprVPN servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
|
||||
return persisted
|
||||
}
|
||||
|
||||
func (s *storage) mergeWindscribe(hardcoded, persisted models.WindscribeServers) models.WindscribeServers {
|
||||
if persisted.Timestamp <= hardcoded.Timestamp {
|
||||
return hardcoded
|
||||
}
|
||||
versionDiff := hardcoded.Version - persisted.Version
|
||||
if versionDiff > 0 {
|
||||
s.logger.Info(
|
||||
"Windscribe servers from file discarded because they are %d versions behind",
|
||||
versionDiff)
|
||||
return hardcoded
|
||||
}
|
||||
s.logger.Info("Using Windscribe servers from file (%s more recent)",
|
||||
getUnixTimeDifference(persisted.Timestamp, hardcoded.Timestamp))
|
||||
return persisted
|
||||
}
|
||||
|
||||
@@ -7,18 +7,21 @@ import (
|
||||
)
|
||||
|
||||
type Storage interface {
|
||||
SyncServers(hardcodedServers models.AllServers, write bool) (allServers models.AllServers, err error)
|
||||
// Passing an empty filepath disables writing to a file
|
||||
SyncServers(hardcodedServers models.AllServers) (allServers models.AllServers, err error)
|
||||
FlushToFile(servers models.AllServers) error
|
||||
}
|
||||
|
||||
type storage struct {
|
||||
os os.OS
|
||||
logger logging.Logger
|
||||
os os.OS
|
||||
logger logging.Logger
|
||||
filepath string
|
||||
}
|
||||
|
||||
func New(logger logging.Logger, os os.OS) Storage {
|
||||
func New(logger logging.Logger, os os.OS, filepath string) Storage {
|
||||
return &storage{
|
||||
os: os,
|
||||
logger: logger.WithPrefix("storage: "),
|
||||
os: os,
|
||||
logger: logger.WithPrefix("storage: "),
|
||||
filepath: filepath,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
@@ -9,8 +10,9 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
)
|
||||
|
||||
const (
|
||||
jsonFilepath = "/gluetun/servers.json"
|
||||
var (
|
||||
ErrCannotReadFile = errors.New("cannot read servers from file")
|
||||
ErrCannotWriteFile = errors.New("cannot write servers to file")
|
||||
)
|
||||
|
||||
func countServers(allServers models.AllServers) int {
|
||||
@@ -25,38 +27,54 @@ func countServers(allServers models.AllServers) int {
|
||||
len(allServers.Windscribe.Servers)
|
||||
}
|
||||
|
||||
func (s *storage) SyncServers(hardcodedServers models.AllServers, write bool) (
|
||||
func (s *storage) SyncServers(hardcodedServers models.AllServers) (
|
||||
allServers models.AllServers, err error) {
|
||||
// Eventually read file
|
||||
var serversOnFile models.AllServers
|
||||
file, err := s.os.OpenFile(jsonFilepath, os.O_RDONLY, 0)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return allServers, err
|
||||
}
|
||||
if err == nil {
|
||||
var serversOnFile models.AllServers
|
||||
decoder := json.NewDecoder(file)
|
||||
if err := decoder.Decode(&serversOnFile); err != nil {
|
||||
_ = file.Close()
|
||||
return allServers, err
|
||||
}
|
||||
return allServers, file.Close()
|
||||
serversOnFile, err := s.readFromFile(s.filepath)
|
||||
if err != nil {
|
||||
return allServers, fmt.Errorf("%w: %s", ErrCannotReadFile, err)
|
||||
}
|
||||
|
||||
// Merge data from file and hardcoded
|
||||
s.logger.Info("Merging by most recent %d hardcoded servers and %d servers read from %s",
|
||||
countServers(hardcodedServers), countServers(serversOnFile), jsonFilepath)
|
||||
allServers = s.mergeServers(hardcodedServers, serversOnFile)
|
||||
hardcodedCount := countServers(hardcodedServers)
|
||||
countOnFile := countServers(serversOnFile)
|
||||
|
||||
if countOnFile == 0 {
|
||||
s.logger.Info("creating %s with %d hardcoded servers", s.filepath, hardcodedCount)
|
||||
allServers = hardcodedServers
|
||||
} else {
|
||||
s.logger.Info(
|
||||
"merging by most recent %d hardcoded servers and %d servers read from %s",
|
||||
hardcodedCount, countOnFile, s.filepath)
|
||||
allServers = s.mergeServers(hardcodedServers, serversOnFile)
|
||||
}
|
||||
|
||||
// Eventually write file
|
||||
if !write || reflect.DeepEqual(serversOnFile, allServers) {
|
||||
if s.filepath == "" || reflect.DeepEqual(serversOnFile, allServers) {
|
||||
return allServers, nil
|
||||
}
|
||||
return allServers, s.FlushToFile(allServers)
|
||||
|
||||
if err := s.FlushToFile(allServers); err != nil {
|
||||
return allServers, fmt.Errorf("%w: %s", ErrCannotWriteFile, err)
|
||||
}
|
||||
return allServers, nil
|
||||
}
|
||||
|
||||
func (s *storage) readFromFile(filepath string) (servers models.AllServers, err error) {
|
||||
file, err := s.os.OpenFile(filepath, os.O_RDONLY, 0)
|
||||
if os.IsNotExist(err) {
|
||||
return servers, nil
|
||||
} else if err != nil {
|
||||
return servers, err
|
||||
}
|
||||
decoder := json.NewDecoder(file)
|
||||
if err := decoder.Decode(&servers); err != nil {
|
||||
_ = file.Close()
|
||||
return servers, err
|
||||
}
|
||||
return servers, file.Close()
|
||||
}
|
||||
|
||||
func (s *storage) FlushToFile(servers models.AllServers) error {
|
||||
file, err := s.os.OpenFile(jsonFilepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
file, err := s.os.OpenFile(s.filepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -64,7 +82,7 @@ func (s *storage) FlushToFile(servers models.AllServers) error {
|
||||
encoder.SetIndent("", " ")
|
||||
if err := encoder.Encode(servers); err != nil {
|
||||
_ = file.Close()
|
||||
return fmt.Errorf("cannot write to file: %w", err)
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user