diff --git a/internal/configuration/mullvad.go b/internal/configuration/mullvad.go index 4c943222..df057997 100644 --- a/internal/configuration/mullvad.go +++ b/internal/configuration/mullvad.go @@ -40,7 +40,12 @@ func (settings *Provider) readMullvad(r reader) (err error) { return fmt.Errorf("environment variable OWNED: %w", err) } - return settings.ServerSelection.OpenVPN.readMullvad(r.env) + err = settings.ServerSelection.OpenVPN.readMullvad(r.env) + if err != nil { + return err + } + + return settings.ServerSelection.Wireguard.readMullvad(r.env) } func (settings *OpenVPNSelection) readMullvad(env params.Interface) (err error) { @@ -57,3 +62,12 @@ func (settings *OpenVPNSelection) readMullvad(env params.Interface) (err error) return nil } + +func (settings *WireguardSelection) readMullvad(env params.Interface) (err error) { + settings.CustomPort, err = readWireguardCustomPort(env, nil) + if err != nil { + return err + } + + return nil +} diff --git a/internal/configuration/provider.go b/internal/configuration/provider.go index 82351463..f84d4735 100644 --- a/internal/configuration/provider.go +++ b/internal/configuration/provider.go @@ -167,6 +167,7 @@ func readOpenVPNCustomPort(env params.Interface, tcp bool, ErrInvalidPort, port, portsToString(allowedUDP)) } +// note: set allowed to an empty slice to allow all valid ports func readWireguardCustomPort(env params.Interface, allowed []uint16) (port uint16, err error) { port, err = readPortOrZero(env, "WIREGUARD_PORT") if err != nil { @@ -175,11 +176,16 @@ func readWireguardCustomPort(env params.Interface, allowed []uint16) (port uint1 return 0, nil } + if len(allowed) == 0 { + return port, nil + } + for i := range allowed { if allowed[i] == port { return port, nil } } + return 0, fmt.Errorf( "environment variable WIREGUARD_PORT: %w: port %d, can only be one of %s", ErrInvalidPort, port, portsToString(allowed))