chore(all): memory and thread safe storage
- settings: get filter choices from storage for settings validation - updater: update servers to the storage - storage: minimal deep copying and data duplication - storage: add merged servers mutex for thread safety - connection: filter servers in storage - formatter: format servers to Markdown in storage - PIA: get server by name from storage directly - Updater: get servers count from storage directly - Updater: equality check done in storage, fix #882
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
@@ -24,20 +24,20 @@ func NewConnectionDefaults(openvpnTCPPort, openvpnUDPPort,
|
||||
}
|
||||
}
|
||||
|
||||
var ErrNoServer = errors.New("no server")
|
||||
type Storage interface {
|
||||
FilterServers(provider string, selection settings.ServerSelection) (
|
||||
servers []models.Server, err error)
|
||||
}
|
||||
|
||||
func GetConnection(servers []models.Server,
|
||||
func GetConnection(provider string,
|
||||
storage Storage,
|
||||
selection settings.ServerSelection,
|
||||
defaults ConnectionDefaults,
|
||||
randSource rand.Source) (
|
||||
connection models.Connection, err error) {
|
||||
if len(servers) == 0 {
|
||||
return connection, ErrNoServer
|
||||
}
|
||||
|
||||
servers = filterServers(servers, selection)
|
||||
if len(servers) == 0 {
|
||||
return connection, noServerFoundError(selection)
|
||||
servers, err := storage.FilterServers(provider, selection)
|
||||
if err != nil {
|
||||
return connection, fmt.Errorf("cannot filter servers: %w", err)
|
||||
}
|
||||
|
||||
protocol := getProtocol(selection)
|
||||
|
||||
@@ -1,23 +1,30 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/providers"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errTest := errors.New("test error")
|
||||
|
||||
testCases := map[string]struct {
|
||||
servers []models.Server
|
||||
provider string
|
||||
filteredServers []models.Server
|
||||
filterError error
|
||||
serverSelection settings.ServerSelection
|
||||
defaults ConnectionDefaults
|
||||
randSource rand.Source
|
||||
@@ -25,25 +32,13 @@ func Test_GetConnection(t *testing.T) {
|
||||
errWrapped error
|
||||
errMessage string
|
||||
}{
|
||||
"no server": {
|
||||
serverSelection: settings.ServerSelection{}.
|
||||
WithDefaults(providers.Mullvad),
|
||||
errWrapped: ErrNoServer,
|
||||
errMessage: "no server",
|
||||
},
|
||||
"all servers filtered": {
|
||||
servers: []models.Server{
|
||||
{VPN: vpn.Wireguard},
|
||||
{VPN: vpn.Wireguard},
|
||||
},
|
||||
serverSelection: settings.ServerSelection{
|
||||
VPN: vpn.OpenVPN,
|
||||
}.WithDefaults(providers.Mullvad),
|
||||
errWrapped: ErrNoServerFound,
|
||||
errMessage: "no server found: for VPN openvpn; protocol udp",
|
||||
"storage filter error": {
|
||||
filterError: errTest,
|
||||
errWrapped: errTest,
|
||||
errMessage: "cannot filter servers: test error",
|
||||
},
|
||||
"server without IPs": {
|
||||
servers: []models.Server{
|
||||
filteredServers: []models.Server{
|
||||
{VPN: vpn.OpenVPN, UDP: true},
|
||||
{VPN: vpn.OpenVPN, UDP: true},
|
||||
},
|
||||
@@ -58,7 +53,7 @@ func Test_GetConnection(t *testing.T) {
|
||||
errMessage: "no connection to pick from",
|
||||
},
|
||||
"OpenVPN server with hostname": {
|
||||
servers: []models.Server{
|
||||
filteredServers: []models.Server{
|
||||
{
|
||||
VPN: vpn.OpenVPN,
|
||||
UDP: true,
|
||||
@@ -79,7 +74,7 @@ func Test_GetConnection(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"OpenVPN server with x509": {
|
||||
servers: []models.Server{
|
||||
filteredServers: []models.Server{
|
||||
{
|
||||
VPN: vpn.OpenVPN,
|
||||
UDP: true,
|
||||
@@ -101,7 +96,7 @@ func Test_GetConnection(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"server with IPv4 and IPv6": {
|
||||
servers: []models.Server{
|
||||
filteredServers: []models.Server{
|
||||
{
|
||||
VPN: vpn.OpenVPN,
|
||||
UDP: true,
|
||||
@@ -128,7 +123,7 @@ func Test_GetConnection(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"mixed servers": {
|
||||
servers: []models.Server{
|
||||
filteredServers: []models.Server{
|
||||
{
|
||||
VPN: vpn.OpenVPN,
|
||||
UDP: true,
|
||||
@@ -169,8 +164,14 @@ func Test_GetConnection(t *testing.T) {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
connection, err := GetConnection(testCase.servers,
|
||||
storage := common.NewMockStorage(ctrl)
|
||||
storage.EXPECT().
|
||||
FilterServers(testCase.provider, testCase.serverSelection).
|
||||
Return(testCase.filteredServers, testCase.filterError)
|
||||
|
||||
connection, err := GetConnection(testCase.provider, storage,
|
||||
testCase.serverSelection, testCase.defaults,
|
||||
testCase.randSource)
|
||||
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
)
|
||||
|
||||
func commaJoin(slice []string) string {
|
||||
return strings.Join(slice, ", ")
|
||||
}
|
||||
|
||||
var ErrNoServerFound = errors.New("no server found")
|
||||
|
||||
func noServerFoundError(selection settings.ServerSelection) (err error) {
|
||||
var messageParts []string
|
||||
|
||||
messageParts = append(messageParts, "VPN "+selection.VPN)
|
||||
|
||||
protocol := constants.UDP
|
||||
if *selection.OpenVPN.TCP {
|
||||
protocol = constants.TCP
|
||||
}
|
||||
messageParts = append(messageParts, "protocol "+protocol)
|
||||
|
||||
switch len(selection.Countries) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "country " + selection.Countries[0]
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
part := "countries " + commaJoin(selection.Countries)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
switch len(selection.Regions) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "region " + selection.Regions[0]
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
part := "regions " + commaJoin(selection.Regions)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
switch len(selection.Cities) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "city " + selection.Cities[0]
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
part := "cities " + commaJoin(selection.Cities)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
if *selection.OwnedOnly {
|
||||
messageParts = append(messageParts, "owned servers only")
|
||||
}
|
||||
|
||||
switch len(selection.ISPs) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "ISP " + selection.ISPs[0]
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
part := "ISPs " + commaJoin(selection.ISPs)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
switch len(selection.Hostnames) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "hostname " + selection.Hostnames[0]
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
part := "hostnames " + commaJoin(selection.Hostnames)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
switch len(selection.Names) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "name " + selection.Names[0]
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
part := "names " + commaJoin(selection.Names)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
switch len(selection.Numbers) {
|
||||
case 0:
|
||||
case 1:
|
||||
part := "server number " + strconv.Itoa(int(selection.Numbers[0]))
|
||||
messageParts = append(messageParts, part)
|
||||
default:
|
||||
serverNumbers := make([]string, len(selection.Numbers))
|
||||
for i := range selection.Numbers {
|
||||
serverNumbers[i] = strconv.Itoa(int(selection.Numbers[i]))
|
||||
}
|
||||
part := "server numbers " + commaJoin(serverNumbers)
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
if *selection.OpenVPN.PIAEncPreset != "" {
|
||||
part := "encryption preset " + *selection.OpenVPN.PIAEncPreset
|
||||
messageParts = append(messageParts, part)
|
||||
}
|
||||
|
||||
if *selection.FreeOnly {
|
||||
messageParts = append(messageParts, "free tier only")
|
||||
}
|
||||
|
||||
message := "for " + strings.Join(messageParts, "; ")
|
||||
|
||||
return fmt.Errorf("%w: %s", ErrNoServerFound, message)
|
||||
}
|
||||
Reference in New Issue
Block a user