diff --git a/internal/storage/read.go b/internal/storage/read.go index c0496594..dac2d52d 100644 --- a/internal/storage/read.go +++ b/internal/storage/read.go @@ -16,7 +16,7 @@ import ( // readFromFile reads the servers from server.json. // It only reads servers that have the same version as the hardcoded servers version // to avoid JSON unmarshaling errors. -func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) ( +func (s *Storage) readFromFile(filepath string, hardcodedVersions map[string]uint16) ( servers models.AllServers, err error) { file, err := os.Open(filepath) if os.IsNotExist(err) { @@ -34,10 +34,10 @@ func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) ( return servers, err } - return s.extractServersFromBytes(b, hardcoded) + return s.extractServersFromBytes(b, hardcodedVersions) } -func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers) ( +func (s *Storage) extractServersFromBytes(b []byte, hardcodedVersions map[string]uint16) ( servers models.AllServers, err error) { rawMessages := make(map[string]json.RawMessage) if err := json.Unmarshal(b, &rawMessages); err != nil { @@ -50,7 +50,7 @@ func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers) servers.ProviderToServers = make(map[string]models.Servers, len(allProviders)) titleCaser := cases.Title(language.English) for _, provider := range allProviders { - hardcoded, ok := hardcoded.ProviderToServers[provider] + hardcodedVersion, ok := hardcodedVersions[provider] if !ok { panic(fmt.Sprintf("provider %s not found in hardcoded servers map", provider)) } @@ -64,7 +64,7 @@ func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers) continue } - mergedServers, versionsMatch, err := s.readServers(provider, hardcoded, rawMessage, titleCaser) + mergedServers, versionsMatch, err := s.readServers(provider, hardcodedVersion, rawMessage, titleCaser) if err != nil { return models.AllServers{}, err } else if !versionsMatch { @@ -82,7 +82,7 @@ var ( errDecodeProvider = errors.New("cannot decode servers for provider") ) -func (s *Storage) readServers(provider string, hardcoded models.Servers, +func (s *Storage) readServers(provider string, hardcodedVersion uint16, rawMessage json.RawMessage, titleCaser cases.Caser) (servers models.Servers, versionsMatch bool, err error) { provider = titleCaser.String(provider) @@ -93,9 +93,9 @@ func (s *Storage) readServers(provider string, hardcoded models.Servers, return servers, false, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err) } - versionsMatch = hardcoded.Version == persistedServers.Version + versionsMatch = hardcodedVersion == persistedServers.Version if !versionsMatch { - s.logVersionDiff(provider, hardcoded.Version, persistedServers.Version) + s.logVersionDiff(provider, hardcodedVersion, persistedServers.Version) return servers, versionsMatch, nil } diff --git a/internal/storage/read_test.go b/internal/storage/read_test.go index 7f5d77ad..8acd1e3d 100644 --- a/internal/storage/read_test.go +++ b/internal/storage/read_test.go @@ -11,7 +11,21 @@ import ( "github.com/stretchr/testify/require" ) -func populateProviders(allProviderVersion uint16, allProviderTimestamp int64, +func populateProviderToVersion(allProviderVersion uint16, + providerToVersion map[string]uint16) map[string]uint16 { + allProviders := providers.All() + for _, provider := range allProviders { + _, has := providerToVersion[provider] + if has { + continue + } + + providerToVersion[provider] = allProviderVersion + } + return providerToVersion +} + +func populateAllServersVersion(allProviderVersion uint16, servers models.AllServers) models.AllServers { allProviders := providers.All() if servers.ProviderToServers == nil { @@ -23,8 +37,7 @@ func populateProviders(allProviderVersion uint16, allProviderTimestamp int64, continue } servers.ProviderToServers[provider] = models.Servers{ - Version: allProviderVersion, - Timestamp: allProviderTimestamp, + Version: allProviderVersion, } } return servers @@ -34,54 +47,54 @@ func Test_extractServersFromBytes(t *testing.T) { t.Parallel() testCases := map[string]struct { - b []byte - hardcoded models.AllServers - logged []string - persisted models.AllServers - errMessage string + b []byte + hardcodedVersions map[string]uint16 + logged []string + persisted models.AllServers + errMessage string }{ "bad JSON": { b: []byte("garbage"), errMessage: "cannot decode servers: invalid character 'g' looking for beginning of value", }, "bad provider JSON": { - b: []byte(`{"cyberghost": "garbage"}`), - hardcoded: populateProviders(1, 0, models.AllServers{}), + b: []byte(`{"cyberghost": "garbage"}`), + hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}), errMessage: "cannot decode servers for provider: Cyberghost: " + "json: cannot unmarshal string into Go value of type models.Servers", }, "absent provider keys": { - b: []byte(`{}`), - hardcoded: populateProviders(1, 0, models.AllServers{}), + b: []byte(`{}`), + hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}), persisted: models.AllServers{ ProviderToServers: map[string]models.Servers{}, }, }, "same versions": { b: []byte(`{ - "cyberghost": {"version": 1, "timestamp": 1}, - "expressvpn": {"version": 1, "timestamp": 1}, - "fastestvpn": {"version": 1, "timestamp": 1}, - "hidemyass": {"version": 1, "timestamp": 1}, - "ipvanish": {"version": 1, "timestamp": 1}, - "ivpn": {"version": 1, "timestamp": 1}, - "mullvad": {"version": 1, "timestamp": 1}, - "nordvpn": {"version": 1, "timestamp": 1}, - "perfect privacy": {"version": 1, "timestamp": 1}, - "privado": {"version": 1, "timestamp": 1}, - "private internet access": {"version": 1, "timestamp": 1}, - "privatevpn": {"version": 1, "timestamp": 1}, - "protonvpn": {"version": 1, "timestamp": 1}, - "purevpn": {"version": 1, "timestamp": 1}, - "surfshark": {"version": 1, "timestamp": 1}, - "torguard": {"version": 1, "timestamp": 1}, - "vpn unlimited": {"version": 1, "timestamp": 1}, - "vyprvpn": {"version": 1, "timestamp": 1}, - "wevpn": {"version": 1, "timestamp": 1}, - "windscribe": {"version": 1, "timestamp": 1} + "cyberghost": {"version": 1, "timestamp": 0}, + "expressvpn": {"version": 1, "timestamp": 0}, + "fastestvpn": {"version": 1, "timestamp": 0}, + "hidemyass": {"version": 1, "timestamp": 0}, + "ipvanish": {"version": 1, "timestamp": 0}, + "ivpn": {"version": 1, "timestamp": 0}, + "mullvad": {"version": 1, "timestamp": 0}, + "nordvpn": {"version": 1, "timestamp": 0}, + "perfect privacy": {"version": 1, "timestamp": 0}, + "privado": {"version": 1, "timestamp": 0}, + "private internet access": {"version": 1, "timestamp": 0}, + "privatevpn": {"version": 1, "timestamp": 0}, + "protonvpn": {"version": 1, "timestamp": 0}, + "purevpn": {"version": 1, "timestamp": 0}, + "surfshark": {"version": 1, "timestamp": 0}, + "torguard": {"version": 1, "timestamp": 0}, + "vpn unlimited": {"version": 1, "timestamp": 0}, + "vyprvpn": {"version": 1, "timestamp": 0}, + "wevpn": {"version": 1, "timestamp": 0}, + "windscribe": {"version": 1, "timestamp": 0} }`), - hardcoded: populateProviders(1, 0, models.AllServers{}), - persisted: populateProviders(1, 1, models.AllServers{}), + hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}), + persisted: populateAllServersVersion(1, models.AllServers{}), }, "different versions": { b: []byte(`{ @@ -106,7 +119,7 @@ func Test_extractServersFromBytes(t *testing.T) { "wevpn": {"version": 1, "timestamp": 1}, "windscribe": {"version": 1, "timestamp": 1} }`), - hardcoded: populateProviders(2, 0, models.AllServers{}), + hardcodedVersions: populateProviderToVersion(2, map[string]uint16{}), logged: []string{ "Cyberghost servers from file discarded because they have version 1 and hardcoded servers have version 2", "Expressvpn servers from file discarded because they have version 1 and hardcoded servers have version 2", @@ -155,7 +168,7 @@ func Test_extractServersFromBytes(t *testing.T) { logger: logger, } - servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcoded) + servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcodedVersions) if testCase.errMessage != "" { assert.EqualError(t, err, testCase.errMessage) @@ -176,15 +189,13 @@ func Test_extractServersFromBytes(t *testing.T) { require.GreaterOrEqual(t, len(allProviders), 2) b := []byte(`{}`) - hardcoded := models.AllServers{ - ProviderToServers: map[string]models.Servers{ - allProviders[0]: {}, - // Missing provider allProviders[1] - }, + hardcodedVersions := map[string]uint16{ + allProviders[0]: 1, + // Missing provider allProviders[1] } expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map", allProviders[1]) assert.PanicsWithValue(t, expectedPanicValue, func() { - _, _ = s.extractServersFromBytes(b, hardcoded) + _, _ = s.extractServersFromBytes(b, hardcodedVersions) }) }) } diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 1656e551..68890878 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -8,7 +8,10 @@ import ( //go:generate mockgen -destination=infoerrorer_mock_test.go -package $GOPACKAGE . InfoErrorer type Storage struct { - mergedServers models.AllServers + mergedServers models.AllServers + // this is stored in memory to avoid re-parsing + // the embedded JSON file on every call to the + // SyncServers method. hardcodedServers models.AllServers logger Infoer filepath string @@ -22,11 +25,12 @@ type Infoer interface { // embedded servers file and the file on disk. // Passing an empty filepath disables writing servers to a file. func New(logger Infoer, filepath string) (storage *Storage, err error) { - // error returned covered by unit test - harcodedServers, _ := parseHardcodedServers() + // A unit test prevents any error from being returned + // and ensures all providers are part of the servers returned. + hardcodedServers, _ := parseHardcodedServers() storage = &Storage{ - hardcodedServers: harcodedServers, + hardcodedServers: hardcodedServers, logger: logger, filepath: filepath, } diff --git a/internal/storage/sync.go b/internal/storage/sync.go index 881bcdce..ed5f173e 100644 --- a/internal/storage/sync.go +++ b/internal/storage/sync.go @@ -14,8 +14,14 @@ func countServers(allServers models.AllServers) (count int) { return count } +// SyncServers merges the hardcoded servers with the ones from the file. func (s *Storage) SyncServers() (err error) { - serversOnFile, err := s.readFromFile(s.filepath, s.hardcodedServers) + hardcodedVersions := make(map[string]uint16, len(s.hardcodedServers.ProviderToServers)) + for provider, servers := range s.hardcodedServers.ProviderToServers { + hardcodedVersions[provider] = servers.Version + } + + serversOnFile, err := s.readFromFile(s.filepath, hardcodedVersions) if err != nil { return fmt.Errorf("cannot read servers from file: %w", err) }