Maint: internal/storage rework

- No more global variables
- Inject merged servers to configuration package
- Fix #566: configuration parsing to use persisted servers.json
- Move server data files from `internal/constants` to `internal/storage`
This commit is contained in:
Quentin McGaw (desktop)
2021-08-27 19:10:03 +00:00
parent b1cfc03fc5
commit 3863cc439e
59 changed files with 850 additions and 490 deletions

View File

@@ -1,7 +1,38 @@
package storage
import "github.com/qdm12/gluetun/internal/models"
import (
"encoding/json"
"os"
"path/filepath"
func (s *storage) FlushToFile(allServers models.AllServers) error {
"github.com/qdm12/gluetun/internal/models"
)
var _ Flusher = (*Storage)(nil)
type Flusher interface {
FlushToFile(allServers models.AllServers) error
}
func (s *Storage) FlushToFile(allServers models.AllServers) error {
return flushToFile(s.filepath, allServers)
}
func flushToFile(path string, servers models.AllServers) error {
dirPath := filepath.Dir(path)
if err := os.MkdirAll(dirPath, 0644); err != nil {
return err
}
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return err
}
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err := encoder.Encode(servers); err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -0,0 +1,21 @@
package storage
import (
"embed"
"encoding/json"
"github.com/qdm12/gluetun/internal/models"
)
//go:embed servers.json
var allServersEmbedFS embed.FS //nolint:gochecknoglobals
func parseHardcodedServers() (allServers models.AllServers, err error) {
f, err := allServersEmbedFS.Open("servers.json")
if err != nil {
return allServers, err
}
decoder := json.NewDecoder(f)
err = decoder.Decode(&allServers)
return allServers, err
}

View File

@@ -0,0 +1,147 @@
package storage
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"testing"
"github.com/qdm12/gluetun/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_parseHardcodedServers(t *testing.T) {
t.Parallel()
servers, err := parseHardcodedServers()
require.NoError(t, err)
require.NotEmpty(t, len(servers.Cyberghost.Servers))
}
func digestServerModelVersion(t *testing.T, server interface{}, version uint16) string {
t.Helper()
bytes, err := json.Marshal(server)
if err != nil {
t.Fatal(err)
}
bytes = append(bytes, []byte(fmt.Sprintf("%d", version))...)
arr := sha256.Sum256(bytes)
hexString := hex.EncodeToString(arr[:])
if len(hexString) > 8 {
hexString = hexString[:8]
}
return hexString
}
func Test_versions(t *testing.T) {
t.Parallel()
allServers, err := parseHardcodedServers()
require.NoError(t, err)
const format = "you forgot to update the version for %s"
testCases := map[string]struct {
model interface{}
version uint16
digest string
}{
"Cyberghost": {
model: models.CyberghostServer{},
version: allServers.Cyberghost.Version,
digest: "229828de",
},
"Fastestvpn": {
model: models.FastestvpnServer{},
version: allServers.Fastestvpn.Version,
digest: "8825919b",
},
"HideMyAss": {
model: models.HideMyAssServer{},
version: allServers.HideMyAss.Version,
digest: "a93b4057",
},
"Ipvanish": {
model: models.IpvanishServer{},
version: allServers.Ipvanish.Version,
digest: "2eb80d28",
},
"Ivpn": {
model: models.IvpnServer{},
version: allServers.Ivpn.Version,
digest: "88074ceb",
},
"Mullvad": {
model: models.MullvadServer{},
version: allServers.Mullvad.Version,
digest: "ec56f19d",
},
"Nordvpn": {
model: models.NordvpnServer{},
version: allServers.Nordvpn.Version,
digest: "a3b5d609",
},
"Privado": {
model: models.PrivadoServer{},
version: allServers.Privado.Version,
digest: "dba6736c",
},
"Private Internet Access": {
model: models.PIAServer{},
version: allServers.Pia.Version,
digest: "91db9bc9",
},
"Privatevpn": {
model: models.PrivatevpnServer{},
version: allServers.Privatevpn.Version,
digest: "cba13d78",
},
"Protonvpn": {
model: models.ProtonvpnServer{},
version: allServers.Protonvpn.Version,
digest: "b964085b",
},
"Purevpn": {
model: models.PurevpnServer{},
version: allServers.Purevpn.Version,
digest: "23f2d422",
},
"Surfshark": {
model: models.SurfsharkServer{},
version: allServers.Surfshark.Version,
digest: "3ccaa772",
},
"Torguard": {
model: models.TorguardServer{},
version: allServers.Torguard.Version,
digest: "6eb9028e",
},
"VPN Unlimited": {
model: models.VPNUnlimitedServer{},
version: allServers.VPNUnlimited.Version,
digest: "5cb51319",
},
"Vyprvpn": {
model: models.VyprvpnServer{},
version: allServers.Vyprvpn.Version,
digest: "58de06d8",
},
"Windscribe": {
model: models.WindscribeServer{},
version: allServers.Windscribe.Version,
digest: "4bd0fc4f",
},
}
for name, testCase := range testCases {
name := name
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
digest := digestServerModelVersion(t, testCase.model, testCase.version)
failureMessage := fmt.Sprintf(format, name)
assert.Equal(t, testCase.digest, digest, failureMessage)
})
}
}

