Maint: improve servers data embedding

- use embed.FS to have immutable data
- use sync.Once to parse only once without data races
This commit is contained in:
Quentin McGaw (desktop)
2021-07-20 19:01:49 +00:00
parent e0735b57ce
commit 82533c1453
2 changed files with 15 additions and 10 deletions

View File

@@ -1,29 +1,34 @@
package constants package constants
import ( import (
_ "embed" "embed"
"encoding/json" "encoding/json"
"sync"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
) )
//go:embed servers.json //go:embed servers.json
var allServersBytes []byte //nolint:gochecknoglobals var allServersEmbedFS embed.FS //nolint:gochecknoglobals
var allServers models.AllServers //nolint:gochecknoglobals var allServers models.AllServers //nolint:gochecknoglobals
var parseOnce sync.Once //nolint:gochecknoglobals
func init() { //nolint:gochecknoinits func init() { //nolint:gochecknoinits
// error returned covered by unit test // error returned covered by unit test
allServers, _ = parseAllServers(allServersBytes) parseOnce.Do(func() { allServers, _ = parseAllServers() })
} }
func parseAllServers(b []byte) (allServers models.AllServers, err error) { func parseAllServers() (allServers models.AllServers, err error) {
err = json.Unmarshal(b, &allServers) f, err := allServersEmbedFS.Open("servers.json")
if err != nil {
return allServers, err
}
decoder := json.NewDecoder(f)
err = decoder.Decode(&allServers)
return allServers, err return allServers, err
} }
func GetAllServers() (allServers models.AllServers) { func GetAllServers() models.AllServers {
if allServers.Version == 0 { // not parsed yet - for unit tests mostly parseOnce.Do(func() { allServers, _ = parseAllServers() }) // init did not execute, used in tests
allServers, _ = parseAllServers(allServersBytes)
}
return allServers return allServers
} }

View File

@@ -15,7 +15,7 @@ import (
func Test_parseAllServers(t *testing.T) { func Test_parseAllServers(t *testing.T) {
t.Parallel() t.Parallel()
servers, err := parseAllServers(allServersBytes) servers, err := parseAllServers()
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, len(servers.Cyberghost.Servers)) require.NotEmpty(t, len(servers.Cyberghost.Servers))