diff --git a/internal/configuration/cyberghost.go b/internal/configuration/cyberghost.go index 91559788..567a53aa 100644 --- a/internal/configuration/cyberghost.go +++ b/internal/configuration/cyberghost.go @@ -34,7 +34,7 @@ func (settings *Provider) cyberghostLines() (lines []string) { func (settings *Provider) readCyberghost(r reader) (err error) { settings.Name = constants.Cyberghost - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } diff --git a/internal/configuration/fastestvpn.go b/internal/configuration/fastestvpn.go index b1a0f2d2..5bd67cf1 100644 --- a/internal/configuration/fastestvpn.go +++ b/internal/configuration/fastestvpn.go @@ -19,7 +19,7 @@ func (settings *Provider) fastestvpnLines() (lines []string) { func (settings *Provider) readFastestvpn(r reader) (err error) { settings.Name = constants.Fastestvpn - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } diff --git a/internal/configuration/hidemyass.go b/internal/configuration/hidemyass.go index f82b2504..9ec03051 100644 --- a/internal/configuration/hidemyass.go +++ b/internal/configuration/hidemyass.go @@ -27,7 +27,7 @@ func (settings *Provider) hideMyAssLines() (lines []string) { func (settings *Provider) readHideMyAss(r reader) (err error) { settings.Name = constants.HideMyAss - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } diff --git a/internal/configuration/mullvad.go b/internal/configuration/mullvad.go index 2dd1ccbf..906b4fe4 100644 --- a/internal/configuration/mullvad.go +++ b/internal/configuration/mullvad.go @@ -38,7 +38,7 @@ func (settings *Provider) mullvadLines() (lines []string) { func (settings *Provider) readMullvad(r reader) (err error) { settings.Name = constants.Mullvad - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } @@ -68,7 +68,7 @@ func (settings *Provider) readMullvad(r reader) (err error) { return err } - settings.ServerSelection.CustomPort, err = readCustomPort(r.env, settings.ServerSelection.Protocol, + settings.ServerSelection.CustomPort, err = readCustomPort(r.env, settings.ServerSelection.TCP, []uint16{80, 443, 1401}, []uint16{53, 1194, 1195, 1196, 1197, 1300, 1301, 1302, 1303, 1400}) if err != nil { return err diff --git a/internal/configuration/nordvpn.go b/internal/configuration/nordvpn.go index 029214a9..89a9b3ed 100644 --- a/internal/configuration/nordvpn.go +++ b/internal/configuration/nordvpn.go @@ -35,7 +35,7 @@ func (settings *Provider) nordvpnLines() (lines []string) { func (settings *Provider) readNordvpn(r reader) (err error) { settings.Name = constants.Nordvpn - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } diff --git a/internal/configuration/openvpn_test.go b/internal/configuration/openvpn_test.go index 24e6f806..ac6eb835 100644 --- a/internal/configuration/openvpn_test.go +++ b/internal/configuration/openvpn_test.go @@ -29,7 +29,7 @@ func Test_OpenVPN_JSON(t *testing.T) { "provider": { "name": "name", "server_selection": { - "network_protocol": "", + "tcp": false, "regions": null, "group": "", "countries": null, diff --git a/internal/configuration/privado.go b/internal/configuration/privado.go index ddbe99c1..67b44cd9 100644 --- a/internal/configuration/privado.go +++ b/internal/configuration/privado.go @@ -27,7 +27,7 @@ func (settings *Provider) privadoLines() (lines []string) { func (settings *Provider) readPrivado(r reader) (err error) { settings.Name = constants.Privado - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } diff --git a/internal/configuration/privateinternetaccess.go b/internal/configuration/privateinternetaccess.go index 288a15eb..91536410 100644 --- a/internal/configuration/privateinternetaccess.go +++ b/internal/configuration/privateinternetaccess.go @@ -37,7 +37,7 @@ func (settings *Provider) privateinternetaccessLines() (lines []string) { func (settings *Provider) readPrivateInternetAccess(r reader) (err error) { settings.Name = constants.PrivateInternetAccess - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } diff --git a/internal/configuration/privatevpn.go b/internal/configuration/privatevpn.go index 65aac63b..a980e3dd 100644 --- a/internal/configuration/privatevpn.go +++ b/internal/configuration/privatevpn.go @@ -23,7 +23,7 @@ func (settings *Provider) privatevpnLines() (lines []string) { func (settings *Provider) readPrivatevpn(r reader) (err error) { settings.Name = constants.Privatevpn - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } diff --git a/internal/configuration/protonvpn.go b/internal/configuration/protonvpn.go index f9b167ce..8d1f364f 100644 --- a/internal/configuration/protonvpn.go +++ b/internal/configuration/protonvpn.go @@ -31,7 +31,7 @@ func (settings *Provider) protonvpnLines() (lines []string) { func (settings *Provider) readProtonvpn(r reader) (err error) { settings.Name = constants.Protonvpn - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } diff --git a/internal/configuration/provider.go b/internal/configuration/provider.go index 5c038cc1..e3ffeb83 100644 --- a/internal/configuration/provider.go +++ b/internal/configuration/provider.go @@ -1,7 +1,6 @@ package configuration import ( - "errors" "fmt" "net" "strings" @@ -21,10 +20,12 @@ type Provider struct { func (settings *Provider) lines() (lines []string) { lines = append(lines, lastIndent+strings.Title(settings.Name)+" settings:") - lines = append(lines, indent+lastIndent+"Network protocol: "+settings.ServerSelection.Protocol) + selection := settings.ServerSelection - if settings.ServerSelection.TargetIP != nil { - lines = append(lines, indent+lastIndent+"Target IP address: "+settings.ServerSelection.TargetIP.String()) + lines = append(lines, indent+lastIndent+"Network protocol: "+protoToString(selection.TCP)) + + if selection.TargetIP != nil { + lines = append(lines, indent+lastIndent+"Target IP address: "+selection.TargetIP.String()) } var providerLines []string @@ -73,19 +74,26 @@ func commaJoin(slice []string) string { return strings.Join(slice, ", ") } -func readProtocol(env params.Env) (protocol string, err error) { - return env.Inside("PROTOCOL", []string{constants.TCP, constants.UDP}, params.Default(constants.UDP)) +func readProtocol(env params.Env) (tcp bool, err error) { + protocol, err := env.Inside("PROTOCOL", []string{constants.TCP, constants.UDP}, params.Default(constants.UDP)) + if err != nil { + return false, err + } + return protocol == constants.TCP, nil +} + +func protoToString(tcp bool) string { + if tcp { + return constants.TCP + } + return constants.UDP } func readTargetIP(env params.Env) (targetIP net.IP, err error) { return readIP(env, "OPENVPN_TARGET_IP") } -var ( - ErrInvalidProtocol = errors.New("invalid network protocol") -) - -func readCustomPort(env params.Env, protocol string, +func readCustomPort(env params.Env, tcp bool, allowedTCP, allowedUDP []uint16) (port uint16, err error) { port, err = readPortOrZero(env, "PORT") if err != nil { @@ -94,22 +102,18 @@ func readCustomPort(env params.Env, protocol string, return 0, nil } - switch protocol { - case constants.TCP: + if tcp { for i := range allowedTCP { if allowedTCP[i] == port { return port, nil } } return 0, fmt.Errorf("%w: port %d for TCP protocol", ErrInvalidPort, port) - case constants.UDP: - for i := range allowedUDP { - if allowedUDP[i] == port { - return port, nil - } - } - return 0, fmt.Errorf("%w: port %d for UDP protocol", ErrInvalidPort, port) - default: - return 0, fmt.Errorf("%w: %s", ErrInvalidProtocol, protocol) } + for i := range allowedUDP { + if allowedUDP[i] == port { + return port, nil + } + } + return 0, fmt.Errorf("%w: port %d for UDP protocol", ErrInvalidPort, port) } diff --git a/internal/configuration/provider_test.go b/internal/configuration/provider_test.go index 8ea580ba..aa4b2e25 100644 --- a/internal/configuration/provider_test.go +++ b/internal/configuration/provider_test.go @@ -24,9 +24,8 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.Cyberghost, ServerSelection: ServerSelection{ - Protocol: constants.UDP, - Group: "group", - Regions: []string{"a", "El country"}, + Group: "group", + Regions: []string{"a", "El country"}, }, ExtraConfigOptions: ExtraConfigOptions{ ClientKey: "a", @@ -46,7 +45,6 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.Fastestvpn, ServerSelection: ServerSelection{ - Protocol: constants.UDP, Hostnames: []string{"a", "b"}, Countries: []string{"c", "d"}, }, @@ -62,7 +60,6 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.HideMyAss, ServerSelection: ServerSelection{ - Protocol: constants.UDP, Countries: []string{"a", "b"}, Cities: []string{"c", "d"}, Hostnames: []string{"e", "f"}, @@ -80,7 +77,6 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.Mullvad, ServerSelection: ServerSelection{ - Protocol: constants.UDP, Countries: []string{"a", "b"}, Cities: []string{"c", "d"}, ISPs: []string{"e", "f"}, @@ -104,9 +100,8 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.Nordvpn, ServerSelection: ServerSelection{ - Protocol: constants.UDP, - Regions: []string{"a", "b"}, - Numbers: []uint16{1, 2}, + Regions: []string{"a", "b"}, + Numbers: []uint16{1, 2}, }, }, lines: []string{ @@ -120,7 +115,6 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.Privado, ServerSelection: ServerSelection{ - Protocol: constants.UDP, Hostnames: []string{"a", "b"}, }, }, @@ -134,7 +128,6 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.Privatevpn, ServerSelection: ServerSelection{ - Protocol: constants.UDP, Hostnames: []string{"a", "b"}, Countries: []string{"c", "d"}, Cities: []string{"e", "f"}, @@ -152,7 +145,6 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.Protonvpn, ServerSelection: ServerSelection{ - Protocol: constants.UDP, Countries: []string{"a", "b"}, Regions: []string{"c", "d"}, Cities: []string{"e", "f"}, @@ -174,7 +166,6 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.PrivateInternetAccess, ServerSelection: ServerSelection{ - Protocol: constants.UDP, Regions: []string{"a", "b"}, EncryptionPreset: constants.PIAEncryptionPresetStrong, CustomPort: 1, @@ -198,7 +189,6 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.Purevpn, ServerSelection: ServerSelection{ - Protocol: constants.UDP, Regions: []string{"a", "b"}, Countries: []string{"c", "d"}, Cities: []string{"e", "f"}, @@ -216,8 +206,7 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.Surfshark, ServerSelection: ServerSelection{ - Protocol: constants.UDP, - Regions: []string{"a", "b"}, + Regions: []string{"a", "b"}, }, }, lines: []string{ @@ -230,7 +219,6 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.Torguard, ServerSelection: ServerSelection{ - Protocol: constants.UDP, Countries: []string{"a", "b"}, Cities: []string{"c", "d"}, Hostnames: []string{"e"}, @@ -248,8 +236,7 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.Vyprvpn, ServerSelection: ServerSelection{ - Protocol: constants.UDP, - Regions: []string{"a", "b"}, + Regions: []string{"a", "b"}, }, }, lines: []string{ @@ -262,7 +249,6 @@ func Test_Provider_lines(t *testing.T) { settings: Provider{ Name: constants.Windscribe, ServerSelection: ServerSelection{ - Protocol: constants.UDP, Regions: []string{"a", "b"}, Cities: []string{"c", "d"}, Hostnames: []string{"e", "f"}, @@ -296,18 +282,18 @@ func Test_readProtocol(t *testing.T) { t.Parallel() testCases := map[string]struct { - mockStr string - mockErr error - protocol string - err error + mockStr string + mockErr error + tcp bool + err error }{ "error": { mockErr: errDummy, err: errDummy, }, "success": { - mockStr: "tcp", - protocol: constants.TCP, + mockStr: "tcp", + tcp: true, }, } @@ -322,7 +308,7 @@ func Test_readProtocol(t *testing.T) { Inside("PROTOCOL", []string{"tcp", "udp"}, gomock.Any()). Return(testCase.mockStr, testCase.mockErr) - protocol, err := readProtocol(env) + tcp, err := readProtocol(env) if testCase.err != nil { require.Error(t, err) @@ -331,7 +317,7 @@ func Test_readProtocol(t *testing.T) { assert.NoError(t, err) } - assert.Equal(t, testCase.protocol, protocol) + assert.Equal(t, testCase.tcp, tcp) }) } } diff --git a/internal/configuration/purevpn.go b/internal/configuration/purevpn.go index 3c80a28b..238a9792 100644 --- a/internal/configuration/purevpn.go +++ b/internal/configuration/purevpn.go @@ -27,7 +27,7 @@ func (settings *Provider) purevpnLines() (lines []string) { func (settings *Provider) readPurevpn(r reader) (err error) { settings.Name = constants.Purevpn - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } diff --git a/internal/configuration/selection.go b/internal/configuration/selection.go index 34ca7c68..6d0624e9 100644 --- a/internal/configuration/selection.go +++ b/internal/configuration/selection.go @@ -4,9 +4,9 @@ import ( "net" ) -type ServerSelection struct { +type ServerSelection struct { //nolint:maligned // Common - Protocol string `json:"network_protocol"` + TCP bool `json:"tcp"` // UDP if TCP is false TargetIP net.IP `json:"target_ip,omitempty"` // TODO comments // Cyberghost, PIA, Protonvpn, Surfshark, Windscribe, Vyprvpn, NordVPN diff --git a/internal/configuration/settings_test.go b/internal/configuration/settings_test.go index cb621e06..59890599 100644 --- a/internal/configuration/settings_test.go +++ b/internal/configuration/settings_test.go @@ -28,7 +28,7 @@ func Test_Settings_lines(t *testing.T) { " |--Verbosity level: 0", " |--Provider:", " |--Mullvad settings:", - " |--Network protocol: ", + " |--Network protocol: udp", "|--DNS:", "|--Firewall: disabled ⚠️", "|--System:", diff --git a/internal/configuration/surfshark.go b/internal/configuration/surfshark.go index e8ef31b9..0115614b 100644 --- a/internal/configuration/surfshark.go +++ b/internal/configuration/surfshark.go @@ -19,7 +19,7 @@ func (settings *Provider) surfsharkLines() (lines []string) { func (settings *Provider) readSurfshark(r reader) (err error) { settings.Name = constants.Surfshark - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } diff --git a/internal/configuration/torguard.go b/internal/configuration/torguard.go index f34a6648..a91d8635 100644 --- a/internal/configuration/torguard.go +++ b/internal/configuration/torguard.go @@ -23,7 +23,7 @@ func (settings *Provider) torguardLines() (lines []string) { func (settings *Provider) readTorguard(r reader) (err error) { settings.Name = constants.Torguard - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } diff --git a/internal/configuration/vyprvpn.go b/internal/configuration/vyprvpn.go index 6b09580b..68c94b92 100644 --- a/internal/configuration/vyprvpn.go +++ b/internal/configuration/vyprvpn.go @@ -15,7 +15,7 @@ func (settings *Provider) vyprvpnLines() (lines []string) { func (settings *Provider) readVyprvpn(r reader) (err error) { settings.Name = constants.Vyprvpn - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } diff --git a/internal/configuration/windscribe.go b/internal/configuration/windscribe.go index 115f7deb..fc0f931f 100644 --- a/internal/configuration/windscribe.go +++ b/internal/configuration/windscribe.go @@ -27,7 +27,7 @@ func (settings *Provider) windscribeLines() (lines []string) { func (settings *Provider) readWindscribe(r reader) (err error) { settings.Name = constants.Windscribe - settings.ServerSelection.Protocol, err = readProtocol(r.env) + settings.ServerSelection.TCP, err = readProtocol(r.env) if err != nil { return err } @@ -52,7 +52,7 @@ func (settings *Provider) readWindscribe(r reader) (err error) { return err } - settings.ServerSelection.CustomPort, err = readCustomPort(r.env, settings.ServerSelection.Protocol, + settings.ServerSelection.CustomPort, err = readCustomPort(r.env, settings.ServerSelection.TCP, []uint16{21, 22, 80, 123, 143, 443, 587, 1194, 3306, 8080, 54783}, []uint16{53, 80, 123, 443, 1194, 54783}) if err != nil { diff --git a/internal/provider/cyberghost.go b/internal/provider/cyberghost.go index fe336597..62a61591 100644 --- a/internal/provider/cyberghost.go +++ b/internal/provider/cyberghost.go @@ -45,8 +45,9 @@ func (c *cyberghost) filterServers(regions, hostnames []string, group string) (s func (c *cyberghost) GetOpenVPNConnection(selection configuration.ServerSelection) ( connection models.OpenVPNConnection, err error) { const httpsPort = 443 + protocol := tcpBoolToProtocol(selection.TCP) if selection.TargetIP != nil { - return models.OpenVPNConnection{IP: selection.TargetIP, Port: httpsPort, Protocol: selection.Protocol}, nil + return models.OpenVPNConnection{IP: selection.TargetIP, Port: httpsPort, Protocol: protocol}, nil } servers := c.filterServers(selection.Regions, selection.Hostnames, selection.Group) @@ -58,7 +59,7 @@ func (c *cyberghost) GetOpenVPNConnection(selection configuration.ServerSelectio var connections []models.OpenVPNConnection for _, server := range servers { for _, IP := range server.IPs { - connections = append(connections, models.OpenVPNConnection{IP: IP, Port: httpsPort, Protocol: selection.Protocol}) + connections = append(connections, models.OpenVPNConnection{IP: IP, Port: httpsPort, Protocol: protocol}) } } diff --git a/internal/provider/fastestvpn.go b/internal/provider/fastestvpn.go index e2c91031..26876389 100644 --- a/internal/provider/fastestvpn.go +++ b/internal/provider/fastestvpn.go @@ -28,20 +28,13 @@ func newFastestvpn(servers []models.FastestvpnServer, timeNow timeNowFunc) *fast } } -func (f *fastestvpn) filterServers(countries, hostnames []string, protocol string) (servers []models.FastestvpnServer) { - var tcp, udp bool - if protocol == "tcp" { - tcp = true - } else { - udp = true - } - +func (f *fastestvpn) filterServers(countries, hostnames []string, tcp bool) (servers []models.FastestvpnServer) { for _, server := range f.servers { switch { case filterByPossibilities(server.Country, countries): case filterByPossibilities(server.Hostname, hostnames): case tcp && !server.TCP: - case udp && !server.UDP: + case !tcp && !server.UDP: default: servers = append(servers, server) } @@ -50,7 +43,7 @@ func (f *fastestvpn) filterServers(countries, hostnames []string, protocol strin } func (f *fastestvpn) notFoundErr(selection configuration.ServerSelection) error { - message := "no server found for protocol " + selection.Protocol + message := "no server found for protocol " + tcpBoolToProtocol(selection.TCP) if len(selection.Hostnames) > 0 { message += " + hostnames " + commaJoin(selection.Hostnames) @@ -66,12 +59,13 @@ func (f *fastestvpn) notFoundErr(selection configuration.ServerSelection) error func (f *fastestvpn) GetOpenVPNConnection(selection configuration.ServerSelection) ( connection models.OpenVPNConnection, err error) { var port uint16 = 4443 + protocol := tcpBoolToProtocol(selection.TCP) if selection.TargetIP != nil { - return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: selection.Protocol}, nil + return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: protocol}, nil } - servers := f.filterServers(selection.Countries, selection.Hostnames, selection.Protocol) + servers := f.filterServers(selection.Countries, selection.Hostnames, selection.TCP) if len(servers) == 0 { return connection, f.notFoundErr(selection) } @@ -82,7 +76,7 @@ func (f *fastestvpn) GetOpenVPNConnection(selection configuration.ServerSelectio connection := models.OpenVPNConnection{ IP: IP, Port: port, - Protocol: selection.Protocol, + Protocol: protocol, } connections = append(connections, connection) } diff --git a/internal/provider/hidemyass.go b/internal/provider/hidemyass.go index 2bbc4514..8f2daa9e 100644 --- a/internal/provider/hidemyass.go +++ b/internal/provider/hidemyass.go @@ -30,15 +30,15 @@ func newHideMyAss(servers []models.HideMyAssServer, timeNow timeNowFunc) *hideMy } func (h *hideMyAss) filterServers(countries, cities, hostnames []string, - protocol string) (servers []models.HideMyAssServer) { + tcp bool) (servers []models.HideMyAssServer) { for _, server := range h.servers { switch { case filterByPossibilities(server.Country, countries), filterByPossibilities(server.City, cities), filterByPossibilities(server.Hostname, hostnames), - protocol == constants.TCP && !server.TCP, - protocol == constants.UDP && !server.UDP: + tcp && !server.TCP, + !tcp && !server.UDP: default: servers = append(servers, server) } @@ -67,7 +67,9 @@ func (h *hideMyAss) notFoundErr(selection configuration.ServerSelection) error { func (h *hideMyAss) GetOpenVPNConnection(selection configuration.ServerSelection) ( connection models.OpenVPNConnection, err error) { var defaultPort uint16 = 553 - if selection.Protocol == constants.TCP { + protocol := constants.UDP + if selection.TCP { + protocol = constants.TCP defaultPort = 8080 } port := defaultPort @@ -76,10 +78,14 @@ func (h *hideMyAss) GetOpenVPNConnection(selection configuration.ServerSelection } if selection.TargetIP != nil { - return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: selection.Protocol}, nil + return models.OpenVPNConnection{ + IP: selection.TargetIP, + Port: port, + Protocol: protocol, + }, nil } - servers := h.filterServers(selection.Countries, selection.Cities, selection.Hostnames, selection.Protocol) + servers := h.filterServers(selection.Countries, selection.Cities, selection.Hostnames, selection.TCP) if len(servers) == 0 { return models.OpenVPNConnection{}, h.notFoundErr(selection) } @@ -87,7 +93,11 @@ func (h *hideMyAss) GetOpenVPNConnection(selection configuration.ServerSelection var connections []models.OpenVPNConnection for _, server := range servers { for _, IP := range server.IPs { - connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + connections = append(connections, models.OpenVPNConnection{ + IP: IP, + Port: port, + Protocol: protocol, + }) } } diff --git a/internal/provider/mullvad.go b/internal/provider/mullvad.go index 57633c2e..9458ebdb 100644 --- a/internal/provider/mullvad.go +++ b/internal/provider/mullvad.go @@ -48,8 +48,10 @@ func (m *mullvad) filterServers(countries, cities, hostnames, func (m *mullvad) GetOpenVPNConnection(selection configuration.ServerSelection) ( connection models.OpenVPNConnection, err error) { var defaultPort uint16 = 1194 - if selection.Protocol == constants.TCP { + protocol := constants.UDP + if selection.TCP { defaultPort = 443 + protocol = constants.TCP } port := defaultPort if selection.CustomPort > 0 { @@ -57,7 +59,7 @@ func (m *mullvad) GetOpenVPNConnection(selection configuration.ServerSelection) } if selection.TargetIP != nil { - return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: selection.Protocol}, nil + return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: protocol}, nil } servers := m.filterServers(selection.Countries, selection.Cities, @@ -70,7 +72,7 @@ func (m *mullvad) GetOpenVPNConnection(selection configuration.ServerSelection) var connections []models.OpenVPNConnection for _, server := range servers { for _, IP := range server.IPs { - connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: protocol}) } } diff --git a/internal/provider/nordvpn.go b/internal/provider/nordvpn.go index 71f9f59b..e4780e10 100644 --- a/internal/provider/nordvpn.go +++ b/internal/provider/nordvpn.go @@ -29,7 +29,7 @@ func newNordvpn(servers []models.NordvpnServer, timeNow timeNowFunc) *nordvpn { } } -func (n *nordvpn) filterServers(regions, hostnames, names []string, numbers []uint16, protocol string) ( +func (n *nordvpn) filterServers(regions, hostnames, names []string, numbers []uint16, tcp bool) ( servers []models.NordvpnServer) { numbersStr := make([]string, len(numbers)) for i := range numbers { @@ -39,8 +39,8 @@ func (n *nordvpn) filterServers(regions, hostnames, names []string, numbers []ui numberStr := fmt.Sprintf("%d", server.Number) switch { case - protocol == constants.TCP && !server.TCP, - protocol == constants.UDP && !server.UDP, + tcp && !server.TCP, + !tcp && !server.UDP, filterByPossibilities(server.Region, regions), filterByPossibilities(server.Hostname, hostnames), filterByPossibilities(server.Name, names), @@ -55,7 +55,7 @@ func (n *nordvpn) filterServers(regions, hostnames, names []string, numbers []ui var errNoServerFound = errors.New("no server found") func (n *nordvpn) notFoundErr(selection configuration.ServerSelection) error { - message := "for protocol " + selection.Protocol + message := "for protocol " + tcpBoolToProtocol(selection.TCP) if len(selection.Regions) > 0 { message += " + regions " + commaJoin(selection.Regions) @@ -82,29 +82,26 @@ func (n *nordvpn) notFoundErr(selection configuration.ServerSelection) error { func (n *nordvpn) GetOpenVPNConnection(selection configuration.ServerSelection) ( connection models.OpenVPNConnection, err error) { - var port uint16 - switch { - case selection.Protocol == constants.UDP: - port = 1194 - case selection.Protocol == constants.TCP: + var port uint16 = 1194 + protocol := constants.UDP + if selection.TCP { port = 443 - default: - return connection, fmt.Errorf("protocol %q is unknown", selection.Protocol) + protocol = constants.TCP } if selection.TargetIP != nil { - return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: selection.Protocol}, nil + return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: protocol}, nil } servers := n.filterServers(selection.Regions, selection.Hostnames, - selection.Names, selection.Numbers, selection.Protocol) + selection.Names, selection.Numbers, selection.TCP) if len(servers) == 0 { return connection, n.notFoundErr(selection) } connections := make([]models.OpenVPNConnection, len(servers)) for i := range servers { - connections[i] = models.OpenVPNConnection{IP: servers[i].IP, Port: port, Protocol: selection.Protocol} + connections[i] = models.OpenVPNConnection{IP: servers[i].IP, Port: port, Protocol: protocol} } return pickRandomConnection(connections, n.randSource), nil diff --git a/internal/provider/piav4.go b/internal/provider/piav4.go index 50ddd42a..9be2ce4a 100644 --- a/internal/provider/piav4.go +++ b/internal/provider/piav4.go @@ -47,15 +47,14 @@ var ( func (p *pia) getPort(selection configuration.ServerSelection) (port uint16, err error) { if selection.CustomPort == 0 { - switch selection.Protocol { - case constants.TCP: + if selection.TCP { switch selection.EncryptionPreset { case constants.PIAEncryptionPresetNormal: port = 502 case constants.PIAEncryptionPresetStrong: port = 501 } - case constants.UDP: + } else { switch selection.EncryptionPreset { case constants.PIAEncryptionPresetNormal: port = 1198 @@ -63,38 +62,28 @@ func (p *pia) getPort(selection configuration.ServerSelection) (port uint16, err port = 1197 } } - - if port == 0 { - return 0, fmt.Errorf( - "%w: combination of protocol %q and encryption %q does not yield any port number", - ErrInvalidPort, selection.Protocol, selection.EncryptionPreset) - } return port, nil } port = selection.CustomPort - switch selection.Protocol { - case constants.TCP: + if selection.TCP { switch port { case 80, 110, 443: //nolint:gomnd + return port, nil default: - return 0, fmt.Errorf("%w: %d for protocol %s", - ErrInvalidPort, port, selection.Protocol) - } - case constants.UDP: - switch port { - case 53, 1194, 1197, 1198, 8080, 9201: //nolint:gomnd - default: - return 0, fmt.Errorf("%w: %d for protocol %s", - ErrInvalidPort, port, selection.Protocol) + return 0, fmt.Errorf("%w: %d for protocol TCP", ErrInvalidPort, port) } } - - return port, nil + switch port { + case 53, 1194, 1197, 1198, 8080, 9201: //nolint:gomnd + return port, nil + default: + return 0, fmt.Errorf("%w: %d for protocol UDP", ErrInvalidPort, port) + } } -func (p *pia) notFoundErr(regions, hostnames, names []string, protocol string) error { - message := "for protocol " + protocol +func (p *pia) notFoundErr(regions, hostnames, names []string, tcp bool) error { + message := "for protocol " + tcpBoolToProtocol(tcp) if len(regions) > 0 { message += " + regions " + commaJoin(regions) @@ -118,15 +107,17 @@ func (p *pia) GetOpenVPNConnection(selection configuration.ServerSelection) ( return connection, err } + protocol := tcpBoolToProtocol(selection.TCP) + servers := p.servers if selection.TargetIP != nil { - connection = models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: selection.Protocol} + connection = models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: protocol} } else { servers := p.filterServers(selection.Regions, selection.Hostnames, - selection.Names, selection.Protocol) + selection.Names, selection.TCP) if len(servers) == 0 { return connection, p.notFoundErr(selection.Regions, selection.Hostnames, - selection.Names, selection.Protocol) + selection.Names, selection.TCP) } var connections []models.OpenVPNConnection @@ -135,7 +126,7 @@ func (p *pia) GetOpenVPNConnection(selection configuration.ServerSelection) ( connection := models.OpenVPNConnection{ IP: ip, Port: port, - Protocol: selection.Protocol, + Protocol: protocol, } connections = append(connections, connection) } @@ -369,15 +360,15 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client, } } -func (p *pia) filterServers(regions, hostnames, names []string, protocol string) ( +func (p *pia) filterServers(regions, hostnames, names []string, tcp bool) ( filtered []models.PIAServer) { for _, server := range p.servers { switch { case filterByPossibilities(server.Region, regions), filterByPossibilities(server.Hostname, hostnames), filterByPossibilities(server.ServerName, names), - protocol == constants.TCP && !server.TCP, - protocol == constants.UDP && !server.UDP: + tcp && !server.TCP, + !tcp && !server.UDP: default: filtered = append(filtered, server) } diff --git a/internal/provider/privado.go b/internal/provider/privado.go index 436a6e8e..95127034 100644 --- a/internal/provider/privado.go +++ b/internal/provider/privado.go @@ -2,6 +2,7 @@ package provider import ( "context" + "errors" "fmt" "math/rand" "net" @@ -67,17 +68,22 @@ func (p *privado) notFoundErr(countries, regions, cities, hostnames []string) er return fmt.Errorf("%w: %s", errNoServerFound, message) } +var ErrProtocolUnsupported = errors.New("network protocol is not supported") + func (p *privado) GetOpenVPNConnection(selection configuration.ServerSelection) ( connection models.OpenVPNConnection, err error) { var port uint16 = 1194 - switch selection.Protocol { - case constants.UDP: - default: - return connection, fmt.Errorf("protocol %q is not supported by Privado", selection.Protocol) + const protocol = constants.UDP + if selection.TCP { + return connection, fmt.Errorf("%w: TCP for provider Privado", ErrProtocolUnsupported) } if selection.TargetIP != nil { - return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: selection.Protocol}, nil + return models.OpenVPNConnection{ + IP: selection.TargetIP, + Port: port, + Protocol: protocol, + }, nil } servers := p.filterServers(selection.Countries, selection.Regions, @@ -92,7 +98,7 @@ func (p *privado) GetOpenVPNConnection(selection configuration.ServerSelection) connection := models.OpenVPNConnection{ IP: servers[i].IP, Port: port, - Protocol: selection.Protocol, + Protocol: protocol, Hostname: servers[i].Hostname, } connections[i] = connection diff --git a/internal/provider/privatevpn.go b/internal/provider/privatevpn.go index d3efe142..d098eb78 100644 --- a/internal/provider/privatevpn.go +++ b/internal/provider/privatevpn.go @@ -43,7 +43,7 @@ func (p *privatevpn) filterServers(countries, cities, hostnames []string) (serve } func (p *privatevpn) notFoundErr(selection configuration.ServerSelection) error { - message := "no server found for protocol " + selection.Protocol + message := "no server found for protocol " + tcpBoolToProtocol(selection.TCP) if len(selection.Countries) > 0 { message += " + countries " + commaJoin(selection.Countries) @@ -63,14 +63,20 @@ func (p *privatevpn) notFoundErr(selection configuration.ServerSelection) error func (p *privatevpn) GetOpenVPNConnection(selection configuration.ServerSelection) ( connection models.OpenVPNConnection, err error) { var port uint16 - if selection.Protocol == constants.TCP { + protocol := constants.TCP + if selection.TCP { port = 443 } else { + protocol = constants.UDP port = 1194 } if selection.TargetIP != nil { - return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: selection.Protocol}, nil + return models.OpenVPNConnection{ + IP: selection.TargetIP, + Port: port, + Protocol: protocol, + }, nil } servers := p.filterServers(selection.Countries, selection.Cities, selection.Hostnames) @@ -84,7 +90,7 @@ func (p *privatevpn) GetOpenVPNConnection(selection configuration.ServerSelectio connection := models.OpenVPNConnection{ IP: ip, Port: port, - Protocol: selection.Protocol, + Protocol: protocol, } connections = append(connections, connection) } diff --git a/internal/provider/protonvpn.go b/internal/provider/protonvpn.go index 1d433940..7843826d 100644 --- a/internal/provider/protonvpn.go +++ b/internal/provider/protonvpn.go @@ -35,11 +35,13 @@ func (p *protonvpn) GetOpenVPNConnection(selection configuration.ServerSelection return connection, err } + protocol := tcpBoolToProtocol(selection.TCP) + if selection.TargetIP != nil { return models.OpenVPNConnection{ IP: selection.TargetIP, Port: port, - Protocol: selection.Protocol, + Protocol: protocol, }, nil } @@ -54,7 +56,7 @@ func (p *protonvpn) GetOpenVPNConnection(selection configuration.ServerSelection connections[i] = models.OpenVPNConnection{ IP: servers[i].EntryIP, Port: port, - Protocol: selection.Protocol, + Protocol: protocol, } } @@ -139,35 +141,29 @@ func (p *protonvpn) PortForward(ctx context.Context, client *http.Client, func (p *protonvpn) getPort(selection configuration.ServerSelection) (port uint16, err error) { if selection.CustomPort == 0 { - switch selection.Protocol { - case constants.TCP: + if selection.TCP { const defaultTCPPort = 443 return defaultTCPPort, nil - case constants.UDP: - const defaultUDPPort = 1194 - return defaultUDPPort, nil } + const defaultUDPPort = 1194 + return defaultUDPPort, nil } port = selection.CustomPort - switch selection.Protocol { - case constants.TCP: + if selection.TCP { switch port { case 443, 5995, 8443: //nolint:gomnd + return port, nil default: - return 0, fmt.Errorf("%w: %d for protocol %s", - ErrInvalidPort, port, selection.Protocol) - } - case constants.UDP: - switch port { - case 80, 443, 1194, 4569, 5060: //nolint:gomnd - default: - return 0, fmt.Errorf("%w: %d for protocol %s", - ErrInvalidPort, port, selection.Protocol) + return 0, fmt.Errorf("%w: %d for protocol TCP", ErrInvalidPort, port) } } - - return port, nil + switch port { + case 80, 443, 1194, 4569, 5060: //nolint:gomnd + return port, nil + default: + return 0, fmt.Errorf("%w: %d for protocol UDP", ErrInvalidPort, port) + } } func (p *protonvpn) filterServers(countries, regions, cities, names, hostnames []string) ( @@ -188,7 +184,7 @@ func (p *protonvpn) filterServers(countries, regions, cities, names, hostnames [ } func (p *protonvpn) notFoundErr(selection configuration.ServerSelection) error { - message := "no server found for protocol " + selection.Protocol + message := "no server found for protocol " + tcpBoolToProtocol(selection.TCP) if len(selection.Countries) > 0 { message += " + countries " + commaJoin(selection.Countries) diff --git a/internal/provider/purevpn.go b/internal/provider/purevpn.go index 7aaccc85..434586ff 100644 --- a/internal/provider/purevpn.go +++ b/internal/provider/purevpn.go @@ -7,7 +7,6 @@ import ( "net" "net/http" "strconv" - "strings" "github.com/qdm12/gluetun/internal/configuration" "github.com/qdm12/gluetun/internal/constants" @@ -30,7 +29,7 @@ func newPurevpn(servers []models.PurevpnServer, timeNow timeNowFunc) *purevpn { } func (p *purevpn) filterServers(regions, countries, cities, hostnames []string, - protocol string) (servers []models.PurevpnServer) { + tcp bool) (servers []models.PurevpnServer) { for _, server := range p.servers { switch { case @@ -38,8 +37,8 @@ func (p *purevpn) filterServers(regions, countries, cities, hostnames []string, filterByPossibilities(server.Country, countries), filterByPossibilities(server.City, cities), filterByPossibilities(server.Hostname, hostnames), - strings.EqualFold(protocol, "tcp") && !server.TCP, - strings.EqualFold(protocol, "udp") && !server.UDP: + tcp && !server.TCP, + !tcp && !server.UDP: default: servers = append(servers, server) } @@ -49,22 +48,19 @@ func (p *purevpn) filterServers(regions, countries, cities, hostnames []string, func (p *purevpn) GetOpenVPNConnection(selection configuration.ServerSelection) ( connection models.OpenVPNConnection, err error) { - var port uint16 - switch { - case selection.Protocol == constants.UDP: - port = 53 - case selection.Protocol == constants.TCP: + var port uint16 = 53 + protocol := constants.UDP + if selection.TCP { port = 80 - default: - return connection, fmt.Errorf("protocol %q is unknown", selection.Protocol) + protocol = constants.TCP } if selection.TargetIP != nil { - return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: selection.Protocol}, nil + return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: protocol}, nil } servers := p.filterServers(selection.Regions, selection.Countries, - selection.Cities, selection.Hostnames, selection.Protocol) + selection.Cities, selection.Hostnames, selection.TCP) if len(servers) == 0 { return connection, fmt.Errorf("no server found for regions %s, countries %s and cities %s", commaJoin(selection.Regions), commaJoin(selection.Countries), commaJoin(selection.Cities)) @@ -73,7 +69,7 @@ func (p *purevpn) GetOpenVPNConnection(selection configuration.ServerSelection) var connections []models.OpenVPNConnection for _, server := range servers { for _, IP := range server.IPs { - connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: protocol}) } } diff --git a/internal/provider/surfshark.go b/internal/provider/surfshark.go index b30a58d9..6b7af1fe 100644 --- a/internal/provider/surfshark.go +++ b/internal/provider/surfshark.go @@ -7,7 +7,6 @@ import ( "net" "net/http" "strconv" - "strings" "github.com/qdm12/gluetun/internal/configuration" "github.com/qdm12/gluetun/internal/constants" @@ -29,14 +28,14 @@ func newSurfshark(servers []models.SurfsharkServer, timeNow timeNowFunc) *surfsh } } -func (s *surfshark) filterServers(regions, hostnames []string, protocol string) (servers []models.SurfsharkServer) { +func (s *surfshark) filterServers(regions, hostnames []string, tcp bool) (servers []models.SurfsharkServer) { for _, server := range s.servers { switch { case filterByPossibilities(server.Region, regions), filterByPossibilities(server.Hostname, hostnames), - strings.EqualFold(protocol, "tcp") && !server.TCP, - strings.EqualFold(protocol, "udp") && !server.UDP: + tcp && !server.TCP, + !tcp && !server.UDP: default: servers = append(servers, server) } @@ -45,7 +44,7 @@ func (s *surfshark) filterServers(regions, hostnames []string, protocol string) } func (s *surfshark) notFoundErr(selection configuration.ServerSelection) error { - message := "for protocol " + selection.Protocol + message := "for protocol " + tcpBoolToProtocol(selection.TCP) if len(selection.Countries) > 0 { message += " + regions " + commaJoin(selection.Regions) @@ -60,21 +59,18 @@ func (s *surfshark) notFoundErr(selection configuration.ServerSelection) error { func (s *surfshark) GetOpenVPNConnection(selection configuration.ServerSelection) ( connection models.OpenVPNConnection, err error) { - var port uint16 - switch { - case selection.Protocol == constants.TCP: + var port uint16 = 1194 + protocol := constants.UDP + if selection.TCP { port = 1443 - case selection.Protocol == constants.UDP: - port = 1194 - default: - return connection, fmt.Errorf("protocol %q is unknown", selection.Protocol) + protocol = constants.TCP } if selection.TargetIP != nil { - return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: selection.Protocol}, nil + return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: protocol}, nil } - servers := s.filterServers(selection.Regions, selection.Hostnames, selection.Protocol) + servers := s.filterServers(selection.Regions, selection.Hostnames, selection.TCP) if len(servers) == 0 { return connection, s.notFoundErr(selection) } @@ -82,7 +78,7 @@ func (s *surfshark) GetOpenVPNConnection(selection configuration.ServerSelection var connections []models.OpenVPNConnection for _, server := range servers { for _, IP := range server.IPs { - connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: protocol}) } } diff --git a/internal/provider/torguard.go b/internal/provider/torguard.go index 3d641233..dc527845 100644 --- a/internal/provider/torguard.go +++ b/internal/provider/torguard.go @@ -7,7 +7,6 @@ import ( "net" "net/http" "strconv" - "strings" "github.com/qdm12/gluetun/internal/configuration" "github.com/qdm12/gluetun/internal/constants" @@ -30,15 +29,15 @@ func newTorguard(servers []models.TorguardServer, timeNow timeNowFunc) *torguard } func (t *torguard) filterServers(countries, cities, hostnames []string, - protocol string) (servers []models.TorguardServer) { + tcp bool) (servers []models.TorguardServer) { for _, server := range t.servers { switch { case filterByPossibilities(server.Country, countries), filterByPossibilities(server.City, cities), filterByPossibilities(server.Hostname, hostnames), - strings.EqualFold(protocol, "tcp") && !server.TCP, - strings.EqualFold(protocol, "udp") && !server.UDP: + tcp && !server.TCP, + !tcp && !server.UDP: default: servers = append(servers, server) } @@ -47,7 +46,7 @@ func (t *torguard) filterServers(countries, cities, hostnames []string, } func (t *torguard) notFoundErr(selection configuration.ServerSelection) error { - message := "no server found for protocol " + selection.Protocol + message := "no server found for protocol " + tcpBoolToProtocol(selection.TCP) if len(selection.Countries) > 0 { message += " + countries " + commaJoin(selection.Countries) @@ -71,12 +70,14 @@ func (t *torguard) GetOpenVPNConnection(selection configuration.ServerSelection) port = selection.CustomPort } + protocol := tcpBoolToProtocol(selection.TCP) + if selection.TargetIP != nil { - return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: selection.Protocol}, nil + return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: protocol}, nil } servers := t.filterServers(selection.Countries, selection.Cities, - selection.Hostnames, selection.Protocol) + selection.Hostnames, selection.TCP) if len(servers) == 0 { return connection, t.notFoundErr(selection) } @@ -87,7 +88,7 @@ func (t *torguard) GetOpenVPNConnection(selection configuration.ServerSelection) connection := models.OpenVPNConnection{ IP: ip, Port: port, - Protocol: selection.Protocol, + Protocol: protocol, } connections = append(connections, connection) } diff --git a/internal/provider/utils.go b/internal/provider/utils.go index fb9b312a..df80f0e4 100644 --- a/internal/provider/utils.go +++ b/internal/provider/utils.go @@ -52,3 +52,10 @@ func filterByPossibilities(value string, possibilities []string) (filtered bool) func commaJoin(slice []string) string { return strings.Join(slice, ",") } + +func tcpBoolToProtocol(tcp bool) (protocol string) { + if tcp { + return "tcp" + } + return "udp" +} diff --git a/internal/provider/vyprvpn.go b/internal/provider/vyprvpn.go index d3f1b296..a463fd72 100644 --- a/internal/provider/vyprvpn.go +++ b/internal/provider/vyprvpn.go @@ -7,7 +7,6 @@ import ( "net" "net/http" "strconv" - "strings" "github.com/qdm12/gluetun/internal/configuration" "github.com/qdm12/gluetun/internal/constants" @@ -29,14 +28,14 @@ func newVyprvpn(servers []models.VyprvpnServer, timeNow timeNowFunc) *vyprvpn { } } -func (v *vyprvpn) filterServers(regions, hostnames []string, protocol string) (servers []models.VyprvpnServer) { +func (v *vyprvpn) filterServers(regions, hostnames []string, tcp bool) (servers []models.VyprvpnServer) { for _, server := range v.servers { switch { case filterByPossibilities(server.Region, regions), filterByPossibilities(server.Hostname, hostnames), - strings.EqualFold(protocol, "tcp") && !server.TCP, - strings.EqualFold(protocol, "udp") && !server.UDP: + tcp && !server.TCP, + !tcp && !server.UDP: default: servers = append(servers, server) } @@ -47,20 +46,17 @@ func (v *vyprvpn) filterServers(regions, hostnames []string, protocol string) (s func (v *vyprvpn) GetOpenVPNConnection(selection configuration.ServerSelection) ( connection models.OpenVPNConnection, err error) { var port uint16 - switch { - case selection.Protocol == constants.TCP: - return connection, fmt.Errorf("TCP protocol not supported by this VPN provider") - case selection.Protocol == constants.UDP: - port = 443 - default: - return connection, fmt.Errorf("protocol %q is unknown", selection.Protocol) + const protocol = constants.TCP + if selection.TCP { + return connection, fmt.Errorf("%w: TCP for provider VyprVPN", + ErrProtocolUnsupported) } if selection.TargetIP != nil { - return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: selection.Protocol}, nil + return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: protocol}, nil } - servers := v.filterServers(selection.Regions, selection.Hostnames, selection.Protocol) + servers := v.filterServers(selection.Regions, selection.Hostnames, selection.TCP) if len(servers) == 0 { return connection, fmt.Errorf("no server found for region %s", commaJoin(selection.Regions)) } @@ -68,7 +64,7 @@ func (v *vyprvpn) GetOpenVPNConnection(selection configuration.ServerSelection) var connections []models.OpenVPNConnection for _, server := range servers { for _, IP := range server.IPs { - connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: selection.Protocol}) + connections = append(connections, models.OpenVPNConnection{IP: IP, Port: port, Protocol: protocol}) } } diff --git a/internal/provider/windscribe.go b/internal/provider/windscribe.go index c33a22f9..53e47352 100644 --- a/internal/provider/windscribe.go +++ b/internal/provider/windscribe.go @@ -45,20 +45,19 @@ func (w *windscribe) filterServers(regions, cities, hostnames []string) (servers //nolint:lll func (w *windscribe) GetOpenVPNConnection(selection configuration.ServerSelection) (connection models.OpenVPNConnection, err error) { - var port uint16 - switch { - case selection.CustomPort > 0: - port = selection.CustomPort - case selection.Protocol == constants.TCP: + var port uint16 = 443 + protocol := constants.UDP + if selection.TCP { port = 1194 - case selection.Protocol == constants.UDP: - port = 443 - default: - return connection, fmt.Errorf("protocol %q is unknown", selection.Protocol) + protocol = constants.TCP + } + + if selection.CustomPort > 0 { + port = selection.CustomPort } if selection.TargetIP != nil { - return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: selection.Protocol}, nil + return models.OpenVPNConnection{IP: selection.TargetIP, Port: port, Protocol: protocol}, nil } servers := w.filterServers(selection.Regions, selection.Cities, selection.Hostnames) @@ -72,7 +71,7 @@ func (w *windscribe) GetOpenVPNConnection(selection configuration.ServerSelectio connection := models.OpenVPNConnection{ IP: ip, Port: port, - Protocol: selection.Protocol, + Protocol: protocol, } connections = append(connections, connection) }