diff --git a/internal/provider/fastestvpn/filter.go b/internal/provider/fastestvpn/filter.go index 93e5c8c6..67d37542 100644 --- a/internal/provider/fastestvpn/filter.go +++ b/internal/provider/fastestvpn/filter.go @@ -13,8 +13,7 @@ func (f *Fastestvpn) filterServers(selection configuration.ServerSelection) ( case utils.FilterByPossibilities(server.Country, selection.Countries), utils.FilterByPossibilities(server.Hostname, selection.Hostnames), - selection.OpenVPN.TCP && !server.TCP, - !selection.OpenVPN.TCP && !server.UDP: + utils.FilterByProtocol(selection, server.TCP, server.UDP): default: servers = append(servers, server) } diff --git a/internal/provider/hidemyass/filter.go b/internal/provider/hidemyass/filter.go index 2a721f46..fa1c3e79 100644 --- a/internal/provider/hidemyass/filter.go +++ b/internal/provider/hidemyass/filter.go @@ -14,8 +14,7 @@ func (h *HideMyAss) filterServers(selection configuration.ServerSelection) ( utils.FilterByPossibilities(server.Country, selection.Countries), utils.FilterByPossibilities(server.City, selection.Cities), utils.FilterByPossibilities(server.Hostname, selection.Hostnames), - selection.OpenVPN.TCP && !server.TCP, - !selection.OpenVPN.TCP && !server.UDP: + utils.FilterByProtocol(selection, server.TCP, server.UDP): default: servers = append(servers, server) } diff --git a/internal/provider/ipvanish/filter.go b/internal/provider/ipvanish/filter.go index 393ea0c6..02e3c9a8 100644 --- a/internal/provider/ipvanish/filter.go +++ b/internal/provider/ipvanish/filter.go @@ -14,8 +14,7 @@ func (i *Ipvanish) filterServers(selection configuration.ServerSelection) ( utils.FilterByPossibilities(server.Country, selection.Countries), utils.FilterByPossibilities(server.City, selection.Cities), utils.FilterByPossibilities(server.Hostname, selection.Hostnames), - selection.OpenVPN.TCP && !server.TCP, - !selection.OpenVPN.TCP && !server.UDP: + utils.FilterByProtocol(selection, server.TCP, server.UDP): default: servers = append(servers, server) } diff --git a/internal/provider/ivpn/filter.go b/internal/provider/ivpn/filter.go index 580de4bf..a757970b 100644 --- a/internal/provider/ivpn/filter.go +++ b/internal/provider/ivpn/filter.go @@ -16,8 +16,7 @@ func (i *Ivpn) filterServers(selection configuration.ServerSelection) ( utils.FilterByPossibilities(server.Country, selection.Countries), utils.FilterByPossibilities(server.City, selection.Cities), utils.FilterByPossibilities(server.Hostname, selection.Hostnames), - selection.OpenVPN.TCP && !server.TCP, - !selection.OpenVPN.TCP && !server.UDP: + utils.FilterByProtocol(selection, server.TCP, server.UDP): default: servers = append(servers, server) } diff --git a/internal/provider/nordvpn/filter.go b/internal/provider/nordvpn/filter.go index 6942fff1..e60b00f0 100644 --- a/internal/provider/nordvpn/filter.go +++ b/internal/provider/nordvpn/filter.go @@ -23,8 +23,7 @@ func (n *Nordvpn) filterServers(selection configuration.ServerSelection) ( utils.FilterByPossibilities(server.Hostname, selection.Hostnames), utils.FilterByPossibilities(server.Name, selection.Names), utils.FilterByPossibilities(serverNumber, selectedNumbers), - selection.OpenVPN.TCP && !server.TCP, - !selection.OpenVPN.TCP && !server.UDP: + utils.FilterByProtocol(selection, server.TCP, server.UDP): default: servers = append(servers, server) } diff --git a/internal/provider/privateinternetaccess/filter.go b/internal/provider/privateinternetaccess/filter.go index abd10936..5efb55e8 100644 --- a/internal/provider/privateinternetaccess/filter.go +++ b/internal/provider/privateinternetaccess/filter.go @@ -14,8 +14,7 @@ func (p *PIA) filterServers(selection configuration.ServerSelection) ( utils.FilterByPossibilities(server.Region, selection.Regions), utils.FilterByPossibilities(server.Hostname, selection.Hostnames), utils.FilterByPossibilities(server.ServerName, selection.Names), - selection.OpenVPN.TCP && !server.TCP, - !selection.OpenVPN.TCP && !server.UDP: + utils.FilterByProtocol(selection, server.TCP, server.UDP): default: servers = append(servers, server) } diff --git a/internal/provider/purevpn/filter.go b/internal/provider/purevpn/filter.go index 40b5fc49..9f2d0f0e 100644 --- a/internal/provider/purevpn/filter.go +++ b/internal/provider/purevpn/filter.go @@ -15,8 +15,7 @@ func (p *Purevpn) filterServers(selection configuration.ServerSelection) ( utils.FilterByPossibilities(server.Country, selection.Countries), utils.FilterByPossibilities(server.City, selection.Cities), utils.FilterByPossibilities(server.Hostname, selection.Hostnames), - selection.OpenVPN.TCP && !server.TCP, - !selection.OpenVPN.TCP && !server.UDP: + utils.FilterByProtocol(selection, server.TCP, server.UDP): default: servers = append(servers, server) } diff --git a/internal/provider/surfshark/filter.go b/internal/provider/surfshark/filter.go index 3b82311b..8d71fa77 100644 --- a/internal/provider/surfshark/filter.go +++ b/internal/provider/surfshark/filter.go @@ -15,8 +15,7 @@ func (s *Surfshark) filterServers(selection configuration.ServerSelection) ( utils.FilterByPossibilities(server.Country, selection.Countries), utils.FilterByPossibilities(server.City, selection.Cities), utils.FilterByPossibilities(server.Hostname, selection.Hostnames), - selection.OpenVPN.TCP && !server.TCP, - !selection.OpenVPN.TCP && !server.UDP, + utils.FilterByProtocol(selection, server.TCP, server.UDP), selection.MultiHopOnly && !server.MultiHop: default: servers = append(servers, server) diff --git a/internal/provider/torguard/filter.go b/internal/provider/torguard/filter.go index 782b70d5..c8a8ebb2 100644 --- a/internal/provider/torguard/filter.go +++ b/internal/provider/torguard/filter.go @@ -14,8 +14,7 @@ func (t *Torguard) filterServers(selection configuration.ServerSelection) ( utils.FilterByPossibilities(server.Country, selection.Countries), utils.FilterByPossibilities(server.City, selection.Cities), utils.FilterByPossibilities(server.Hostname, selection.Hostnames), - selection.OpenVPN.TCP && !server.TCP, - !selection.OpenVPN.TCP && !server.UDP: + utils.FilterByProtocol(selection, server.TCP, server.UDP): default: servers = append(servers, server) } diff --git a/internal/provider/utils/protocol.go b/internal/provider/utils/protocol.go index 518c99cc..48fbb606 100644 --- a/internal/provider/utils/protocol.go +++ b/internal/provider/utils/protocol.go @@ -11,3 +11,15 @@ func GetProtocol(selection configuration.ServerSelection) (protocol string) { } return constants.UDP } + +func FilterByProtocol(selection configuration.ServerSelection, + serverTCP, serverUDP bool) (filtered bool) { + switch selection.VPN { + case constants.Wireguard: + return !serverUDP + default: // OpenVPN + wantTCP := selection.OpenVPN.TCP + wantUDP := !wantTCP + return (wantTCP && !serverTCP) || (wantUDP && !serverUDP) + } +} diff --git a/internal/provider/utils/protocol_test.go b/internal/provider/utils/protocol_test.go index 08cdedf0..c31a6ff3 100644 --- a/internal/provider/utils/protocol_test.go +++ b/internal/provider/utils/protocol_test.go @@ -52,3 +52,81 @@ func Test_GetProtocol(t *testing.T) { }) } } + +func Test_FilterByProtocol(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + selection configuration.ServerSelection + serverTCP bool + serverUDP bool + filtered bool + }{ + "Wireguard and server has UDP": { + selection: configuration.ServerSelection{ + VPN: constants.Wireguard, + }, + serverUDP: true, + filtered: false, + }, + "Wireguard and server has not UDP": { + selection: configuration.ServerSelection{ + VPN: constants.Wireguard, + }, + serverUDP: false, + filtered: true, + }, + "OpenVPN UDP and server has UDP": { + selection: configuration.ServerSelection{ + VPN: constants.OpenVPN, + OpenVPN: configuration.OpenVPNSelection{ + TCP: false, + }, + }, + serverUDP: true, + filtered: false, + }, + "OpenVPN UDP and server has not UDP": { + selection: configuration.ServerSelection{ + VPN: constants.OpenVPN, + OpenVPN: configuration.OpenVPNSelection{ + TCP: false, + }, + }, + serverUDP: false, + filtered: true, + }, + "OpenVPN TCP and server has TCP": { + selection: configuration.ServerSelection{ + VPN: constants.OpenVPN, + OpenVPN: configuration.OpenVPNSelection{ + TCP: true, + }, + }, + serverTCP: true, + filtered: false, + }, + "OpenVPN TCP and server has not TCP": { + selection: configuration.ServerSelection{ + VPN: constants.OpenVPN, + OpenVPN: configuration.OpenVPNSelection{ + TCP: true, + }, + }, + serverTCP: false, + filtered: true, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + filtered := FilterByProtocol(testCase.selection, + testCase.serverTCP, testCase.serverUDP) + + assert.Equal(t, testCase.filtered, filtered) + }) + } +} diff --git a/internal/provider/vpnunlimited/filter.go b/internal/provider/vpnunlimited/filter.go index d1ed7044..a368d367 100644 --- a/internal/provider/vpnunlimited/filter.go +++ b/internal/provider/vpnunlimited/filter.go @@ -16,8 +16,7 @@ func (p *Provider) filterServers(selection configuration.ServerSelection) ( utils.FilterByPossibilities(server.Hostname, selection.Hostnames), selection.FreeOnly && !server.Free, selection.StreamOnly && !server.Stream, - selection.OpenVPN.TCP && !server.TCP, - !selection.OpenVPN.TCP && !server.UDP: + utils.FilterByProtocol(selection, server.TCP, server.UDP): default: servers = append(servers, server) } diff --git a/internal/provider/vyprvpn/filter.go b/internal/provider/vyprvpn/filter.go index c4d0db27..ae9dac80 100644 --- a/internal/provider/vyprvpn/filter.go +++ b/internal/provider/vyprvpn/filter.go @@ -13,8 +13,7 @@ func (v *Vyprvpn) filterServers(selection configuration.ServerSelection) ( case utils.FilterByPossibilities(server.Region, selection.Regions), utils.FilterByPossibilities(server.Hostname, selection.Hostnames), - selection.OpenVPN.TCP && !server.TCP, - !selection.OpenVPN.TCP && !server.UDP: + utils.FilterByProtocol(selection, server.TCP, server.UDP): default: servers = append(servers, server) }