chore(storage): only pass hardcoded versions to read file

This commit is contained in:
Quentin McGaw
2022-05-28 22:36:16 +00:00
parent 22455ac76f
commit 90dd3b1b5c
4 changed files with 76 additions and 55 deletions

View File

@@ -16,7 +16,7 @@ import (
// readFromFile reads the servers from server.json.
// It only reads servers that have the same version as the hardcoded servers version
// to avoid JSON unmarshaling errors.
func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) (
func (s *Storage) readFromFile(filepath string, hardcodedVersions map[string]uint16) (
servers models.AllServers, err error) {
file, err := os.Open(filepath)
if os.IsNotExist(err) {
@@ -34,10 +34,10 @@ func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) (
return servers, err
}
return s.extractServersFromBytes(b, hardcoded)
return s.extractServersFromBytes(b, hardcodedVersions)
}
func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers) (
func (s *Storage) extractServersFromBytes(b []byte, hardcodedVersions map[string]uint16) (
servers models.AllServers, err error) {
rawMessages := make(map[string]json.RawMessage)
if err := json.Unmarshal(b, &rawMessages); err != nil {
@@ -50,7 +50,7 @@ func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers)
servers.ProviderToServers = make(map[string]models.Servers, len(allProviders))
titleCaser := cases.Title(language.English)
for _, provider := range allProviders {
hardcoded, ok := hardcoded.ProviderToServers[provider]
hardcodedVersion, ok := hardcodedVersions[provider]
if !ok {
panic(fmt.Sprintf("provider %s not found in hardcoded servers map", provider))
}
@@ -64,7 +64,7 @@ func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers)
continue
}
mergedServers, versionsMatch, err := s.readServers(provider, hardcoded, rawMessage, titleCaser)
mergedServers, versionsMatch, err := s.readServers(provider, hardcodedVersion, rawMessage, titleCaser)
if err != nil {
return models.AllServers{}, err
} else if !versionsMatch {
@@ -82,7 +82,7 @@ var (
errDecodeProvider = errors.New("cannot decode servers for provider")
)
func (s *Storage) readServers(provider string, hardcoded models.Servers,
func (s *Storage) readServers(provider string, hardcodedVersion uint16,
rawMessage json.RawMessage, titleCaser cases.Caser) (servers models.Servers,
versionsMatch bool, err error) {
provider = titleCaser.String(provider)
@@ -93,9 +93,9 @@ func (s *Storage) readServers(provider string, hardcoded models.Servers,
return servers, false, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err)
}
versionsMatch = hardcoded.Version == persistedServers.Version
versionsMatch = hardcodedVersion == persistedServers.Version
if !versionsMatch {
s.logVersionDiff(provider, hardcoded.Version, persistedServers.Version)
s.logVersionDiff(provider, hardcodedVersion, persistedServers.Version)
return servers, versionsMatch, nil
}

View File

@@ -11,7 +11,21 @@ import (
"github.com/stretchr/testify/require"
)
func populateProviders(allProviderVersion uint16, allProviderTimestamp int64,
func populateProviderToVersion(allProviderVersion uint16,
providerToVersion map[string]uint16) map[string]uint16 {
allProviders := providers.All()
for _, provider := range allProviders {
_, has := providerToVersion[provider]
if has {
continue
}
providerToVersion[provider] = allProviderVersion
}
return providerToVersion
}
func populateAllServersVersion(allProviderVersion uint16,
servers models.AllServers) models.AllServers {
allProviders := providers.All()
if servers.ProviderToServers == nil {
@@ -23,8 +37,7 @@ func populateProviders(allProviderVersion uint16, allProviderTimestamp int64,
continue
}
servers.ProviderToServers[provider] = models.Servers{
Version: allProviderVersion,
Timestamp: allProviderTimestamp,
Version: allProviderVersion,
}
}
return servers
@@ -34,54 +47,54 @@ func Test_extractServersFromBytes(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
b []byte
hardcoded models.AllServers
logged []string
persisted models.AllServers
errMessage string
b []byte
hardcodedVersions map[string]uint16
logged []string
persisted models.AllServers
errMessage string
}{
"bad JSON": {
b: []byte("garbage"),
errMessage: "cannot decode servers: invalid character 'g' looking for beginning of value",
},
"bad provider JSON": {
b: []byte(`{"cyberghost": "garbage"}`),
hardcoded: populateProviders(1, 0, models.AllServers{}),
b: []byte(`{"cyberghost": "garbage"}`),
hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}),
errMessage: "cannot decode servers for provider: Cyberghost: " +
"json: cannot unmarshal string into Go value of type models.Servers",
},
"absent provider keys": {
b: []byte(`{}`),
hardcoded: populateProviders(1, 0, models.AllServers{}),
b: []byte(`{}`),
hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}),
persisted: models.AllServers{
ProviderToServers: map[string]models.Servers{},
},
},
"same versions": {
b: []byte(`{
"cyberghost": {"version": 1, "timestamp": 1},
"expressvpn": {"version": 1, "timestamp": 1},
"fastestvpn": {"version": 1, "timestamp": 1},
"hidemyass": {"version": 1, "timestamp": 1},
"ipvanish": {"version": 1, "timestamp": 1},
"ivpn": {"version": 1, "timestamp": 1},
"mullvad": {"version": 1, "timestamp": 1},
"nordvpn": {"version": 1, "timestamp": 1},
"perfect privacy": {"version": 1, "timestamp": 1},
"privado": {"version": 1, "timestamp": 1},
"private internet access": {"version": 1, "timestamp": 1},
"privatevpn": {"version": 1, "timestamp": 1},
"protonvpn": {"version": 1, "timestamp": 1},
"purevpn": {"version": 1, "timestamp": 1},
"surfshark": {"version": 1, "timestamp": 1},
"torguard": {"version": 1, "timestamp": 1},
"vpn unlimited": {"version": 1, "timestamp": 1},
"vyprvpn": {"version": 1, "timestamp": 1},
"wevpn": {"version": 1, "timestamp": 1},
"windscribe": {"version": 1, "timestamp": 1}
"cyberghost": {"version": 1, "timestamp": 0},
"expressvpn": {"version": 1, "timestamp": 0},
"fastestvpn": {"version": 1, "timestamp": 0},
"hidemyass": {"version": 1, "timestamp": 0},
"ipvanish": {"version": 1, "timestamp": 0},
"ivpn": {"version": 1, "timestamp": 0},
"mullvad": {"version": 1, "timestamp": 0},
"nordvpn": {"version": 1, "timestamp": 0},
"perfect privacy": {"version": 1, "timestamp": 0},
"privado": {"version": 1, "timestamp": 0},
"private internet access": {"version": 1, "timestamp": 0},
"privatevpn": {"version": 1, "timestamp": 0},
"protonvpn": {"version": 1, "timestamp": 0},
"purevpn": {"version": 1, "timestamp": 0},
"surfshark": {"version": 1, "timestamp": 0},
"torguard": {"version": 1, "timestamp": 0},
"vpn unlimited": {"version": 1, "timestamp": 0},
"vyprvpn": {"version": 1, "timestamp": 0},
"wevpn": {"version": 1, "timestamp": 0},
"windscribe": {"version": 1, "timestamp": 0}
}`),
hardcoded: populateProviders(1, 0, models.AllServers{}),
persisted: populateProviders(1, 1, models.AllServers{}),
hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}),
persisted: populateAllServersVersion(1, models.AllServers{}),
},
"different versions": {
b: []byte(`{
@@ -106,7 +119,7 @@ func Test_extractServersFromBytes(t *testing.T) {
"wevpn": {"version": 1, "timestamp": 1},
"windscribe": {"version": 1, "timestamp": 1}
}`),
hardcoded: populateProviders(2, 0, models.AllServers{}),
hardcodedVersions: populateProviderToVersion(2, map[string]uint16{}),
logged: []string{
"Cyberghost servers from file discarded because they have version 1 and hardcoded servers have version 2",
"Expressvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
@@ -155,7 +168,7 @@ func Test_extractServersFromBytes(t *testing.T) {
logger: logger,
}
servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcoded)
servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcodedVersions)
if testCase.errMessage != "" {
assert.EqualError(t, err, testCase.errMessage)
@@ -176,15 +189,13 @@ func Test_extractServersFromBytes(t *testing.T) {
require.GreaterOrEqual(t, len(allProviders), 2)
b := []byte(`{}`)
hardcoded := models.AllServers{
ProviderToServers: map[string]models.Servers{
allProviders[0]: {},
// Missing provider allProviders[1]
},
hardcodedVersions := map[string]uint16{
allProviders[0]: 1,
// Missing provider allProviders[1]
}
expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map", allProviders[1])
assert.PanicsWithValue(t, expectedPanicValue, func() {
_, _ = s.extractServersFromBytes(b, hardcoded)
_, _ = s.extractServersFromBytes(b, hardcodedVersions)
})
})
}

