chore: simplify provider GetConnection

This commit is contained in:
Quentin McGaw
2022-04-19 14:28:57 +00:00
parent 306d8494d6
commit 0c0f1663b1
36 changed files with 243 additions and 707 deletions

View File

@@ -0,0 +1,60 @@
package utils
import (
"fmt"
"math/rand"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
)
type ConnectionDefaults struct {
OpenVPNTCPPort uint16
OpenVPNUDPPort uint16
WireguardPort uint16
}
func NewConnectionDefaults(openvpnTCPPort, openvpnUDPPort,
wireguardPort uint16) ConnectionDefaults {
return ConnectionDefaults{
OpenVPNTCPPort: openvpnTCPPort,
OpenVPNUDPPort: openvpnUDPPort,
WireguardPort: wireguardPort,
}
}
func GetConnection(servers []models.Server,
selection settings.ServerSelection,
defaults ConnectionDefaults,
randSource rand.Source) (
connection models.Connection, err error) {
servers, err = FilterServers(servers, selection)
if err != nil {
return connection, fmt.Errorf("cannot filter servers: %w", err)
}
protocol := getProtocol(selection)
port := GetPort(selection, defaults.OpenVPNTCPPort,
defaults.OpenVPNUDPPort, defaults.WireguardPort)
connections := make([]models.Connection, 0, len(servers))
for _, server := range servers {
for _, ip := range server.IPs {
if ip.To4() == nil {
// do not use IPv6 connections for now
continue
}
connection := models.Connection{
Type: selection.VPN,
IP: ip,
Port: port,
Protocol: protocol,
Hostname: server.Hostname,
PubKey: server.WgPubKey, // Wireguard
}
connections = append(connections, connection)
}
}
return PickConnection(connections, selection, randSource)
}

View File

@@ -0,0 +1,9 @@
package utils
import "testing"
func Test_GetConnection(t *testing.T) {
t.Parallel()
// testCases := map[string]struct{}{}
}

View File

@@ -34,7 +34,7 @@ func filterServer(server models.Server,
return true
}
if FilterByProtocol(selection, server.TCP, server.UDP) {
if filterByProtocol(selection, server.TCP, server.UDP) {
return true
}

View File

@@ -1,7 +1,11 @@
package utils
import (
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/vpn"
)
@@ -13,6 +17,7 @@ func GetPort(selection settings.ServerSelection,
if customPort > 0 {
return customPort
}
checkDefined("Wireguard", defaultWireguard)
return defaultWireguard
default: // OpenVPN
customPort := *selection.OpenVPN.CustomPort
@@ -20,8 +25,40 @@ func GetPort(selection settings.ServerSelection,
return customPort
}
if *selection.OpenVPN.TCP {
checkDefined("OpenVPN TCP", defaultOpenVPNTCP)
return defaultOpenVPNTCP
}
checkDefined("OpenVPN UDP", defaultOpenVPNUDP)
return defaultOpenVPNUDP
}
}
func checkDefined(portName string, port uint16) {
if port > 0 {
return
}
message := fmt.Sprintf("no default %s port is defined!", portName)
panic(message)
}
var ErrInvalidPort = errors.New("invalid port number")
// CheckPortAllowed for custom port used for OpenVPN.
func CheckPortAllowed(port uint16, tcp bool,
allowedTCP, allowedUDP []uint16) (err error) {
allowedPorts := allowedUDP
protocol := constants.UDP
if tcp {
allowedPorts = allowedTCP
protocol = constants.TCP
}
for _, allowedPort := range allowedPorts {
if port == allowedPort {
return nil
}
}
return fmt.Errorf("%w: %d for protocol %s",
ErrInvalidPort, port, protocol)
}

View File