View File

@@ -4,11 +4,10 @@ import (
"strconv"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
)
func (s *storage) logVersionDiff(provider string, diff int) {
func (s *Storage) logVersionDiff(provider string, diff int) {
diffString := strconv.Itoa(diff)
message := provider + " servers from file discarded because they are " +
@@ -20,7 +19,7 @@ func (s *storage) logVersionDiff(provider string, diff int) {
s.logger.Info(message)
}
func (s *storage) logTimeDiff(provider string, persistedUnix, hardcodedUnix int64) {
func (s *Storage) logTimeDiff(provider string, persistedUnix, hardcodedUnix int64) {
diff := time.Unix(persistedUnix, 0).Sub(time.Unix(hardcodedUnix, 0))
if diff < 0 {
diff = -diff
@@ -31,7 +30,7 @@ func (s *storage) logTimeDiff(provider string, persistedUnix, hardcodedUnix int6
s.logger.Info(message)
}
func (s *storage) mergeServers(hardcoded, persisted models.AllServers) models.AllServers {
func (s *Storage) mergeServers(hardcoded, persisted models.AllServers) models.AllServers {
return models.AllServers{
Version: hardcoded.Version,
Cyberghost: s.mergeCyberghost(hardcoded.Cyberghost, persisted.Cyberghost),
@@ -54,7 +53,7 @@ func (s *storage) mergeServers(hardcoded, persisted models.AllServers) models.Al
}
}
func (s *storage) mergeCyberghost(hardcoded, persisted models.CyberghostServers) models.CyberghostServers {
func (s *Storage) mergeCyberghost(hardcoded, persisted models.CyberghostServers) models.CyberghostServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -69,7 +68,7 @@ func (s *storage) mergeCyberghost(hardcoded, persisted models.CyberghostServers)
return persisted
}
func (s *storage) mergeFastestvpn(hardcoded, persisted models.FastestvpnServers) models.FastestvpnServers {
func (s *Storage) mergeFastestvpn(hardcoded, persisted models.FastestvpnServers) models.FastestvpnServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -82,7 +81,7 @@ func (s *storage) mergeFastestvpn(hardcoded, persisted models.FastestvpnServers)
return persisted
}
func (s *storage) mergeHideMyAss(hardcoded, persisted models.HideMyAssServers) models.HideMyAssServers {
func (s *Storage) mergeHideMyAss(hardcoded, persisted models.HideMyAssServers) models.HideMyAssServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -95,7 +94,7 @@ func (s *storage) mergeHideMyAss(hardcoded, persisted models.HideMyAssServers) m
return persisted
}
func (s *storage) mergeIpvanish(hardcoded, persisted models.IpvanishServers) models.IpvanishServers {
func (s *Storage) mergeIpvanish(hardcoded, persisted models.IpvanishServers) models.IpvanishServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -108,7 +107,7 @@ func (s *storage) mergeIpvanish(hardcoded, persisted models.IpvanishServers) mod
return persisted
}
func (s *storage) mergeIvpn(hardcoded, persisted models.IvpnServers) models.IvpnServers {
func (s *Storage) mergeIvpn(hardcoded, persisted models.IvpnServers) models.IvpnServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -121,7 +120,7 @@ func (s *storage) mergeIvpn(hardcoded, persisted models.IvpnServers) models.Ivpn
return persisted
}
func (s *storage) mergeMullvad(hardcoded, persisted models.MullvadServers) models.MullvadServers {
func (s *Storage) mergeMullvad(hardcoded, persisted models.MullvadServers) models.MullvadServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -136,7 +135,7 @@ func (s *storage) mergeMullvad(hardcoded, persisted models.MullvadServers) model
return persisted
}
func (s *storage) mergeNordVPN(hardcoded, persisted models.NordvpnServers) models.NordvpnServers {
func (s *Storage) mergeNordVPN(hardcoded, persisted models.NordvpnServers) models.NordvpnServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -151,7 +150,7 @@ func (s *storage) mergeNordVPN(hardcoded, persisted models.NordvpnServers) model
return persisted
}
func (s *storage) mergePrivado(hardcoded, persisted models.PrivadoServers) models.PrivadoServers {
func (s *Storage) mergePrivado(hardcoded, persisted models.PrivadoServers) models.PrivadoServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -165,7 +164,7 @@ func (s *storage) mergePrivado(hardcoded, persisted models.PrivadoServers) model
return persisted
}
func (s *storage) mergePIA(hardcoded, persisted models.PiaServers) models.PiaServers {
func (s *Storage) mergePIA(hardcoded, persisted models.PiaServers) models.PiaServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -179,7 +178,7 @@ func (s *storage) mergePIA(hardcoded, persisted models.PiaServers) models.PiaSer
return persisted
}
func (s *storage) mergePrivatevpn(hardcoded, persisted models.PrivatevpnServers) models.PrivatevpnServers {
func (s *Storage) mergePrivatevpn(hardcoded, persisted models.PrivatevpnServers) models.PrivatevpnServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -193,7 +192,7 @@ func (s *storage) mergePrivatevpn(hardcoded, persisted models.PrivatevpnServers)
return persisted
}
func (s *storage) mergeProtonvpn(hardcoded, persisted models.ProtonvpnServers) models.ProtonvpnServers {
func (s *Storage) mergeProtonvpn(hardcoded, persisted models.ProtonvpnServers) models.ProtonvpnServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -207,7 +206,7 @@ func (s *storage) mergeProtonvpn(hardcoded, persisted models.ProtonvpnServers) m
return persisted
}
func (s *storage) mergePureVPN(hardcoded, persisted models.PurevpnServers) models.PurevpnServers {
func (s *Storage) mergePureVPN(hardcoded, persisted models.PurevpnServers) models.PurevpnServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -222,7 +221,7 @@ func (s *storage) mergePureVPN(hardcoded, persisted models.PurevpnServers) model
return persisted
}
func (s *storage) mergeSurfshark(hardcoded, persisted models.SurfsharkServers) models.SurfsharkServers {
func (s *Storage) mergeSurfshark(hardcoded, persisted models.SurfsharkServers) models.SurfsharkServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -237,7 +236,7 @@ func (s *storage) mergeSurfshark(hardcoded, persisted models.SurfsharkServers) m
return persisted
}
func (s *storage) mergeTorguard(hardcoded, persisted models.TorguardServers) models.TorguardServers {
func (s *Storage) mergeTorguard(hardcoded, persisted models.TorguardServers) models.TorguardServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -251,21 +250,21 @@ func (s *storage) mergeTorguard(hardcoded, persisted models.TorguardServers) mod
return persisted
}
func (s *storage) mergeVPNUnlimited(hardcoded, persisted models.VPNUnlimitedServers) models.VPNUnlimitedServers {
func (s *Storage) mergeVPNUnlimited(hardcoded, persisted models.VPNUnlimitedServers) models.VPNUnlimitedServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
versionDiff := int(hardcoded.Version) - int(persisted.Version)
if versionDiff > 0 {
s.logVersionDiff(constants.VPNUnlimited, versionDiff)
s.logVersionDiff("VPN Unlimited", versionDiff)
return hardcoded
}
s.logTimeDiff(constants.VPNUnlimited, persisted.Timestamp, hardcoded.Timestamp)
s.logTimeDiff("VPN Unlimited", persisted.Timestamp, hardcoded.Timestamp)
return persisted
}
func (s *storage) mergeVyprvpn(hardcoded, persisted models.VyprvpnServers) models.VyprvpnServers {
func (s *Storage) mergeVyprvpn(hardcoded, persisted models.VyprvpnServers) models.VyprvpnServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}
@@ -280,7 +279,7 @@ func (s *storage) mergeVyprvpn(hardcoded, persisted models.VyprvpnServers) model
return persisted
}
func (s *storage) mergeWindscribe(hardcoded, persisted models.WindscribeServers) models.WindscribeServers {
func (s *Storage) mergeWindscribe(hardcoded, persisted models.WindscribeServers) models.WindscribeServers {
if persisted.Timestamp <= hardcoded.Timestamp {
return hardcoded
}

28
internal/storage/read.go Normal file
View File

@@ -0,0 +1,28 @@
package storage
import (
"encoding/json"
"errors"
"io"
"os"
"github.com/qdm12/gluetun/internal/models"
)
func readFromFile(filepath string) (servers models.AllServers, err error) {
file, err := os.Open(filepath)
if os.IsNotExist(err) {
return servers, nil
} else if err != nil {
return servers, err
}
decoder := json.NewDecoder(file)
if err := decoder.Decode(&servers); err != nil {
_ = file.Close()
if errors.Is(err, io.EOF) {
return servers, nil
}
return servers, err
}
return servers, file.Close()
}

View File

@@ -0,0 +1,7 @@
package storage
import "github.com/qdm12/gluetun/internal/models"
func (s *Storage) GetServers() models.AllServers {
return s.mergedServers.GetCopy()
}

119517
internal/storage/servers.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -6,20 +6,29 @@ import (
"github.com/qdm12/golibs/logging"
)
type Storage interface {
// Passing an empty filepath disables writing to a file
SyncServers(hardcodedServers models.AllServers) (allServers models.AllServers, err error)
FlushToFile(servers models.AllServers) error
type Storage struct {
mergedServers models.AllServers
hardcodedServers models.AllServers
logger logging.Logger
filepath string
}
type storage struct {
logger logging.Logger
filepath string
}
// New creates a new storage and reads the servers from the
// embedded servers file and the file on disk.
// Passing an empty filepath disables writing servers to a file.
func New(logger logging.Logger, filepath string) (storage *Storage, err error) {
// error returned covered by unit test
harcodedServers, _ := parseHardcodedServers()
func New(logger logging.Logger, filepath string) Storage {
return &storage{
logger: logger,
filepath: filepath,
storage = &Storage{
hardcodedServers: harcodedServers,
logger: logger,
filepath: filepath,
}
if err := storage.SyncServers(); err != nil {
return nil, err
}
return storage, nil
}

View File

@@ -1,14 +1,9 @@
package storage
import (
"encoding/json"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"reflect"
"strconv"
"github.com/qdm12/gluetun/internal/models"
)
@@ -38,70 +33,35 @@ func countServers(allServers models.AllServers) int {
len(allServers.Windscribe.Servers)
}
func (s *storage) SyncServers(hardcodedServers models.AllServers) (
allServers models.AllServers, err error) {
func (s *Storage) SyncServers() (err error) {
serversOnFile, err := readFromFile(s.filepath)
if err != nil {
return allServers, fmt.Errorf("%w: %s", ErrCannotReadFile, err)
return fmt.Errorf("%w: %s", ErrCannotReadFile, err)
}
hardcodedCount := countServers(hardcodedServers)
hardcodedCount := countServers(s.hardcodedServers)
countOnFile := countServers(serversOnFile)
if countOnFile == 0 {
s.logger.Info("creating " + s.filepath + " with " + strconv.Itoa(hardcodedCount) + " hardcoded servers")
allServers = hardcodedServers
s.logger.Info(fmt.Sprintf(
"creating %s with %d hardcoded servers",
s.filepath, hardcodedCount))
s.mergedServers = s.hardcodedServers
} else {
s.logger.Info("merging by most recent " +
strconv.Itoa(hardcodedCount) + " hardcoded servers and " +
strconv.Itoa(countOnFile) + " servers read from " + s.filepath)
allServers = s.mergeServers(hardcodedServers, serversOnFile)
s.logger.Info(fmt.Sprintf(
"merging by most recent %d hardcoded servers and %d servers read from %s",
hardcodedCount, countOnFile, s.filepath))
s.mergedServers = s.mergeServers(s.hardcodedServers, serversOnFile)
}
// Eventually write file
if s.filepath == "" || reflect.DeepEqual(serversOnFile, allServers) {
return allServers, nil
if s.filepath == "" || reflect.DeepEqual(serversOnFile, s.mergedServers) {
return nil
}
if err := flushToFile(s.filepath, allServers); err != nil {
return allServers, fmt.Errorf("%w: %s", ErrCannotWriteFile, err)
if err := flushToFile(s.filepath, s.mergedServers); err != nil {
return fmt.Errorf("%w: %s", ErrCannotWriteFile, err)
}
return allServers, nil
}
func readFromFile(filepath string) (servers models.AllServers, err error) {
file, err := os.Open(filepath)
if os.IsNotExist(err) {
return servers, nil
} else if err != nil {
return servers, err
}
decoder := json.NewDecoder(file)
if err := decoder.Decode(&servers); err != nil {
_ = file.Close()
if errors.Is(err, io.EOF) {
return servers, nil
}
return servers, err
}
return servers, file.Close()
}
func flushToFile(path string, servers models.AllServers) error {
dirPath := filepath.Dir(path)
if err := os.MkdirAll(dirPath, 0644); err != nil {
return err
}
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return err
}
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err := encoder.Encode(servers); err != nil {
_ = file.Close()
return err
}
return file.Close()
return nil
}