View File

@@ -8,7 +8,10 @@ import (
//go:generate mockgen -destination=infoerrorer_mock_test.go -package $GOPACKAGE . InfoErrorer
type Storage struct {
mergedServers models.AllServers
mergedServers models.AllServers
// this is stored in memory to avoid re-parsing
// the embedded JSON file on every call to the
// SyncServers method.
hardcodedServers models.AllServers
logger Infoer
filepath string
@@ -22,11 +25,12 @@ type Infoer interface {
// embedded servers file and the file on disk.
// Passing an empty filepath disables writing servers to a file.
func New(logger Infoer, filepath string) (storage *Storage, err error) {
// error returned covered by unit test
harcodedServers, _ := parseHardcodedServers()
// A unit test prevents any error from being returned
// and ensures all providers are part of the servers returned.
hardcodedServers, _ := parseHardcodedServers()
storage = &Storage{
hardcodedServers: harcodedServers,
hardcodedServers: hardcodedServers,
logger: logger,
filepath: filepath,
}

View File

@@ -14,8 +14,14 @@ func countServers(allServers models.AllServers) (count int) {
return count
}
// SyncServers merges the hardcoded servers with the ones from the file.
func (s *Storage) SyncServers() (err error) {
serversOnFile, err := s.readFromFile(s.filepath, s.hardcodedServers)
hardcodedVersions := make(map[string]uint16, len(s.hardcodedServers.ProviderToServers))
for provider, servers := range s.hardcodedServers.ProviderToServers {
hardcodedVersions[provider] = servers.Version
}
serversOnFile, err := s.readFromFile(s.filepath, hardcodedVersions)
if err != nil {
return fmt.Errorf("cannot read servers from file: %w", err)
}