chore: simplify provider GetConnection
This commit is contained in:
60
internal/provider/utils/connection.go
Normal file
60
internal/provider/utils/connection.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
type ConnectionDefaults struct {
|
||||
OpenVPNTCPPort uint16
|
||||
OpenVPNUDPPort uint16
|
||||
WireguardPort uint16
|
||||
}
|
||||
|
||||
func NewConnectionDefaults(openvpnTCPPort, openvpnUDPPort,
|
||||
wireguardPort uint16) ConnectionDefaults {
|
||||
return ConnectionDefaults{
|
||||
OpenVPNTCPPort: openvpnTCPPort,
|
||||
OpenVPNUDPPort: openvpnUDPPort,
|
||||
WireguardPort: wireguardPort,
|
||||
}
|
||||
}
|
||||
|
||||
func GetConnection(servers []models.Server,
|
||||
selection settings.ServerSelection,
|
||||
defaults ConnectionDefaults,
|
||||
randSource rand.Source) (
|
||||
connection models.Connection, err error) {
|
||||
servers, err = FilterServers(servers, selection)
|
||||
if err != nil {
|
||||
return connection, fmt.Errorf("cannot filter servers: %w", err)
|
||||
}
|
||||
|
||||
protocol := getProtocol(selection)
|
||||
port := GetPort(selection, defaults.OpenVPNTCPPort,
|
||||
defaults.OpenVPNUDPPort, defaults.WireguardPort)
|
||||
|
||||
connections := make([]models.Connection, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
for _, ip := range server.IPs {
|
||||
if ip.To4() == nil {
|
||||
// do not use IPv6 connections for now
|
||||
continue
|
||||
}
|
||||
connection := models.Connection{
|
||||
Type: selection.VPN,
|
||||
IP: ip,
|
||||
Port: port,
|
||||
Protocol: protocol,
|
||||
Hostname: server.Hostname,
|
||||
PubKey: server.WgPubKey, // Wireguard
|
||||
}
|
||||
connections = append(connections, connection)
|
||||
}
|
||||
}
|
||||
|
||||
return PickConnection(connections, selection, randSource)
|
||||
}
|
||||
9
internal/provider/utils/connection_test.go
Normal file
9
internal/provider/utils/connection_test.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package utils
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_GetConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// testCases := map[string]struct{}{}
|
||||
}
|
||||
@@ -34,7 +34,7 @@ func filterServer(server models.Server,
|
||||
return true
|
||||
}
|
||||
|
||||
if FilterByProtocol(selection, server.TCP, server.UDP) {
|
||||
if filterByProtocol(selection, server.TCP, server.UDP) {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
)
|
||||
|
||||
@@ -13,6 +17,7 @@ func GetPort(selection settings.ServerSelection,
|
||||
if customPort > 0 {
|
||||
return customPort
|
||||
}
|
||||
checkDefined("Wireguard", defaultWireguard)
|
||||
return defaultWireguard
|
||||
default: // OpenVPN
|
||||
customPort := *selection.OpenVPN.CustomPort
|
||||
@@ -20,8 +25,40 @@ func GetPort(selection settings.ServerSelection,
|
||||
return customPort
|
||||
}
|
||||
if *selection.OpenVPN.TCP {
|
||||
checkDefined("OpenVPN TCP", defaultOpenVPNTCP)
|
||||
return defaultOpenVPNTCP
|
||||
}
|
||||
|
||||
checkDefined("OpenVPN UDP", defaultOpenVPNUDP)
|
||||
return defaultOpenVPNUDP
|
||||
}
|
||||
}
|
||||
|
||||
func checkDefined(portName string, port uint16) {
|
||||
if port > 0 {
|
||||
return
|
||||
}
|
||||
message := fmt.Sprintf("no default %s port is defined!", portName)
|
||||
panic(message)
|
||||
}
|
||||
|
||||
var ErrInvalidPort = errors.New("invalid port number")
|
||||
|
||||
// CheckPortAllowed for custom port used for OpenVPN.
|
||||
func CheckPortAllowed(port uint16, tcp bool,
|
||||
allowedTCP, allowedUDP []uint16) (err error) {
|
||||
allowedPorts := allowedUDP
|
||||
protocol := constants.UDP
|
||||
if tcp {
|
||||
allowedPorts = allowedTCP
|
||||
protocol = constants.TCP
|
||||
}
|
||||
for _, allowedPort := range allowedPorts {
|
||||
if port == allowedPort {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: %d for protocol %s",
|
||||
ErrInvalidPort, port, protocol)
|
||||
}
|
||||
|
||||
@@ -21,12 +21,19 @@ func Test_GetPort(t *testing.T) {
|
||||
)
|
||||
|
||||
testCases := map[string]struct {
|
||||
selection settings.ServerSelection
|
||||
port uint16
|
||||
selection settings.ServerSelection
|
||||
defaultOpenVPNTCP uint16
|
||||
defaultOpenVPNUDP uint16
|
||||
defaultWireguard uint16
|
||||
port uint16
|
||||
panics string
|
||||
}{
|
||||
"default": {
|
||||
selection: settings.ServerSelection{}.WithDefaults(""),
|
||||
port: defaultOpenVPNUDP,
|
||||
selection: settings.ServerSelection{}.WithDefaults(""),
|
||||
defaultOpenVPNTCP: defaultOpenVPNTCP,
|
||||
defaultOpenVPNUDP: defaultOpenVPNUDP,
|
||||
defaultWireguard: defaultWireguard,
|
||||
port: defaultOpenVPNUDP,
|
||||
},
|
||||
"OpenVPN UDP": {
|
||||
selection: settings.ServerSelection{
|
||||
@@ -36,7 +43,20 @@ func Test_GetPort(t *testing.T) {
|
||||
TCP: boolPtr(false),
|
||||
},
|
||||
},
|
||||
port: defaultOpenVPNUDP,
|
||||
defaultOpenVPNTCP: defaultOpenVPNTCP,
|
||||
defaultOpenVPNUDP: defaultOpenVPNUDP,
|
||||
defaultWireguard: defaultWireguard,
|
||||
port: defaultOpenVPNUDP,
|
||||
},
|
||||
"OpenVPN UDP no default port defined": {
|
||||
selection: settings.ServerSelection{
|
||||
VPN: vpn.OpenVPN,
|
||||
OpenVPN: settings.OpenVPNSelection{
|
||||
CustomPort: uint16Ptr(0),
|
||||
TCP: boolPtr(false),
|
||||
},
|
||||
},
|
||||
panics: "no default OpenVPN UDP port is defined!",
|
||||
},
|
||||
"OpenVPN TCP": {
|
||||
selection: settings.ServerSelection{
|
||||
@@ -46,7 +66,18 @@ func Test_GetPort(t *testing.T) {
|
||||
TCP: boolPtr(true),
|
||||
},
|
||||
},
|
||||
port: defaultOpenVPNTCP,
|
||||
defaultOpenVPNTCP: defaultOpenVPNTCP,
|
||||
port: defaultOpenVPNTCP,
|
||||
},
|
||||
"OpenVPN TCP no default port defined": {
|
||||
selection: settings.ServerSelection{
|
||||
VPN: vpn.OpenVPN,
|
||||
OpenVPN: settings.OpenVPNSelection{
|
||||
CustomPort: uint16Ptr(0),
|
||||
TCP: boolPtr(true),
|
||||
},
|
||||
},
|
||||
panics: "no default OpenVPN TCP port is defined!",
|
||||
},
|
||||
"OpenVPN custom port": {
|
||||
selection: settings.ServerSelection{
|
||||
@@ -61,7 +92,8 @@ func Test_GetPort(t *testing.T) {
|
||||
selection: settings.ServerSelection{
|
||||
VPN: vpn.Wireguard,
|
||||
}.WithDefaults(""),
|
||||
port: defaultWireguard,
|
||||
defaultWireguard: defaultWireguard,
|
||||
port: defaultWireguard,
|
||||
},
|
||||
"Wireguard custom port": {
|
||||
selection: settings.ServerSelection{
|
||||
@@ -70,7 +102,14 @@ func Test_GetPort(t *testing.T) {
|
||||
EndpointPort: uint16Ptr(1234),
|
||||
},
|
||||
},
|
||||
port: 1234,
|
||||
defaultWireguard: defaultWireguard,
|
||||
port: 1234,
|
||||
},
|
||||
"Wireguard no default port defined": {
|
||||
selection: settings.ServerSelection{
|
||||
VPN: vpn.Wireguard,
|
||||
}.WithDefaults(""),
|
||||
panics: "no default Wireguard port is defined!",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -79,8 +118,20 @@ func Test_GetPort(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if testCase.panics != "" {
|
||||
assert.PanicsWithValue(t, testCase.panics, func() {
|
||||
_ = GetPort(testCase.selection,
|
||||
testCase.defaultOpenVPNTCP,
|
||||
testCase.defaultOpenVPNUDP,
|
||||
testCase.defaultWireguard)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
port := GetPort(testCase.selection,
|
||||
defaultOpenVPNTCP, defaultOpenVPNUDP, defaultWireguard)
|
||||
testCase.defaultOpenVPNTCP,
|
||||
testCase.defaultOpenVPNUDP,
|
||||
testCase.defaultWireguard)
|
||||
|
||||
assert.Equal(t, testCase.port, port)
|
||||
})
|
||||
|
||||
@@ -6,14 +6,14 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
)
|
||||
|
||||
func GetProtocol(selection settings.ServerSelection) (protocol string) {
|
||||
func getProtocol(selection settings.ServerSelection) (protocol string) {
|
||||
if selection.VPN == vpn.OpenVPN && *selection.OpenVPN.TCP {
|
||||
return constants.TCP
|
||||
}
|
||||
return constants.UDP
|
||||
}
|
||||
|
||||
func FilterByProtocol(selection settings.ServerSelection,
|
||||
func filterByProtocol(selection settings.ServerSelection,
|
||||
serverTCP, serverUDP bool) (filtered bool) {
|
||||
switch selection.VPN {
|
||||
case vpn.Wireguard:
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetProtocol(t *testing.T) {
|
||||
func Test_getProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
@@ -50,14 +50,14 @@ func Test_GetProtocol(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
protocol := GetProtocol(testCase.selection)
|
||||
protocol := getProtocol(testCase.selection)
|
||||
|
||||
assert.Equal(t, testCase.protocol, protocol)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FilterByProtocol(t *testing.T) {
|
||||
func Test_filterByProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
@@ -127,7 +127,7 @@ func Test_FilterByProtocol(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
filtered := FilterByProtocol(testCase.selection,
|
||||
filtered := filterByProtocol(testCase.selection,
|
||||
testCase.serverTCP, testCase.serverUDP)
|
||||
|
||||
assert.Equal(t, testCase.filtered, filtered)
|
||||
|
||||
Reference in New Issue
Block a user