Maint: common GetProtocol for OpenVPN+Wireguard providers

This commit is contained in:
Quentin McGaw (desktop)
2021-08-23 16:07:47 +00:00
parent 06a2d79cb4
commit dbf5c569ea
8 changed files with 70 additions and 155 deletions

View File

@@ -10,7 +10,7 @@ import (
func (i *Ivpn) GetConnection(selection configuration.ServerSelection) ( func (i *Ivpn) GetConnection(selection configuration.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
port := getPort(selection) port := getPort(selection)
protocol := getProtocol(selection) protocol := utils.GetProtocol(selection)
servers, err := i.filterServers(selection) servers, err := i.filterServers(selection)
if err != nil { if err != nil {
@@ -60,10 +60,3 @@ func getPort(selection configuration.ServerSelection) (port uint16) {
return port return port
} }
} }
func getProtocol(selection configuration.ServerSelection) (protocol string) {
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
return constants.TCP
}
return constants.UDP
}

View File

@@ -158,42 +158,3 @@ func Test_getPort(t *testing.T) {
}) })
} }
} }
func Test_getProtocol(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
selection configuration.ServerSelection
protocol string
}{
"OpenVPN UDP": {
protocol: constants.UDP,
},
"OpenVPN TCP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
TCP: true,
},
},
protocol: constants.TCP,
},
"Wireguard": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
protocol: constants.UDP,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
protocol := getProtocol(testCase.selection)
assert.Equal(t, testCase.protocol, protocol)
})
}
}

View File

@@ -10,7 +10,7 @@ import (
func (m *Mullvad) GetConnection(selection configuration.ServerSelection) ( func (m *Mullvad) GetConnection(selection configuration.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
port := getPort(selection) port := getPort(selection)
protocol := getProtocol(selection) protocol := utils.GetProtocol(selection)
servers, err := m.filterServers(selection) servers, err := m.filterServers(selection)
if err != nil { if err != nil {
@@ -59,11 +59,3 @@ func getPort(selection configuration.ServerSelection) (port uint16) {
return port return port
} }
} }
func getProtocol(selection configuration.ServerSelection) (protocol string) {
protocol = constants.UDP
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
protocol = constants.TCP
}
return protocol
}

View File

@@ -157,48 +157,3 @@ func Test_getPort(t *testing.T) {
}) })
} }
} }
func Test_getProtocol(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
selection configuration.ServerSelection
protocol string
}{
"default": {
protocol: constants.UDP,
},
"OpenVPN UDP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
protocol: constants.UDP,
},
"OpenVPN TCP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
TCP: true,
},
},
protocol: constants.TCP,
},
"Wireguard": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
protocol: constants.UDP,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
protocol := getProtocol(testCase.selection)
assert.Equal(t, testCase.protocol, protocol)
})
}
}

View File

@@ -0,0 +1,13 @@
package utils
import (
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
)
func GetProtocol(selection configuration.ServerSelection) (protocol string) {
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
return constants.TCP
}
return constants.UDP
}

View File

@@ -0,0 +1,54 @@
package utils
import (
"testing"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/stretchr/testify/assert"
)
func Test_GetProtocol(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
selection configuration.ServerSelection
protocol string
}{
"default": {
protocol: constants.UDP,
},
"OpenVPN UDP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
protocol: constants.UDP,
},
"OpenVPN TCP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
TCP: true,
},
},
protocol: constants.TCP,
},
"Wireguard": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
protocol: constants.UDP,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
protocol := GetProtocol(testCase.selection)
assert.Equal(t, testCase.protocol, protocol)
})
}
}

View File

@@ -10,7 +10,7 @@ import (
func (w *Windscribe) GetConnection(selection configuration.ServerSelection) ( func (w *Windscribe) GetConnection(selection configuration.ServerSelection) (
connection models.Connection, err error) { connection models.Connection, err error) {
port := getPort(selection) port := getPort(selection)
protocol := getProtocol(selection) protocol := utils.GetProtocol(selection)
servers, err := w.filterServers(selection) servers, err := w.filterServers(selection)
if err != nil { if err != nil {
@@ -60,11 +60,3 @@ func getPort(selection configuration.ServerSelection) (port uint16) {
return port return port
} }
} }
func getProtocol(selection configuration.ServerSelection) (protocol string) {
protocol = constants.UDP
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
protocol = constants.TCP
}
return protocol
}

View File

@@ -157,48 +157,3 @@ func Test_getPort(t *testing.T) {
}) })
} }
} }
func Test_getProtocol(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
selection configuration.ServerSelection
protocol string
}{
"default": {
protocol: constants.UDP,
},
"OpenVPN UDP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
},
protocol: constants.UDP,
},
"OpenVPN TCP": {
selection: configuration.ServerSelection{
VPN: constants.OpenVPN,
OpenVPN: configuration.OpenVPNSelection{
TCP: true,
},
},
protocol: constants.TCP,
},
"Wireguard": {
selection: configuration.ServerSelection{
VPN: constants.Wireguard,
},
protocol: constants.UDP,
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
protocol := getProtocol(testCase.selection)
assert.Equal(t, testCase.protocol, protocol)
})
}
}