@@ -21,12 +21,19 @@ func Test_GetPort(t *testing.T) {
)
testCases := map[string]struct {
selection settings.ServerSelection
port uint16
selection settings.ServerSelection
defaultOpenVPNTCP uint16
defaultOpenVPNUDP uint16
defaultWireguard uint16
port uint16
panics string
}{
"default": {
selection: settings.ServerSelection{}.WithDefaults(""),
port: defaultOpenVPNUDP,
selection: settings.ServerSelection{}.WithDefaults(""),
defaultOpenVPNTCP: defaultOpenVPNTCP,
defaultOpenVPNUDP: defaultOpenVPNUDP,
defaultWireguard: defaultWireguard,
port: defaultOpenVPNUDP,
},
"OpenVPN UDP": {
selection: settings.ServerSelection{
@@ -36,7 +43,20 @@ func Test_GetPort(t *testing.T) {
TCP: boolPtr(false),
},
},
port: defaultOpenVPNUDP,
defaultOpenVPNTCP: defaultOpenVPNTCP,
defaultOpenVPNUDP: defaultOpenVPNUDP,
defaultWireguard: defaultWireguard,
port: defaultOpenVPNUDP,
},
"OpenVPN UDP no default port defined": {
selection: settings.ServerSelection{
VPN: vpn.OpenVPN,
OpenVPN: settings.OpenVPNSelection{
CustomPort: uint16Ptr(0),
TCP: boolPtr(false),
},
},
panics: "no default OpenVPN UDP port is defined!",
},
"OpenVPN TCP": {
selection: settings.ServerSelection{
@@ -46,7 +66,18 @@ func Test_GetPort(t *testing.T) {
TCP: boolPtr(true),
},
},
port: defaultOpenVPNTCP,
defaultOpenVPNTCP: defaultOpenVPNTCP,
port: defaultOpenVPNTCP,
},
"OpenVPN TCP no default port defined": {
selection: settings.ServerSelection{
VPN: vpn.OpenVPN,
OpenVPN: settings.OpenVPNSelection{
CustomPort: uint16Ptr(0),
TCP: boolPtr(true),
},
},
panics: "no default OpenVPN TCP port is defined!",
},
"OpenVPN custom port": {
selection: settings.ServerSelection{
@@ -61,7 +92,8 @@ func Test_GetPort(t *testing.T) {
selection: settings.ServerSelection{
VPN: vpn.Wireguard,
}.WithDefaults(""),
port: defaultWireguard,
defaultWireguard: defaultWireguard,
port: defaultWireguard,
},
"Wireguard custom port": {
selection: settings.ServerSelection{
@@ -70,7 +102,14 @@ func Test_GetPort(t *testing.T) {
EndpointPort: uint16Ptr(1234),
},
},
port: 1234,
defaultWireguard: defaultWireguard,
port: 1234,
},
"Wireguard no default port defined": {
selection: settings.ServerSelection{
VPN: vpn.Wireguard,
}.WithDefaults(""),
panics: "no default Wireguard port is defined!",
},
}
@@ -79,8 +118,20 @@ func Test_GetPort(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
if testCase.panics != "" {
assert.PanicsWithValue(t, testCase.panics, func() {
_ = GetPort(testCase.selection,
testCase.defaultOpenVPNTCP,
testCase.defaultOpenVPNUDP,
testCase.defaultWireguard)
})
return
}
port := GetPort(testCase.selection,
defaultOpenVPNTCP, defaultOpenVPNUDP, defaultWireguard)
testCase.defaultOpenVPNTCP,
testCase.defaultOpenVPNUDP,
testCase.defaultWireguard)
assert.Equal(t, testCase.port, port)
})

View File

@@ -6,14 +6,14 @@ import (
"github.com/qdm12/gluetun/internal/constants/vpn"
)
func GetProtocol(selection settings.ServerSelection) (protocol string) {
func getProtocol(selection settings.ServerSelection) (protocol string) {
if selection.VPN == vpn.OpenVPN && *selection.OpenVPN.TCP {
return constants.TCP
}
return constants.UDP
}
func FilterByProtocol(selection settings.ServerSelection,
func filterByProtocol(selection settings.ServerSelection,
serverTCP, serverUDP bool) (filtered bool) {
switch selection.VPN {
case vpn.Wireguard:

View File

@@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_GetProtocol(t *testing.T) {
func Test_getProtocol(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
@@ -50,14 +50,14 @@ func Test_GetProtocol(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
protocol := GetProtocol(testCase.selection)
protocol := getProtocol(testCase.selection)
assert.Equal(t, testCase.protocol, protocol)
})
}
}
func Test_FilterByProtocol(t *testing.T) {
func Test_filterByProtocol(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
@@ -127,7 +127,7 @@ func Test_FilterByProtocol(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
filtered := FilterByProtocol(testCase.selection,
filtered := filterByProtocol(testCase.selection,
testCase.serverTCP, testCase.serverUDP)
assert.Equal(t, testCase.filtered, filtered)