diff --git a/internal/configuration/settings/helpers/belong.go b/internal/configuration/settings/helpers/belong.go index d60e6dd6..62aafde2 100644 --- a/internal/configuration/settings/helpers/belong.go +++ b/internal/configuration/settings/helpers/belong.go @@ -15,9 +15,16 @@ func IsOneOf(value string, choices ...string) (ok bool) { return false } -var ErrValueNotOneOf = errors.New("value is not one of the possible choices") +var ( + ErrNoChoice = errors.New("one or more values is set but there is no possible value available") + ErrValueNotOneOf = errors.New("value is not one of the possible choices") +) func AreAllOneOf(values, choices []string) (err error) { + if len(values) > 0 && len(choices) == 0 { + return ErrNoChoice + } + set := make(map[string]struct{}, len(choices)) for _, choice := range choices { choice = strings.ToLower(choice) diff --git a/internal/configuration/settings/serverselection.go b/internal/configuration/settings/serverselection.go index 776252f1..b0709257 100644 --- a/internal/configuration/settings/serverselection.go +++ b/internal/configuration/settings/serverselection.go @@ -1,6 +1,7 @@ package settings import ( + "errors" "fmt" "net" "strings" @@ -37,16 +38,16 @@ type ServerSelection struct { //nolint:maligned Numbers []uint16 // Hostnames is the list of hostnames to filter VPN servers with. Hostnames []string - // OwnedOnly is true if only VPN provider owned servers + // OwnedOnly is true if VPN provider servers that are not owned // should be filtered. This is used with Mullvad. OwnedOnly *bool - // FreeOnly is true if only free VPN servers - // should be filtered. This is used with ProtonVPN. + // FreeOnly is true if VPN servers that are not free should + // be filtered. This is used with ProtonVPN and VPN Unlimited. FreeOnly *bool - // FreeOnly is true if only free VPN servers - // should be filtered. This is used with ProtonVPN. + // StreamOnly is true if VPN servers not for streaming should + // be filtered. This is used with VPNUnlimited. StreamOnly *bool - // MultiHopOnly is true if only multihop VPN servers + // MultiHopOnly is true if VPN servers that are not multihop // should be filtered. This is used with Surfshark. MultiHopOnly *bool @@ -58,6 +59,13 @@ type ServerSelection struct { //nolint:maligned Wireguard WireguardSelection } +var ( + ErrOwnedOnlyNotSupported = errors.New("owned only filter is not supported") + ErrFreeOnlyNotSupported = errors.New("free only filter is not supported") + ErrStreamOnlyNotSupported = errors.New("stream only filter is not supported") + ErrMultiHopOnlyNotSupported = errors.New("multi hop only filter is not supported") +) + func (ss *ServerSelection) validate(vpnServiceProvider string, allServers models.AllServers) (err error) { switch ss.VPN { @@ -66,8 +74,71 @@ func (ss *ServerSelection) validate(vpnServiceProvider string, return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN) } - var countryChoices, regionChoices, cityChoices, - ispChoices, nameChoices, hostnameChoices []string + countryChoices, regionChoices, cityChoices, + ispChoices, nameChoices, hostnameChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, allServers) + if err != nil { + return err // already wrapped error + } + + err = validateServerFilters(*ss, countryChoices, regionChoices, cityChoices, + ispChoices, nameChoices, hostnameChoices) + if err != nil { + if errors.Is(err, helpers.ErrNoChoice) { + return fmt.Errorf("for VPN service provider %s: %w", vpnServiceProvider, err) + } + return err // already wrapped error + } + + if *ss.OwnedOnly && + vpnServiceProvider != constants.Mullvad { + return fmt.Errorf("%w: for VPN service provider %s", + ErrOwnedOnlyNotSupported, vpnServiceProvider) + } + + if *ss.FreeOnly && + !helpers.IsOneOf(vpnServiceProvider, + constants.Protonvpn, + constants.VPNUnlimited, + ) { + return fmt.Errorf("%w: for VPN service provider %s", + ErrFreeOnlyNotSupported, vpnServiceProvider) + } + + if *ss.StreamOnly && + !helpers.IsOneOf(vpnServiceProvider, + constants.Protonvpn, + constants.VPNUnlimited, + ) { + return fmt.Errorf("%w: for VPN service provider %s", + ErrStreamOnlyNotSupported, vpnServiceProvider) + } + + if *ss.MultiHopOnly && + vpnServiceProvider != constants.Surfshark { + return fmt.Errorf("%w: for VPN service provider %s", + ErrStreamOnlyNotSupported, vpnServiceProvider) + } + + if ss.VPN == constants.OpenVPN { + err = ss.OpenVPN.validate(vpnServiceProvider) + if err != nil { + return fmt.Errorf("OpenVPN server selection settings validation failed: %w", err) + } + } else { + err = ss.Wireguard.validate(vpnServiceProvider) + if err != nil { + return fmt.Errorf("Wireguard server selection settings validation failed: %w", err) + } + } + + return nil +} + +func getLocationFilterChoices(vpnServiceProvider string, ss *ServerSelection, + allServers models.AllServers) ( + countryChoices, regionChoices, cityChoices, + ispChoices, nameChoices, hostnameChoices []string, + err error) { switch vpnServiceProvider { case constants.Custom: case constants.Cyberghost: @@ -151,7 +222,7 @@ func (ss *ServerSelection) validate(vpnServiceProvider string, // TODO v4 remove regionChoices = append(regionChoices, validation.SurfsharkRetroLocChoices()...) if err := helpers.AreAllOneOf(ss.Regions, regionChoices); err != nil { - return fmt.Errorf("%w: %s", ErrRegionNotValid, err) + return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrRegionNotValid, err) } // Retro compatibility // TODO remove in v4 @@ -179,28 +250,11 @@ func (ss *ServerSelection) validate(vpnServiceProvider string, cityChoices = validation.WindscribeCityChoices(servers) hostnameChoices = validation.WindscribeHostnameChoices(servers) default: - return fmt.Errorf("%w: %s", ErrVPNProviderNameNotValid, vpnServiceProvider) + return nil, nil, nil, nil, nil, nil, fmt.Errorf("%w: %s", ErrVPNProviderNameNotValid, vpnServiceProvider) } - err = validateServerFilters(*ss, countryChoices, regionChoices, cityChoices, - ispChoices, nameChoices, hostnameChoices) - if err != nil { - return err // already wrapped error - } - - if ss.VPN == constants.OpenVPN { - err = ss.OpenVPN.validate(vpnServiceProvider) - if err != nil { - return fmt.Errorf("OpenVPN server selection settings validation failed: %w", err) - } - } else { - err = ss.Wireguard.validate(vpnServiceProvider) - if err != nil { - return fmt.Errorf("Wireguard server selection settings validation failed: %w", err) - } - } - - return nil + return countryChoices, regionChoices, cityChoices, + ispChoices, nameChoices, hostnameChoices, nil } // validateServerFilters validates filters against the choices given as arguments. @@ -208,40 +262,28 @@ func (ss *ServerSelection) validate(vpnServiceProvider string, func validateServerFilters(settings ServerSelection, countryChoices, regionChoices, cityChoices, ispChoices, nameChoices, hostnameChoices []string) (err error) { - if countryChoices != nil { - if err := helpers.AreAllOneOf(settings.Countries, countryChoices); err != nil { - return fmt.Errorf("%w: %s", ErrCountryNotValid, err) - } + if err := helpers.AreAllOneOf(settings.Countries, countryChoices); err != nil { + return fmt.Errorf("%w: %s", ErrCountryNotValid, err) } - if regionChoices != nil { - if err := helpers.AreAllOneOf(settings.Regions, regionChoices); err != nil { - return fmt.Errorf("%w: %s", ErrRegionNotValid, err) - } + if err := helpers.AreAllOneOf(settings.Regions, regionChoices); err != nil { + return fmt.Errorf("%w: %s", ErrRegionNotValid, err) } - if cityChoices != nil { - if err := helpers.AreAllOneOf(settings.Cities, cityChoices); err != nil { - return fmt.Errorf("%w: %s", ErrCityNotValid, err) - } + if err := helpers.AreAllOneOf(settings.Cities, cityChoices); err != nil { + return fmt.Errorf("%w: %s", ErrCityNotValid, err) } - if ispChoices != nil { - if err := helpers.AreAllOneOf(settings.ISPs, ispChoices); err != nil { - return fmt.Errorf("%w: %s", ErrISPNotValid, err) - } + if err := helpers.AreAllOneOf(settings.ISPs, ispChoices); err != nil { + return fmt.Errorf("%w: %s", ErrISPNotValid, err) } - if hostnameChoices != nil { - if err := helpers.AreAllOneOf(settings.Hostnames, hostnameChoices); err != nil { - return fmt.Errorf("%w: %s", ErrHostnameNotValid, err) - } + if err := helpers.AreAllOneOf(settings.Hostnames, hostnameChoices); err != nil { + return fmt.Errorf("%w: %s", ErrHostnameNotValid, err) } - if nameChoices != nil { - if err := helpers.AreAllOneOf(settings.Names, nameChoices); err != nil { - return fmt.Errorf("%w: %s", ErrNameNotValid, err) - } + if err := helpers.AreAllOneOf(settings.Names, nameChoices); err != nil { + return fmt.Errorf("%w: %s", ErrNameNotValid, err) } return nil