diff --git a/internal/provider/ivpn/connection.go b/internal/provider/ivpn/connection.go index cb65a74b..d507eab6 100644 --- a/internal/provider/ivpn/connection.go +++ b/internal/provider/ivpn/connection.go @@ -10,7 +10,7 @@ import ( func (i *Ivpn) GetConnection(selection configuration.ServerSelection) ( connection models.Connection, err error) { port := getPort(selection) - protocol := getProtocol(selection) + protocol := utils.GetProtocol(selection) servers, err := i.filterServers(selection) if err != nil { @@ -60,10 +60,3 @@ func getPort(selection configuration.ServerSelection) (port uint16) { return port } } - -func getProtocol(selection configuration.ServerSelection) (protocol string) { - if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP { - return constants.TCP - } - return constants.UDP -} diff --git a/internal/provider/ivpn/connection_test.go b/internal/provider/ivpn/connection_test.go index bbf634c8..3a45f25a 100644 --- a/internal/provider/ivpn/connection_test.go +++ b/internal/provider/ivpn/connection_test.go @@ -158,42 +158,3 @@ func Test_getPort(t *testing.T) { }) } } - -func Test_getProtocol(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - selection configuration.ServerSelection - protocol string - }{ - "OpenVPN UDP": { - protocol: constants.UDP, - }, - "OpenVPN TCP": { - selection: configuration.ServerSelection{ - VPN: constants.OpenVPN, - OpenVPN: configuration.OpenVPNSelection{ - TCP: true, - }, - }, - protocol: constants.TCP, - }, - "Wireguard": { - selection: configuration.ServerSelection{ - VPN: constants.Wireguard, - }, - protocol: constants.UDP, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - protocol := getProtocol(testCase.selection) - - assert.Equal(t, testCase.protocol, protocol) - }) - } -} diff --git a/internal/provider/mullvad/connection.go b/internal/provider/mullvad/connection.go index c22efb02..694d4a12 100644 --- a/internal/provider/mullvad/connection.go +++ b/internal/provider/mullvad/connection.go @@ -10,7 +10,7 @@ import ( func (m *Mullvad) GetConnection(selection configuration.ServerSelection) ( connection models.Connection, err error) { port := getPort(selection) - protocol := getProtocol(selection) + protocol := utils.GetProtocol(selection) servers, err := m.filterServers(selection) if err != nil { @@ -59,11 +59,3 @@ func getPort(selection configuration.ServerSelection) (port uint16) { return port } } - -func getProtocol(selection configuration.ServerSelection) (protocol string) { - protocol = constants.UDP - if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP { - protocol = constants.TCP - } - return protocol -} diff --git a/internal/provider/mullvad/connection_test.go b/internal/provider/mullvad/connection_test.go index 88c67349..864e41d5 100644 --- a/internal/provider/mullvad/connection_test.go +++ b/internal/provider/mullvad/connection_test.go @@ -157,48 +157,3 @@ func Test_getPort(t *testing.T) { }) } } - -func Test_getProtocol(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - selection configuration.ServerSelection - protocol string - }{ - "default": { - protocol: constants.UDP, - }, - "OpenVPN UDP": { - selection: configuration.ServerSelection{ - VPN: constants.OpenVPN, - }, - protocol: constants.UDP, - }, - "OpenVPN TCP": { - selection: configuration.ServerSelection{ - VPN: constants.OpenVPN, - OpenVPN: configuration.OpenVPNSelection{ - TCP: true, - }, - }, - protocol: constants.TCP, - }, - "Wireguard": { - selection: configuration.ServerSelection{ - VPN: constants.Wireguard, - }, - protocol: constants.UDP, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - protocol := getProtocol(testCase.selection) - - assert.Equal(t, testCase.protocol, protocol) - }) - } -} diff --git a/internal/provider/utils/protocol.go b/internal/provider/utils/protocol.go new file mode 100644 index 00000000..518c99cc --- /dev/null +++ b/internal/provider/utils/protocol.go @@ -0,0 +1,13 @@ +package utils + +import ( + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/constants" +) + +func GetProtocol(selection configuration.ServerSelection) (protocol string) { + if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP { + return constants.TCP + } + return constants.UDP +} diff --git a/internal/provider/utils/protocol_test.go b/internal/provider/utils/protocol_test.go new file mode 100644 index 00000000..08cdedf0 --- /dev/null +++ b/internal/provider/utils/protocol_test.go @@ -0,0 +1,54 @@ +package utils + +import ( + "testing" + + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/constants" + "github.com/stretchr/testify/assert" +) + +func Test_GetProtocol(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + selection configuration.ServerSelection + protocol string + }{ + "default": { + protocol: constants.UDP, + }, + "OpenVPN UDP": { + selection: configuration.ServerSelection{ + VPN: constants.OpenVPN, + }, + protocol: constants.UDP, + }, + "OpenVPN TCP": { + selection: configuration.ServerSelection{ + VPN: constants.OpenVPN, + OpenVPN: configuration.OpenVPNSelection{ + TCP: true, + }, + }, + protocol: constants.TCP, + }, + "Wireguard": { + selection: configuration.ServerSelection{ + VPN: constants.Wireguard, + }, + protocol: constants.UDP, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + protocol := GetProtocol(testCase.selection) + + assert.Equal(t, testCase.protocol, protocol) + }) + } +} diff --git a/internal/provider/windscribe/connection.go b/internal/provider/windscribe/connection.go index 8f565a4d..bf589e03 100644 --- a/internal/provider/windscribe/connection.go +++ b/internal/provider/windscribe/connection.go @@ -10,7 +10,7 @@ import ( func (w *Windscribe) GetConnection(selection configuration.ServerSelection) ( connection models.Connection, err error) { port := getPort(selection) - protocol := getProtocol(selection) + protocol := utils.GetProtocol(selection) servers, err := w.filterServers(selection) if err != nil { @@ -60,11 +60,3 @@ func getPort(selection configuration.ServerSelection) (port uint16) { return port } } - -func getProtocol(selection configuration.ServerSelection) (protocol string) { - protocol = constants.UDP - if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP { - protocol = constants.TCP - } - return protocol -} diff --git a/internal/provider/windscribe/connection_test.go b/internal/provider/windscribe/connection_test.go index d6521130..aa4fafdc 100644 --- a/internal/provider/windscribe/connection_test.go +++ b/internal/provider/windscribe/connection_test.go @@ -157,48 +157,3 @@ func Test_getPort(t *testing.T) { }) } } - -func Test_getProtocol(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - selection configuration.ServerSelection - protocol string - }{ - "default": { - protocol: constants.UDP, - }, - "OpenVPN UDP": { - selection: configuration.ServerSelection{ - VPN: constants.OpenVPN, - }, - protocol: constants.UDP, - }, - "OpenVPN TCP": { - selection: configuration.ServerSelection{ - VPN: constants.OpenVPN, - OpenVPN: configuration.OpenVPNSelection{ - TCP: true, - }, - }, - protocol: constants.TCP, - }, - "Wireguard": { - selection: configuration.ServerSelection{ - VPN: constants.Wireguard, - }, - protocol: constants.UDP, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - protocol := getProtocol(testCase.selection) - - assert.Equal(t, testCase.protocol, protocol) - }) - } -}