diff --git a/internal/provider/cyberghost/connection.go b/internal/provider/cyberghost/connection.go index 9e224cfa..19fbb3ca 100644 --- a/internal/provider/cyberghost/connection.go +++ b/internal/provider/cyberghost/connection.go @@ -33,9 +33,5 @@ func (c *Cyberghost) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, c.randSource), nil + return utils.PickConnection(connections, selection, c.randSource) } diff --git a/internal/provider/fastestvpn/connection.go b/internal/provider/fastestvpn/connection.go index 0dbb13ae..6dc1c0a0 100644 --- a/internal/provider/fastestvpn/connection.go +++ b/internal/provider/fastestvpn/connection.go @@ -33,9 +33,5 @@ func (f *Fastestvpn) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, f.randSource), nil + return utils.PickConnection(connections, selection, f.randSource) } diff --git a/internal/provider/hidemyass/connection.go b/internal/provider/hidemyass/connection.go index 608a8020..8ef2018e 100644 --- a/internal/provider/hidemyass/connection.go +++ b/internal/provider/hidemyass/connection.go @@ -38,9 +38,5 @@ func (h *HideMyAss) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, h.randSource), nil + return utils.PickConnection(connections, selection, h.randSource) } diff --git a/internal/provider/ipvanish/connection.go b/internal/provider/ipvanish/connection.go index 970c1f11..270dab07 100644 --- a/internal/provider/ipvanish/connection.go +++ b/internal/provider/ipvanish/connection.go @@ -38,9 +38,5 @@ func (i *Ipvanish) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, i.randSource), nil + return utils.PickConnection(connections, selection, i.randSource) } diff --git a/internal/provider/ivpn/connection.go b/internal/provider/ivpn/connection.go index b1f2009b..703a203d 100644 --- a/internal/provider/ivpn/connection.go +++ b/internal/provider/ivpn/connection.go @@ -31,11 +31,7 @@ func (i *Ivpn) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, i.randSource), nil + return utils.PickConnection(connections, selection, i.randSource) } func getPort(selection configuration.ServerSelection) (port uint16) { diff --git a/internal/provider/mullvad/connection.go b/internal/provider/mullvad/connection.go index 341c7d62..de7937cd 100644 --- a/internal/provider/mullvad/connection.go +++ b/internal/provider/mullvad/connection.go @@ -30,11 +30,7 @@ func (m *Mullvad) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, m.randSource), nil + return utils.PickConnection(connections, selection, m.randSource) } func getPort(selection configuration.ServerSelection) (port uint16) { diff --git a/internal/provider/nordvpn/connection.go b/internal/provider/nordvpn/connection.go index 7cc04e7c..15abc61d 100644 --- a/internal/provider/nordvpn/connection.go +++ b/internal/provider/nordvpn/connection.go @@ -32,9 +32,5 @@ func (n *Nordvpn) GetConnection(selection configuration.ServerSelection) ( connections[i] = connection } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, n.randSource), nil + return utils.PickConnection(connections, selection, n.randSource) } diff --git a/internal/provider/privado/connection.go b/internal/provider/privado/connection.go index 8f5d5c63..41b621af 100644 --- a/internal/provider/privado/connection.go +++ b/internal/provider/privado/connection.go @@ -37,9 +37,5 @@ func (p *Privado) GetConnection(selection configuration.ServerSelection) ( connections[i] = connection } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, p.randSource), nil + return utils.PickConnection(connections, selection, p.randSource) } diff --git a/internal/provider/privateinternetaccess/connection.go b/internal/provider/privateinternetaccess/connection.go index 33de8e80..e7fcbde7 100644 --- a/internal/provider/privateinternetaccess/connection.go +++ b/internal/provider/privateinternetaccess/connection.go @@ -38,15 +38,5 @@ func (p *PIA) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - connection, err = utils.GetTargetIPConnection(connections, selection.TargetIP) - } else { - connection, err = utils.PickRandomConnection(connections, p.randSource), nil - } - - if err != nil { - return connection, err - } - - return connection, nil + return utils.PickConnection(connections, selection, p.randSource) } diff --git a/internal/provider/privatevpn/connection.go b/internal/provider/privatevpn/connection.go index 44d39c36..9e76f43b 100644 --- a/internal/provider/privatevpn/connection.go +++ b/internal/provider/privatevpn/connection.go @@ -34,9 +34,5 @@ func (p *Privatevpn) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, p.randSource), nil + return utils.PickConnection(connections, selection, p.randSource) } diff --git a/internal/provider/protonvpn/connection.go b/internal/provider/protonvpn/connection.go index 3ac1b02d..48f63367 100644 --- a/internal/provider/protonvpn/connection.go +++ b/internal/provider/protonvpn/connection.go @@ -34,9 +34,5 @@ func (p *Protonvpn) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, p.randSource), nil + return utils.PickConnection(connections, selection, p.randSource) } diff --git a/internal/provider/purevpn/connection.go b/internal/provider/purevpn/connection.go index 1837d0e8..b0ea9119 100644 --- a/internal/provider/purevpn/connection.go +++ b/internal/provider/purevpn/connection.go @@ -34,9 +34,5 @@ func (p *Purevpn) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, p.randSource), nil + return utils.PickConnection(connections, selection, p.randSource) } diff --git a/internal/provider/surfshark/connection.go b/internal/provider/surfshark/connection.go index 2da549ad..2eeba1e5 100644 --- a/internal/provider/surfshark/connection.go +++ b/internal/provider/surfshark/connection.go @@ -34,9 +34,5 @@ func (s *Surfshark) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, s.randSource), nil + return utils.PickConnection(connections, selection, s.randSource) } diff --git a/internal/provider/torguard/connection.go b/internal/provider/torguard/connection.go index f5e4e7d5..0abc4511 100644 --- a/internal/provider/torguard/connection.go +++ b/internal/provider/torguard/connection.go @@ -37,9 +37,5 @@ func (t *Torguard) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, t.randSource), nil + return utils.PickConnection(connections, selection, t.randSource) } diff --git a/internal/provider/utils/pick.go b/internal/provider/utils/pick.go index 71d67d21..423261db 100644 --- a/internal/provider/utils/pick.go +++ b/internal/provider/utils/pick.go @@ -1,12 +1,51 @@ package utils import ( + "errors" + "fmt" "math/rand" + "net" + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/models" ) -func PickRandomConnection(connections []models.Connection, +// PickConnection picks a connection from a pool of connections. +// If the VPN protocol is Wireguard and the target IP is set, +// it finds the connection corresponding to this target IP. +// Otherwise, it picks a random connection from the pool of connections +// and sets the target IP address as the IP if this one is set. +func PickConnection(connections []models.Connection, + selection configuration.ServerSelection, randSource rand.Source) ( + connection models.Connection, err error) { + if selection.TargetIP != nil && selection.VPN == constants.Wireguard { + // we need the right public key + return getTargetIPConnection(connections, selection.TargetIP) + } + + connection = pickRandomConnection(connections, randSource) + if selection.TargetIP != nil { + connection.IP = selection.TargetIP + } + + return connection, nil +} + +func pickRandomConnection(connections []models.Connection, source rand.Source) models.Connection { return connections[rand.New(source).Intn(len(connections))] //nolint:gosec } + +var errTargetIPNotFound = errors.New("target IP address not found") + +func getTargetIPConnection(connections []models.Connection, + targetIP net.IP) (connection models.Connection, err error) { + for _, connection := range connections { + if targetIP.Equal(connection.IP) { + return connection, nil + } + } + return connection, fmt.Errorf("%w: in %d filtered connections", + errTargetIPNotFound, len(connections)) +} diff --git a/internal/provider/utils/pick_test.go b/internal/provider/utils/pick_test.go index 6f14d95f..83380684 100644 --- a/internal/provider/utils/pick_test.go +++ b/internal/provider/utils/pick_test.go @@ -8,19 +8,19 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_PickRandomConnection(t *testing.T) { +func Test_pickRandomConnection(t *testing.T) { t.Parallel() connections := []models.Connection{ {Port: 1}, {Port: 2}, {Port: 3}, {Port: 4}, } source := rand.NewSource(0) - connection := PickRandomConnection(connections, source) + connection := pickRandomConnection(connections, source) assert.Equal(t, models.Connection{Port: 3}, connection) - connection = PickRandomConnection(connections, source) + connection = pickRandomConnection(connections, source) assert.Equal(t, models.Connection{Port: 3}, connection) - connection = PickRandomConnection(connections, source) + connection = pickRandomConnection(connections, source) assert.Equal(t, models.Connection{Port: 2}, connection) } diff --git a/internal/provider/utils/targetip.go b/internal/provider/utils/targetip.go deleted file mode 100644 index 23107cdc..00000000 --- a/internal/provider/utils/targetip.go +++ /dev/null @@ -1,22 +0,0 @@ -package utils - -import ( - "errors" - "fmt" - "net" - - "github.com/qdm12/gluetun/internal/models" -) - -var ErrTargetIPNotFound = errors.New("target IP address not found") - -func GetTargetIPConnection(connections []models.Connection, - targetIP net.IP) (connection models.Connection, err error) { - for _, connection := range connections { - if targetIP.Equal(connection.IP) { - return connection, nil - } - } - return connection, fmt.Errorf("%w: in %d filtered connections", - ErrTargetIPNotFound, len(connections)) -} diff --git a/internal/provider/vpnunlimited/connection.go b/internal/provider/vpnunlimited/connection.go index f09f7e2e..8afcca49 100644 --- a/internal/provider/vpnunlimited/connection.go +++ b/internal/provider/vpnunlimited/connection.go @@ -37,9 +37,5 @@ func (p *Provider) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, p.randSource), nil + return utils.PickConnection(connections, selection, p.randSource) } diff --git a/internal/provider/vyprvpn/connection.go b/internal/provider/vyprvpn/connection.go index 6babf608..774f1da2 100644 --- a/internal/provider/vyprvpn/connection.go +++ b/internal/provider/vyprvpn/connection.go @@ -38,9 +38,5 @@ func (v *Vyprvpn) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, v.randSource), nil + return utils.PickConnection(connections, selection, v.randSource) } diff --git a/internal/provider/windscribe/connection.go b/internal/provider/windscribe/connection.go index 5c0ea014..d853984c 100644 --- a/internal/provider/windscribe/connection.go +++ b/internal/provider/windscribe/connection.go @@ -31,11 +31,7 @@ func (w *Windscribe) GetConnection(selection configuration.ServerSelection) ( } } - if selection.TargetIP != nil { - return utils.GetTargetIPConnection(connections, selection.TargetIP) - } - - return utils.PickRandomConnection(connections, w.randSource), nil + return utils.PickConnection(connections, selection, w.randSource) } func getPort(selection configuration.ServerSelection) (port uint16) {