From 1d25a0e18c15f2fed7f70a6a7ecb43e9fc48eda2 Mon Sep 17 00:00:00 2001 From: "Quentin McGaw (desktop)" Date: Thu, 30 Sep 2021 15:22:57 +0000 Subject: [PATCH] Fix: server data version diff when reading file --- internal/storage/infoerrorer_mock_test.go | 46 ++++ internal/storage/merge.go | 122 +--------- internal/storage/read.go | 270 +++++++++++++++++++++- internal/storage/read_test.go | 169 ++++++++++++++ internal/storage/storage.go | 2 + internal/storage/sync.go | 2 +- 6 files changed, 490 insertions(+), 121 deletions(-) create mode 100644 internal/storage/infoerrorer_mock_test.go create mode 100644 internal/storage/read_test.go diff --git a/internal/storage/infoerrorer_mock_test.go b/internal/storage/infoerrorer_mock_test.go new file mode 100644 index 00000000..9c9c176c --- /dev/null +++ b/internal/storage/infoerrorer_mock_test.go @@ -0,0 +1,46 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/storage (interfaces: InfoErrorer) + +// Package storage is a generated GoMock package. +package storage + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockInfoErrorer is a mock of InfoErrorer interface. +type MockInfoErrorer struct { + ctrl *gomock.Controller + recorder *MockInfoErrorerMockRecorder +} + +// MockInfoErrorerMockRecorder is the mock recorder for MockInfoErrorer. +type MockInfoErrorerMockRecorder struct { + mock *MockInfoErrorer +} + +// NewMockInfoErrorer creates a new mock instance. +func NewMockInfoErrorer(ctrl *gomock.Controller) *MockInfoErrorer { + mock := &MockInfoErrorer{ctrl: ctrl} + mock.recorder = &MockInfoErrorerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockInfoErrorer) EXPECT() *MockInfoErrorerMockRecorder { + return m.recorder +} + +// Info mocks base method. +func (m *MockInfoErrorer) Info(arg0 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Info", arg0) +} + +// Info indicates an expected call of Info. +func (mr *MockInfoErrorerMockRecorder) Info(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockInfoErrorer)(nil).Info), arg0) +} diff --git a/internal/storage/merge.go b/internal/storage/merge.go index 6ffe5b83..24e1d7ef 100644 --- a/internal/storage/merge.go +++ b/internal/storage/merge.go @@ -7,15 +7,11 @@ import ( "github.com/qdm12/gluetun/internal/models" ) -func (s *Storage) logVersionDiff(provider string, diff int) { - diffString := strconv.Itoa(diff) - - message := provider + " servers from file discarded because they are " + - diffString + " version" - if diff > 1 { - message += "s" - } - message += " behind" +func (s *Storage) logVersionDiff(provider string, hardcodedVersion, persistedVersion uint16) { + message := provider + " servers from file discarded because they have version " + + strconv.Itoa(int(persistedVersion)) + + " and hardcoded servers have version " + + strconv.Itoa(int(hardcodedVersion)) s.logger.Info(message) } @@ -60,12 +56,6 @@ func (s *Storage) mergeCyberghost(hardcoded, persisted models.CyberghostServers) return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("Cyberghost", versionDiff) - return hardcoded - } - s.logTimeDiff("Cyberghost", persisted.Timestamp, hardcoded.Timestamp) return persisted } @@ -74,11 +64,7 @@ func (s *Storage) mergeExpressvpn(hardcoded, persisted models.ExpressvpnServers) if persisted.Timestamp <= hardcoded.Timestamp { return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("ExpressVPN", versionDiff) - return hardcoded - } + s.logTimeDiff("ExpressVPN", persisted.Timestamp, hardcoded.Timestamp) return persisted } @@ -87,11 +73,7 @@ func (s *Storage) mergeFastestvpn(hardcoded, persisted models.FastestvpnServers) if persisted.Timestamp <= hardcoded.Timestamp { return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("FastestVPN", versionDiff) - return hardcoded - } + s.logTimeDiff("FastestVPN", persisted.Timestamp, hardcoded.Timestamp) return persisted } @@ -100,11 +82,7 @@ func (s *Storage) mergeHideMyAss(hardcoded, persisted models.HideMyAssServers) m if persisted.Timestamp <= hardcoded.Timestamp { return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("HideMyAss", versionDiff) - return hardcoded - } + s.logTimeDiff("HideMyAss", persisted.Timestamp, hardcoded.Timestamp) return persisted } @@ -113,11 +91,7 @@ func (s *Storage) mergeIpvanish(hardcoded, persisted models.IpvanishServers) mod if persisted.Timestamp <= hardcoded.Timestamp { return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("Ipvanish", versionDiff) - return hardcoded - } + s.logTimeDiff("Ipvanish", persisted.Timestamp, hardcoded.Timestamp) return persisted } @@ -126,11 +100,7 @@ func (s *Storage) mergeIvpn(hardcoded, persisted models.IvpnServers) models.Ivpn if persisted.Timestamp <= hardcoded.Timestamp { return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("Ivpn", versionDiff) - return hardcoded - } + s.logTimeDiff("Ivpn", persisted.Timestamp, hardcoded.Timestamp) return persisted } @@ -140,12 +110,6 @@ func (s *Storage) mergeMullvad(hardcoded, persisted models.MullvadServers) model return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("Mullvad", versionDiff) - return hardcoded - } - s.logTimeDiff("Mullvad", persisted.Timestamp, hardcoded.Timestamp) return persisted } @@ -155,12 +119,6 @@ func (s *Storage) mergeNordVPN(hardcoded, persisted models.NordvpnServers) model return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("NordVPN", versionDiff) - return hardcoded - } - s.logTimeDiff("NordVPN", persisted.Timestamp, hardcoded.Timestamp) return persisted } @@ -169,11 +127,6 @@ func (s *Storage) mergePrivado(hardcoded, persisted models.PrivadoServers) model if persisted.Timestamp <= hardcoded.Timestamp { return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("Privado", versionDiff) - return hardcoded - } s.logTimeDiff("Privado", persisted.Timestamp, hardcoded.Timestamp) return persisted @@ -183,11 +136,6 @@ func (s *Storage) mergePIA(hardcoded, persisted models.PiaServers) models.PiaSer if persisted.Timestamp <= hardcoded.Timestamp { return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("Private Internet Access", versionDiff) - return hardcoded - } s.logTimeDiff("Private Internet Access", persisted.Timestamp, hardcoded.Timestamp) return persisted @@ -197,11 +145,6 @@ func (s *Storage) mergePrivatevpn(hardcoded, persisted models.PrivatevpnServers) if persisted.Timestamp <= hardcoded.Timestamp { return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("PrivateVPN", versionDiff) - return hardcoded - } s.logTimeDiff("PrivateVPN", persisted.Timestamp, hardcoded.Timestamp) return persisted @@ -211,11 +154,6 @@ func (s *Storage) mergeProtonvpn(hardcoded, persisted models.ProtonvpnServers) m if persisted.Timestamp <= hardcoded.Timestamp { return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("ProtonVPN", versionDiff) - return hardcoded - } s.logTimeDiff("ProtonVPN", persisted.Timestamp, hardcoded.Timestamp) return persisted @@ -226,12 +164,6 @@ func (s *Storage) mergePureVPN(hardcoded, persisted models.PurevpnServers) model return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("PureVPN", versionDiff) - return hardcoded - } - s.logTimeDiff("PureVPN", persisted.Timestamp, hardcoded.Timestamp) return persisted } @@ -241,12 +173,6 @@ func (s *Storage) mergeSurfshark(hardcoded, persisted models.SurfsharkServers) m return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("Surfshark", versionDiff) - return hardcoded - } - s.logTimeDiff("Surfshark", persisted.Timestamp, hardcoded.Timestamp) return persisted } @@ -255,11 +181,6 @@ func (s *Storage) mergeTorguard(hardcoded, persisted models.TorguardServers) mod if persisted.Timestamp <= hardcoded.Timestamp { return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("Torguard", versionDiff) - return hardcoded - } s.logTimeDiff("Torguard", persisted.Timestamp, hardcoded.Timestamp) return persisted @@ -269,11 +190,6 @@ func (s *Storage) mergeVPNUnlimited(hardcoded, persisted models.VPNUnlimitedServ if persisted.Timestamp <= hardcoded.Timestamp { return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("VPN Unlimited", versionDiff) - return hardcoded - } s.logTimeDiff("VPN Unlimited", persisted.Timestamp, hardcoded.Timestamp) return persisted @@ -284,12 +200,6 @@ func (s *Storage) mergeVyprvpn(hardcoded, persisted models.VyprvpnServers) model return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("VyprVPN", versionDiff) - return hardcoded - } - s.logTimeDiff("VyprVPN", persisted.Timestamp, hardcoded.Timestamp) return persisted } @@ -299,12 +209,6 @@ func (s *Storage) mergeWevpn(hardcoded, persisted models.WevpnServers) models.We return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("WeVPN", versionDiff) - return hardcoded - } - s.logTimeDiff("WeVPN", persisted.Timestamp, hardcoded.Timestamp) return persisted } @@ -314,12 +218,6 @@ func (s *Storage) mergeWindscribe(hardcoded, persisted models.WindscribeServers) return hardcoded } - versionDiff := int(hardcoded.Version) - int(persisted.Version) - if versionDiff > 0 { - s.logVersionDiff("Windscribe", versionDiff) - return hardcoded - } - s.logTimeDiff("Windscribe", persisted.Timestamp, hardcoded.Timestamp) return persisted } diff --git a/internal/storage/read.go b/internal/storage/read.go index 6ff80055..f18b6f3e 100644 --- a/internal/storage/read.go +++ b/internal/storage/read.go @@ -3,26 +3,280 @@ package storage import ( "encoding/json" "errors" + "fmt" "io" "os" "github.com/qdm12/gluetun/internal/models" ) -func readFromFile(filepath string) (servers models.AllServers, err error) { +// readFromFile reads the servers from server.json. +// It only reads servers that have the same version as the hardcoded servers version +// to avoid JSON unmarshaling errors. +func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) ( + servers models.AllServers, err error) { file, err := os.Open(filepath) if os.IsNotExist(err) { return servers, nil } else if err != nil { return servers, err } - decoder := json.NewDecoder(file) - if err := decoder.Decode(&servers); err != nil { - _ = file.Close() - if errors.Is(err, io.EOF) { - return servers, nil - } + + b, err := io.ReadAll(file) + if err != nil { return servers, err } - return servers, file.Close() + + if err := file.Close(); err != nil { + return servers, err + } + + return s.extractServersFromBytes(b, hardcoded) +} + +var ( + errDecodeVersions = errors.New("cannot decode versions") + errDecodeServers = errors.New("cannot decode servers") + errDecodeProvider = errors.New("cannot decode servers for provider") +) + +func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers) ( //nolint:gocognit,gocyclo + servers models.AllServers, err error) { + var versions allVersions + if err := json.Unmarshal(b, &versions); err != nil { + return servers, fmt.Errorf("%w: %s", errDecodeVersions, err) + } + + var rawMessages allJSONRawMessages + if err := json.Unmarshal(b, &rawMessages); err != nil { + return servers, fmt.Errorf("%w: %s", errDecodeServers, err) + } + + // TODO simplify with generics in Go 1.18 + + if hardcoded.Cyberghost.Version != versions.Cyberghost.Version { + s.logVersionDiff("Cyberghost", hardcoded.Cyberghost.Version, versions.Cyberghost.Version) + } else { + err = json.Unmarshal(rawMessages.Cyberghost, &servers.Cyberghost) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Cyberghost", err) + } + } + + if hardcoded.Expressvpn.Version != versions.Expressvpn.Version { + s.logVersionDiff("Expressvpn", hardcoded.Expressvpn.Version, versions.Expressvpn.Version) + } else { + err = json.Unmarshal(rawMessages.Expressvpn, &servers.Expressvpn) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Expressvpn", err) + } + } + + if hardcoded.Fastestvpn.Version != versions.Fastestvpn.Version { + s.logVersionDiff("Fastestvpn", hardcoded.Fastestvpn.Version, versions.Fastestvpn.Version) + } else { + err = json.Unmarshal(rawMessages.Fastestvpn, &servers.Fastestvpn) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Fastestvpn", err) + } + } + + if hardcoded.HideMyAss.Version != versions.HideMyAss.Version { + s.logVersionDiff("HideMyAss", hardcoded.HideMyAss.Version, versions.HideMyAss.Version) + } else { + err = json.Unmarshal(rawMessages.HideMyAss, &servers.HideMyAss) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "HideMyAss", err) + } + } + + if hardcoded.Ipvanish.Version != versions.Ipvanish.Version { + s.logVersionDiff("Ipvanish", hardcoded.Ipvanish.Version, versions.Ipvanish.Version) + } else { + err = json.Unmarshal(rawMessages.Ipvanish, &servers.Ipvanish) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Ipvanish", err) + } + } + + if hardcoded.Ivpn.Version != versions.Ivpn.Version { + s.logVersionDiff("Ivpn", hardcoded.Ivpn.Version, versions.Ivpn.Version) + } else { + err = json.Unmarshal(rawMessages.Ivpn, &servers.Ivpn) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Ivpn", err) + } + } + + if hardcoded.Mullvad.Version != versions.Mullvad.Version { + s.logVersionDiff("Mullvad", hardcoded.Mullvad.Version, versions.Mullvad.Version) + } else { + err = json.Unmarshal(rawMessages.Mullvad, &servers.Mullvad) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Mullvad", err) + } + } + + if hardcoded.Nordvpn.Version != versions.Nordvpn.Version { + s.logVersionDiff("Nordvpn", hardcoded.Nordvpn.Version, versions.Nordvpn.Version) + } else { + err = json.Unmarshal(rawMessages.Nordvpn, &servers.Nordvpn) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Nordvpn", err) + } + } + + if hardcoded.Privado.Version != versions.Privado.Version { + s.logVersionDiff("Privado", hardcoded.Privado.Version, versions.Privado.Version) + } else { + err = json.Unmarshal(rawMessages.Privado, &servers.Privado) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Privado", err) + } + } + + if hardcoded.Pia.Version != versions.Pia.Version { + s.logVersionDiff("Pia", hardcoded.Pia.Version, versions.Pia.Version) + } else { + err = json.Unmarshal(rawMessages.Pia, &servers.Pia) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Pia", err) + } + } + + if hardcoded.Privatevpn.Version != versions.Privatevpn.Version { + s.logVersionDiff("Privatevpn", hardcoded.Privatevpn.Version, versions.Privatevpn.Version) + } else { + err = json.Unmarshal(rawMessages.Privatevpn, &servers.Privatevpn) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Privatevpn", err) + } + } + + if hardcoded.Protonvpn.Version != versions.Protonvpn.Version { + s.logVersionDiff("Protonvpn", hardcoded.Protonvpn.Version, versions.Protonvpn.Version) + } else { + err = json.Unmarshal(rawMessages.Protonvpn, &servers.Protonvpn) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Protonvpn", err) + } + } + + if hardcoded.Purevpn.Version != versions.Purevpn.Version { + s.logVersionDiff("Purevpn", hardcoded.Purevpn.Version, versions.Purevpn.Version) + } else { + err = json.Unmarshal(rawMessages.Purevpn, &servers.Purevpn) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Purevpn", err) + } + } + + if hardcoded.Surfshark.Version != versions.Surfshark.Version { + s.logVersionDiff("Surfshark", hardcoded.Surfshark.Version, versions.Surfshark.Version) + } else { + err = json.Unmarshal(rawMessages.Surfshark, &servers.Surfshark) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Surfshark", err) + } + } + + if hardcoded.Torguard.Version != versions.Torguard.Version { + s.logVersionDiff("Torguard", hardcoded.Torguard.Version, versions.Torguard.Version) + } else { + err = json.Unmarshal(rawMessages.Torguard, &servers.Torguard) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Torguard", err) + } + } + + if hardcoded.VPNUnlimited.Version != versions.VPNUnlimited.Version { + s.logVersionDiff("VPNUnlimited", hardcoded.VPNUnlimited.Version, versions.VPNUnlimited.Version) + } else { + err = json.Unmarshal(rawMessages.VPNUnlimited, &servers.VPNUnlimited) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "VPNUnlimited", err) + } + } + + if hardcoded.Vyprvpn.Version != versions.Vyprvpn.Version { + s.logVersionDiff("Vyprvpn", hardcoded.Vyprvpn.Version, versions.Vyprvpn.Version) + } else { + err = json.Unmarshal(rawMessages.Vyprvpn, &servers.Vyprvpn) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Vyprvpn", err) + } + } + + if hardcoded.Wevpn.Version != versions.Wevpn.Version { + s.logVersionDiff("Wevpn", hardcoded.Wevpn.Version, versions.Wevpn.Version) + } else { + err = json.Unmarshal(rawMessages.Wevpn, &servers.Wevpn) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Wevpn", err) + } + } + + if hardcoded.Windscribe.Version != versions.Windscribe.Version { + s.logVersionDiff("Windscribe", hardcoded.Windscribe.Version, versions.Windscribe.Version) + } else { + err = json.Unmarshal(rawMessages.Windscribe, &servers.Windscribe) + if err != nil { + return servers, fmt.Errorf("%w: %s: %s", errDecodeProvider, "Windscribe", err) + } + } + + return servers, nil +} + +// allVersions is a subset of models.AllServers structure used to track +// versions to avoid unmarshaling errors. +type allVersions struct { + Version uint16 `json:"version"` // used for migration of the top level scheme + Cyberghost serverVersion `json:"cyberghost"` + Expressvpn serverVersion `json:"expressvpn"` + Fastestvpn serverVersion `json:"fastestvpn"` + HideMyAss serverVersion `json:"hidemyass"` + Ipvanish serverVersion `json:"ipvanish"` + Ivpn serverVersion `json:"ivpn"` + Mullvad serverVersion `json:"mullvad"` + Nordvpn serverVersion `json:"nordvpn"` + Privado serverVersion `json:"privado"` + Pia serverVersion `json:"pia"` + Privatevpn serverVersion `json:"privatevpn"` + Protonvpn serverVersion `json:"protonvpn"` + Purevpn serverVersion `json:"purevpn"` + Surfshark serverVersion `json:"surfshark"` + Torguard serverVersion `json:"torguard"` + VPNUnlimited serverVersion `json:"vpnunlimited"` + Vyprvpn serverVersion `json:"vyprvpn"` + Wevpn serverVersion `json:"wevpn"` + Windscribe serverVersion `json:"windscribe"` +} + +type serverVersion struct { + Version uint16 `json:"version"` +} + +// allJSONRawMessages is to delay decoding of each provider servers. +type allJSONRawMessages struct { + Version uint16 `json:"version"` // used for migration of the top level scheme + Cyberghost json.RawMessage `json:"cyberghost"` + Expressvpn json.RawMessage `json:"expressvpn"` + Fastestvpn json.RawMessage `json:"fastestvpn"` + HideMyAss json.RawMessage `json:"hidemyass"` + Ipvanish json.RawMessage `json:"ipvanish"` + Ivpn json.RawMessage `json:"ivpn"` + Mullvad json.RawMessage `json:"mullvad"` + Nordvpn json.RawMessage `json:"nordvpn"` + Privado json.RawMessage `json:"privado"` + Pia json.RawMessage `json:"pia"` + Privatevpn json.RawMessage `json:"privatevpn"` + Protonvpn json.RawMessage `json:"protonvpn"` + Purevpn json.RawMessage `json:"purevpn"` + Surfshark json.RawMessage `json:"surfshark"` + Torguard json.RawMessage `json:"torguard"` + VPNUnlimited json.RawMessage `json:"vpnunlimited"` + Vyprvpn json.RawMessage `json:"vyprvpn"` + Wevpn json.RawMessage `json:"wevpn"` + Windscribe json.RawMessage `json:"windscribe"` } diff --git a/internal/storage/read_test.go b/internal/storage/read_test.go new file mode 100644 index 00000000..c4ace4ad --- /dev/null +++ b/internal/storage/read_test.go @@ -0,0 +1,169 @@ +package storage + +import ( + "errors" + "testing" + + gomock "github.com/golang/mock/gomock" + "github.com/qdm12/gluetun/internal/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_extractServersFromBytes(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + b []byte + hardcoded models.AllServers + logged []string + persisted models.AllServers + err error + }{ + "no data": { + err: errors.New("cannot decode versions: unexpected end of JSON input"), + }, + "empty JSON": { + b: []byte("{}"), + err: errors.New("cannot decode servers for provider: Cyberghost: unexpected end of JSON input"), + }, + "different versions": { + b: []byte(`{}`), + hardcoded: models.AllServers{ + Cyberghost: models.CyberghostServers{Version: 1}, + Expressvpn: models.ExpressvpnServers{Version: 1}, + Fastestvpn: models.FastestvpnServers{Version: 1}, + HideMyAss: models.HideMyAssServers{Version: 1}, + Ipvanish: models.IpvanishServers{Version: 1}, + Ivpn: models.IvpnServers{Version: 1}, + Mullvad: models.MullvadServers{Version: 1}, + Nordvpn: models.NordvpnServers{Version: 1}, + Privado: models.PrivadoServers{Version: 1}, + Pia: models.PiaServers{Version: 1}, + Privatevpn: models.PrivatevpnServers{Version: 1}, + Protonvpn: models.ProtonvpnServers{Version: 1}, + Purevpn: models.PurevpnServers{Version: 1}, + Surfshark: models.SurfsharkServers{Version: 1}, + Torguard: models.TorguardServers{Version: 1}, + VPNUnlimited: models.VPNUnlimitedServers{Version: 1}, + Vyprvpn: models.VyprvpnServers{Version: 1}, + Wevpn: models.WevpnServers{Version: 1}, + Windscribe: models.WindscribeServers{Version: 1}, + }, + logged: []string{ + "Cyberghost servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Expressvpn servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Fastestvpn servers from file discarded because they have version 0 and hardcoded servers have version 1", + "HideMyAss servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Ipvanish servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Ivpn servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Mullvad servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Nordvpn servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Privado servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Pia servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Privatevpn servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Protonvpn servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Purevpn servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Surfshark servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Torguard servers from file discarded because they have version 0 and hardcoded servers have version 1", + "VPNUnlimited servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Vyprvpn servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Wevpn servers from file discarded because they have version 0 and hardcoded servers have version 1", + "Windscribe servers from file discarded because they have version 0 and hardcoded servers have version 1", + }, + }, + "same versions": { + b: []byte(`{ + "cyberghost": {"version": 1, "timestamp": 1}, + "expressvpn": {"version": 1, "timestamp": 1}, + "fastestvpn": {"version": 1, "timestamp": 1}, + "hidemyass": {"version": 1, "timestamp": 1}, + "ipvanish": {"version": 1, "timestamp": 1}, + "ivpn": {"version": 1, "timestamp": 1}, + "mullvad": {"version": 1, "timestamp": 1}, + "nordvpn": {"version": 1, "timestamp": 1}, + "privado": {"version": 1, "timestamp": 1}, + "pia": {"version": 1, "timestamp": 1}, + "privatevpn": {"version": 1, "timestamp": 1}, + "protonvpn": {"version": 1, "timestamp": 1}, + "purevpn": {"version": 1, "timestamp": 1}, + "surfshark": {"version": 1, "timestamp": 1}, + "torguard": {"version": 1, "timestamp": 1}, + "vpnunlimited": {"version": 1, "timestamp": 1}, + "vyprvpn": {"version": 1, "timestamp": 1}, + "wevpn": {"version": 1, "timestamp": 1}, + "windscribe": {"version": 1, "timestamp": 1} + }`), + hardcoded: models.AllServers{ + Cyberghost: models.CyberghostServers{Version: 1}, + Expressvpn: models.ExpressvpnServers{Version: 1}, + Fastestvpn: models.FastestvpnServers{Version: 1}, + HideMyAss: models.HideMyAssServers{Version: 1}, + Ipvanish: models.IpvanishServers{Version: 1}, + Ivpn: models.IvpnServers{Version: 1}, + Mullvad: models.MullvadServers{Version: 1}, + Nordvpn: models.NordvpnServers{Version: 1}, + Privado: models.PrivadoServers{Version: 1}, + Pia: models.PiaServers{Version: 1}, + Privatevpn: models.PrivatevpnServers{Version: 1}, + Protonvpn: models.ProtonvpnServers{Version: 1}, + Purevpn: models.PurevpnServers{Version: 1}, + Surfshark: models.SurfsharkServers{Version: 1}, + Torguard: models.TorguardServers{Version: 1}, + VPNUnlimited: models.VPNUnlimitedServers{Version: 1}, + Vyprvpn: models.VyprvpnServers{Version: 1}, + Wevpn: models.WevpnServers{Version: 1}, + Windscribe: models.WindscribeServers{Version: 1}, + }, + persisted: models.AllServers{ + Cyberghost: models.CyberghostServers{Version: 1, Timestamp: 1}, + Expressvpn: models.ExpressvpnServers{Version: 1, Timestamp: 1}, + Fastestvpn: models.FastestvpnServers{Version: 1, Timestamp: 1}, + HideMyAss: models.HideMyAssServers{Version: 1, Timestamp: 1}, + Ipvanish: models.IpvanishServers{Version: 1, Timestamp: 1}, + Ivpn: models.IvpnServers{Version: 1, Timestamp: 1}, + Mullvad: models.MullvadServers{Version: 1, Timestamp: 1}, + Nordvpn: models.NordvpnServers{Version: 1, Timestamp: 1}, + Privado: models.PrivadoServers{Version: 1, Timestamp: 1}, + Pia: models.PiaServers{Version: 1, Timestamp: 1}, + Privatevpn: models.PrivatevpnServers{Version: 1, Timestamp: 1}, + Protonvpn: models.ProtonvpnServers{Version: 1, Timestamp: 1}, + Purevpn: models.PurevpnServers{Version: 1, Timestamp: 1}, + Surfshark: models.SurfsharkServers{Version: 1, Timestamp: 1}, + Torguard: models.TorguardServers{Version: 1, Timestamp: 1}, + VPNUnlimited: models.VPNUnlimitedServers{Version: 1, Timestamp: 1}, + Vyprvpn: models.VyprvpnServers{Version: 1, Timestamp: 1}, + Wevpn: models.WevpnServers{Version: 1, Timestamp: 1}, + Windscribe: models.WindscribeServers{Version: 1, Timestamp: 1}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + logger := NewMockInfoErrorer(ctrl) + for _, logged := range testCase.logged { + logger.EXPECT().Info(logged) + } + + s := &Storage{ + logger: logger, + } + + servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcoded) + + if testCase.err != nil { + require.Error(t, err) + assert.Equal(t, testCase.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, testCase.persisted, servers) + }) + } +} diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 2c97077a..6a0c72aa 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -5,6 +5,8 @@ import ( "github.com/qdm12/gluetun/internal/models" ) +//go:generate mockgen -destination=infoerrorer_mock_test.go -package $GOPACKAGE . InfoErrorer + type Storage struct { mergedServers models.AllServers hardcodedServers models.AllServers diff --git a/internal/storage/sync.go b/internal/storage/sync.go index 1d65740f..584b329b 100644 --- a/internal/storage/sync.go +++ b/internal/storage/sync.go @@ -36,7 +36,7 @@ func countServers(allServers models.AllServers) int { } func (s *Storage) SyncServers() (err error) { - serversOnFile, err := readFromFile(s.filepath) + serversOnFile, err := s.readFromFile(s.filepath, s.hardcodedServers) if err != nil { return fmt.Errorf("%w: %s", ErrCannotReadFile, err) }