chore(storage): only pass hardcoded versions to read file
This commit is contained in:
@@ -16,7 +16,7 @@ import (
|
||||
// readFromFile reads the servers from server.json.
|
||||
// It only reads servers that have the same version as the hardcoded servers version
|
||||
// to avoid JSON unmarshaling errors.
|
||||
func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) (
|
||||
func (s *Storage) readFromFile(filepath string, hardcodedVersions map[string]uint16) (
|
||||
servers models.AllServers, err error) {
|
||||
file, err := os.Open(filepath)
|
||||
if os.IsNotExist(err) {
|
||||
@@ -34,10 +34,10 @@ func (s *Storage) readFromFile(filepath string, hardcoded models.AllServers) (
|
||||
return servers, err
|
||||
}
|
||||
|
||||
return s.extractServersFromBytes(b, hardcoded)
|
||||
return s.extractServersFromBytes(b, hardcodedVersions)
|
||||
}
|
||||
|
||||
func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers) (
|
||||
func (s *Storage) extractServersFromBytes(b []byte, hardcodedVersions map[string]uint16) (
|
||||
servers models.AllServers, err error) {
|
||||
rawMessages := make(map[string]json.RawMessage)
|
||||
if err := json.Unmarshal(b, &rawMessages); err != nil {
|
||||
@@ -50,7 +50,7 @@ func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers)
|
||||
servers.ProviderToServers = make(map[string]models.Servers, len(allProviders))
|
||||
titleCaser := cases.Title(language.English)
|
||||
for _, provider := range allProviders {
|
||||
hardcoded, ok := hardcoded.ProviderToServers[provider]
|
||||
hardcodedVersion, ok := hardcodedVersions[provider]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("provider %s not found in hardcoded servers map", provider))
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func (s *Storage) extractServersFromBytes(b []byte, hardcoded models.AllServers)
|
||||
continue
|
||||
}
|
||||
|
||||
mergedServers, versionsMatch, err := s.readServers(provider, hardcoded, rawMessage, titleCaser)
|
||||
mergedServers, versionsMatch, err := s.readServers(provider, hardcodedVersion, rawMessage, titleCaser)
|
||||
if err != nil {
|
||||
return models.AllServers{}, err
|
||||
} else if !versionsMatch {
|
||||
@@ -82,7 +82,7 @@ var (
|
||||
errDecodeProvider = errors.New("cannot decode servers for provider")
|
||||
)
|
||||
|
||||
func (s *Storage) readServers(provider string, hardcoded models.Servers,
|
||||
func (s *Storage) readServers(provider string, hardcodedVersion uint16,
|
||||
rawMessage json.RawMessage, titleCaser cases.Caser) (servers models.Servers,
|
||||
versionsMatch bool, err error) {
|
||||
provider = titleCaser.String(provider)
|
||||
@@ -93,9 +93,9 @@ func (s *Storage) readServers(provider string, hardcoded models.Servers,
|
||||
return servers, false, fmt.Errorf("%w: %s: %s", errDecodeProvider, provider, err)
|
||||
}
|
||||
|
||||
versionsMatch = hardcoded.Version == persistedServers.Version
|
||||
versionsMatch = hardcodedVersion == persistedServers.Version
|
||||
if !versionsMatch {
|
||||
s.logVersionDiff(provider, hardcoded.Version, persistedServers.Version)
|
||||
s.logVersionDiff(provider, hardcodedVersion, persistedServers.Version)
|
||||
return servers, versionsMatch, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,21 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func populateProviders(allProviderVersion uint16, allProviderTimestamp int64,
|
||||
func populateProviderToVersion(allProviderVersion uint16,
|
||||
providerToVersion map[string]uint16) map[string]uint16 {
|
||||
allProviders := providers.All()
|
||||
for _, provider := range allProviders {
|
||||
_, has := providerToVersion[provider]
|
||||
if has {
|
||||
continue
|
||||
}
|
||||
|
||||
providerToVersion[provider] = allProviderVersion
|
||||
}
|
||||
return providerToVersion
|
||||
}
|
||||
|
||||
func populateAllServersVersion(allProviderVersion uint16,
|
||||
servers models.AllServers) models.AllServers {
|
||||
allProviders := providers.All()
|
||||
if servers.ProviderToServers == nil {
|
||||
@@ -23,8 +37,7 @@ func populateProviders(allProviderVersion uint16, allProviderTimestamp int64,
|
||||
continue
|
||||
}
|
||||
servers.ProviderToServers[provider] = models.Servers{
|
||||
Version: allProviderVersion,
|
||||
Timestamp: allProviderTimestamp,
|
||||
Version: allProviderVersion,
|
||||
}
|
||||
}
|
||||
return servers
|
||||
@@ -34,54 +47,54 @@ func Test_extractServersFromBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
b []byte
|
||||
hardcoded models.AllServers
|
||||
logged []string
|
||||
persisted models.AllServers
|
||||
errMessage string
|
||||
b []byte
|
||||
hardcodedVersions map[string]uint16
|
||||
logged []string
|
||||
persisted models.AllServers
|
||||
errMessage string
|
||||
}{
|
||||
"bad JSON": {
|
||||
b: []byte("garbage"),
|
||||
errMessage: "cannot decode servers: invalid character 'g' looking for beginning of value",
|
||||
},
|
||||
"bad provider JSON": {
|
||||
b: []byte(`{"cyberghost": "garbage"}`),
|
||||
hardcoded: populateProviders(1, 0, models.AllServers{}),
|
||||
b: []byte(`{"cyberghost": "garbage"}`),
|
||||
hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}),
|
||||
errMessage: "cannot decode servers for provider: Cyberghost: " +
|
||||
"json: cannot unmarshal string into Go value of type models.Servers",
|
||||
},
|
||||
"absent provider keys": {
|
||||
b: []byte(`{}`),
|
||||
hardcoded: populateProviders(1, 0, models.AllServers{}),
|
||||
b: []byte(`{}`),
|
||||
hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}),
|
||||
persisted: models.AllServers{
|
||||
ProviderToServers: map[string]models.Servers{},
|
||||
},
|
||||
},
|
||||
"same versions": {
|
||||
b: []byte(`{
|
||||
"cyberghost": {"version": 1, "timestamp": 1},
|
||||
"expressvpn": {"version": 1, "timestamp": 1},
|
||||
"fastestvpn": {"version": 1, "timestamp": 1},
|
||||
"hidemyass": {"version": 1, "timestamp": 1},
|
||||
"ipvanish": {"version": 1, "timestamp": 1},
|
||||
"ivpn": {"version": 1, "timestamp": 1},
|
||||
"mullvad": {"version": 1, "timestamp": 1},
|
||||
"nordvpn": {"version": 1, "timestamp": 1},
|
||||
"perfect privacy": {"version": 1, "timestamp": 1},
|
||||
"privado": {"version": 1, "timestamp": 1},
|
||||
"private internet access": {"version": 1, "timestamp": 1},
|
||||
"privatevpn": {"version": 1, "timestamp": 1},
|
||||
"protonvpn": {"version": 1, "timestamp": 1},
|
||||
"purevpn": {"version": 1, "timestamp": 1},
|
||||
"surfshark": {"version": 1, "timestamp": 1},
|
||||
"torguard": {"version": 1, "timestamp": 1},
|
||||
"vpn unlimited": {"version": 1, "timestamp": 1},
|
||||
"vyprvpn": {"version": 1, "timestamp": 1},
|
||||
"wevpn": {"version": 1, "timestamp": 1},
|
||||
"windscribe": {"version": 1, "timestamp": 1}
|
||||
"cyberghost": {"version": 1, "timestamp": 0},
|
||||
"expressvpn": {"version": 1, "timestamp": 0},
|
||||
"fastestvpn": {"version": 1, "timestamp": 0},
|
||||
"hidemyass": {"version": 1, "timestamp": 0},
|
||||
"ipvanish": {"version": 1, "timestamp": 0},
|
||||
"ivpn": {"version": 1, "timestamp": 0},
|
||||
"mullvad": {"version": 1, "timestamp": 0},
|
||||
"nordvpn": {"version": 1, "timestamp": 0},
|
||||
"perfect privacy": {"version": 1, "timestamp": 0},
|
||||
"privado": {"version": 1, "timestamp": 0},
|
||||
"private internet access": {"version": 1, "timestamp": 0},
|
||||
"privatevpn": {"version": 1, "timestamp": 0},
|
||||
"protonvpn": {"version": 1, "timestamp": 0},
|
||||
"purevpn": {"version": 1, "timestamp": 0},
|
||||
"surfshark": {"version": 1, "timestamp": 0},
|
||||
"torguard": {"version": 1, "timestamp": 0},
|
||||
"vpn unlimited": {"version": 1, "timestamp": 0},
|
||||
"vyprvpn": {"version": 1, "timestamp": 0},
|
||||
"wevpn": {"version": 1, "timestamp": 0},
|
||||
"windscribe": {"version": 1, "timestamp": 0}
|
||||
}`),
|
||||
hardcoded: populateProviders(1, 0, models.AllServers{}),
|
||||
persisted: populateProviders(1, 1, models.AllServers{}),
|
||||
hardcodedVersions: populateProviderToVersion(1, map[string]uint16{}),
|
||||
persisted: populateAllServersVersion(1, models.AllServers{}),
|
||||
},
|
||||
"different versions": {
|
||||
b: []byte(`{
|
||||
@@ -106,7 +119,7 @@ func Test_extractServersFromBytes(t *testing.T) {
|
||||
"wevpn": {"version": 1, "timestamp": 1},
|
||||
"windscribe": {"version": 1, "timestamp": 1}
|
||||
}`),
|
||||
hardcoded: populateProviders(2, 0, models.AllServers{}),
|
||||
hardcodedVersions: populateProviderToVersion(2, map[string]uint16{}),
|
||||
logged: []string{
|
||||
"Cyberghost servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
"Expressvpn servers from file discarded because they have version 1 and hardcoded servers have version 2",
|
||||
@@ -155,7 +168,7 @@ func Test_extractServersFromBytes(t *testing.T) {
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcoded)
|
||||
servers, err := s.extractServersFromBytes(testCase.b, testCase.hardcodedVersions)
|
||||
|
||||
if testCase.errMessage != "" {
|
||||
assert.EqualError(t, err, testCase.errMessage)
|
||||
@@ -176,15 +189,13 @@ func Test_extractServersFromBytes(t *testing.T) {
|
||||
require.GreaterOrEqual(t, len(allProviders), 2)
|
||||
|
||||
b := []byte(`{}`)
|
||||
hardcoded := models.AllServers{
|
||||
ProviderToServers: map[string]models.Servers{
|
||||
allProviders[0]: {},
|
||||
// Missing provider allProviders[1]
|
||||
},
|
||||
hardcodedVersions := map[string]uint16{
|
||||
allProviders[0]: 1,
|
||||
// Missing provider allProviders[1]
|
||||
}
|
||||
expectedPanicValue := fmt.Sprintf("provider %s not found in hardcoded servers map", allProviders[1])
|
||||
assert.PanicsWithValue(t, expectedPanicValue, func() {
|
||||
_, _ = s.extractServersFromBytes(b, hardcoded)
|
||||
_, _ = s.extractServersFromBytes(b, hardcodedVersions)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,7 +8,10 @@ import (
|
||||
//go:generate mockgen -destination=infoerrorer_mock_test.go -package $GOPACKAGE . InfoErrorer
|
||||
|
||||
type Storage struct {
|
||||
mergedServers models.AllServers
|
||||
mergedServers models.AllServers
|
||||
// this is stored in memory to avoid re-parsing
|
||||
// the embedded JSON file on every call to the
|
||||
// SyncServers method.
|
||||
hardcodedServers models.AllServers
|
||||
logger Infoer
|
||||
filepath string
|
||||
@@ -22,11 +25,12 @@ type Infoer interface {
|
||||
// embedded servers file and the file on disk.
|
||||
// Passing an empty filepath disables writing servers to a file.
|
||||
func New(logger Infoer, filepath string) (storage *Storage, err error) {
|
||||
// error returned covered by unit test
|
||||
harcodedServers, _ := parseHardcodedServers()
|
||||
// A unit test prevents any error from being returned
|
||||
// and ensures all providers are part of the servers returned.
|
||||
hardcodedServers, _ := parseHardcodedServers()
|
||||
|
||||
storage = &Storage{
|
||||
hardcodedServers: harcodedServers,
|
||||
hardcodedServers: hardcodedServers,
|
||||
logger: logger,
|
||||
filepath: filepath,
|
||||
}
|
||||
|
||||
@@ -14,8 +14,14 @@ func countServers(allServers models.AllServers) (count int) {
|
||||
return count
|
||||
}
|
||||
|
||||
// SyncServers merges the hardcoded servers with the ones from the file.
|
||||
func (s *Storage) SyncServers() (err error) {
|
||||
serversOnFile, err := s.readFromFile(s.filepath, s.hardcodedServers)
|
||||
hardcodedVersions := make(map[string]uint16, len(s.hardcodedServers.ProviderToServers))
|
||||
for provider, servers := range s.hardcodedServers.ProviderToServers {
|
||||
hardcodedVersions[provider] = servers.Version
|
||||
}
|
||||
|
||||
serversOnFile, err := s.readFromFile(s.filepath, hardcodedVersions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read servers from file: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user