Maint: common GetProtocol for OpenVPN+Wireguard providers
This commit is contained in:
@@ -10,7 +10,7 @@ import (
|
|||||||
func (i *Ivpn) GetConnection(selection configuration.ServerSelection) (
|
func (i *Ivpn) GetConnection(selection configuration.ServerSelection) (
|
||||||
connection models.Connection, err error) {
|
connection models.Connection, err error) {
|
||||||
port := getPort(selection)
|
port := getPort(selection)
|
||||||
protocol := getProtocol(selection)
|
protocol := utils.GetProtocol(selection)
|
||||||
|
|
||||||
servers, err := i.filterServers(selection)
|
servers, err := i.filterServers(selection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -60,10 +60,3 @@ func getPort(selection configuration.ServerSelection) (port uint16) {
|
|||||||
return port
|
return port
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getProtocol(selection configuration.ServerSelection) (protocol string) {
|
|
||||||
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
|
|
||||||
return constants.TCP
|
|
||||||
}
|
|
||||||
return constants.UDP
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -158,42 +158,3 @@ func Test_getPort(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_getProtocol(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
testCases := map[string]struct {
|
|
||||||
selection configuration.ServerSelection
|
|
||||||
protocol string
|
|
||||||
}{
|
|
||||||
"OpenVPN UDP": {
|
|
||||||
protocol: constants.UDP,
|
|
||||||
},
|
|
||||||
"OpenVPN TCP": {
|
|
||||||
selection: configuration.ServerSelection{
|
|
||||||
VPN: constants.OpenVPN,
|
|
||||||
OpenVPN: configuration.OpenVPNSelection{
|
|
||||||
TCP: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
protocol: constants.TCP,
|
|
||||||
},
|
|
||||||
"Wireguard": {
|
|
||||||
selection: configuration.ServerSelection{
|
|
||||||
VPN: constants.Wireguard,
|
|
||||||
},
|
|
||||||
protocol: constants.UDP,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, testCase := range testCases {
|
|
||||||
testCase := testCase
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
protocol := getProtocol(testCase.selection)
|
|
||||||
|
|
||||||
assert.Equal(t, testCase.protocol, protocol)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
func (m *Mullvad) GetConnection(selection configuration.ServerSelection) (
|
func (m *Mullvad) GetConnection(selection configuration.ServerSelection) (
|
||||||
connection models.Connection, err error) {
|
connection models.Connection, err error) {
|
||||||
port := getPort(selection)
|
port := getPort(selection)
|
||||||
protocol := getProtocol(selection)
|
protocol := utils.GetProtocol(selection)
|
||||||
|
|
||||||
servers, err := m.filterServers(selection)
|
servers, err := m.filterServers(selection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -59,11 +59,3 @@ func getPort(selection configuration.ServerSelection) (port uint16) {
|
|||||||
return port
|
return port
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getProtocol(selection configuration.ServerSelection) (protocol string) {
|
|
||||||
protocol = constants.UDP
|
|
||||||
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
|
|
||||||
protocol = constants.TCP
|
|
||||||
}
|
|
||||||
return protocol
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -157,48 +157,3 @@ func Test_getPort(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_getProtocol(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
testCases := map[string]struct {
|
|
||||||
selection configuration.ServerSelection
|
|
||||||
protocol string
|
|
||||||
}{
|
|
||||||
"default": {
|
|
||||||
protocol: constants.UDP,
|
|
||||||
},
|
|
||||||
"OpenVPN UDP": {
|
|
||||||
selection: configuration.ServerSelection{
|
|
||||||
VPN: constants.OpenVPN,
|
|
||||||
},
|
|
||||||
protocol: constants.UDP,
|
|
||||||
},
|
|
||||||
"OpenVPN TCP": {
|
|
||||||
selection: configuration.ServerSelection{
|
|
||||||
VPN: constants.OpenVPN,
|
|
||||||
OpenVPN: configuration.OpenVPNSelection{
|
|
||||||
TCP: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
protocol: constants.TCP,
|
|
||||||
},
|
|
||||||
"Wireguard": {
|
|
||||||
selection: configuration.ServerSelection{
|
|
||||||
VPN: constants.Wireguard,
|
|
||||||
},
|
|
||||||
protocol: constants.UDP,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, testCase := range testCases {
|
|
||||||
testCase := testCase
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
protocol := getProtocol(testCase.selection)
|
|
||||||
|
|
||||||
assert.Equal(t, testCase.protocol, protocol)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
13
internal/provider/utils/protocol.go
Normal file
13
internal/provider/utils/protocol.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/qdm12/gluetun/internal/configuration"
|
||||||
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetProtocol(selection configuration.ServerSelection) (protocol string) {
|
||||||
|
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
|
||||||
|
return constants.TCP
|
||||||
|
}
|
||||||
|
return constants.UDP
|
||||||
|
}
|
||||||
54
internal/provider/utils/protocol_test.go
Normal file
54
internal/provider/utils/protocol_test.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/qdm12/gluetun/internal/configuration"
|
||||||
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_GetProtocol(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
selection configuration.ServerSelection
|
||||||
|
protocol string
|
||||||
|
}{
|
||||||
|
"default": {
|
||||||
|
protocol: constants.UDP,
|
||||||
|
},
|
||||||
|
"OpenVPN UDP": {
|
||||||
|
selection: configuration.ServerSelection{
|
||||||
|
VPN: constants.OpenVPN,
|
||||||
|
},
|
||||||
|
protocol: constants.UDP,
|
||||||
|
},
|
||||||
|
"OpenVPN TCP": {
|
||||||
|
selection: configuration.ServerSelection{
|
||||||
|
VPN: constants.OpenVPN,
|
||||||
|
OpenVPN: configuration.OpenVPNSelection{
|
||||||
|
TCP: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
protocol: constants.TCP,
|
||||||
|
},
|
||||||
|
"Wireguard": {
|
||||||
|
selection: configuration.ServerSelection{
|
||||||
|
VPN: constants.Wireguard,
|
||||||
|
},
|
||||||
|
protocol: constants.UDP,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
protocol := GetProtocol(testCase.selection)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.protocol, protocol)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
func (w *Windscribe) GetConnection(selection configuration.ServerSelection) (
|
func (w *Windscribe) GetConnection(selection configuration.ServerSelection) (
|
||||||
connection models.Connection, err error) {
|
connection models.Connection, err error) {
|
||||||
port := getPort(selection)
|
port := getPort(selection)
|
||||||
protocol := getProtocol(selection)
|
protocol := utils.GetProtocol(selection)
|
||||||
|
|
||||||
servers, err := w.filterServers(selection)
|
servers, err := w.filterServers(selection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -60,11 +60,3 @@ func getPort(selection configuration.ServerSelection) (port uint16) {
|
|||||||
return port
|
return port
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getProtocol(selection configuration.ServerSelection) (protocol string) {
|
|
||||||
protocol = constants.UDP
|
|
||||||
if selection.VPN == constants.OpenVPN && selection.OpenVPN.TCP {
|
|
||||||
protocol = constants.TCP
|
|
||||||
}
|
|
||||||
return protocol
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -157,48 +157,3 @@ func Test_getPort(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_getProtocol(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
testCases := map[string]struct {
|
|
||||||
selection configuration.ServerSelection
|
|
||||||
protocol string
|
|
||||||
}{
|
|
||||||
"default": {
|
|
||||||
protocol: constants.UDP,
|
|
||||||
},
|
|
||||||
"OpenVPN UDP": {
|
|
||||||
selection: configuration.ServerSelection{
|
|
||||||
VPN: constants.OpenVPN,
|
|
||||||
},
|
|
||||||
protocol: constants.UDP,
|
|
||||||
},
|
|
||||||
"OpenVPN TCP": {
|
|
||||||
selection: configuration.ServerSelection{
|
|
||||||
VPN: constants.OpenVPN,
|
|
||||||
OpenVPN: configuration.OpenVPNSelection{
|
|
||||||
TCP: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
protocol: constants.TCP,
|
|
||||||
},
|
|
||||||
"Wireguard": {
|
|
||||||
selection: configuration.ServerSelection{
|
|
||||||
VPN: constants.Wireguard,
|
|
||||||
},
|
|
||||||
protocol: constants.UDP,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, testCase := range testCases {
|
|
||||||
testCase := testCase
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
protocol := getProtocol(testCase.selection)
|
|
||||||
|
|
||||||
assert.Equal(t, testCase.protocol, protocol)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user