fix(settings): validate Wireguard addresses depending on IPv6 support

This commit is contained in:
Quentin McGaw
2022-12-14 11:29:40 +00:00
parent 16acd1b162
commit f70f0aca9c
9 changed files with 44 additions and 31 deletions

View File

@@ -232,7 +232,12 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
return err return err
} }
err = allSettings.Validate(storage) ipv6Supported, err := netLinker.IsIPv6Supported()
if err != nil {
return fmt.Errorf("checking for IPv6 support: %w", err)
}
err = allSettings.Validate(storage, ipv6Supported)
if err != nil { if err != nil {
return err return err
} }
@@ -296,11 +301,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
return err return err
} }
ipv6Supported, err := netLinker.IsIPv6Supported()
if err != nil {
return fmt.Errorf("checking for IPv6 support: %w", err)
}
if err := routingConf.Setup(); err != nil { if err := routingConf.Setup(); err != nil {
if strings.Contains(err.Error(), "operation not permitted") { if strings.Contains(err.Error(), "operation not permitted") {
logger.Warn("💡 Tip: Are you passing NET_ADMIN capability to gluetun?") logger.Warn("💡 Tip: Are you passing NET_ADMIN capability to gluetun?")
@@ -451,7 +451,8 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
"http server", goroutine.OptionTimeout(defaultShutdownTimeout)) "http server", goroutine.OptionTimeout(defaultShutdownTimeout))
httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging, httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging,
logger.New(log.SetComponent("http server")), logger.New(log.SetComponent("http server")),
buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper, storage) buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper,
storage, ipv6Supported)
if err != nil { if err != nil {
return fmt.Errorf("cannot setup control server: %w", err) return fmt.Errorf("cannot setup control server: %w", err)
} }

View File

@@ -51,15 +51,15 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, source Source,
return err return err
} }
if err = allSettings.Validate(storage); err != nil {
return err
}
ipv6Supported, err := ipv6Checker.IsIPv6Supported() ipv6Supported, err := ipv6Checker.IsIPv6Supported()
if err != nil { if err != nil {
return fmt.Errorf("checking for IPv6 support: %w", err) return fmt.Errorf("checking for IPv6 support: %w", err)
} }
if err = allSettings.Validate(storage, ipv6Supported); err != nil {
return fmt.Errorf("validating settings: %w", err)
}
// Unused by this CLI command // Unused by this CLI command
unzipper := (Unzipper)(nil) unzipper := (Unzipper)(nil)
client := (*http.Client)(nil) client := (*http.Client)(nil)

View File

@@ -39,6 +39,7 @@ var (
ErrWireguardEndpointPortNotSet = errors.New("endpoint port is not set") ErrWireguardEndpointPortNotSet = errors.New("endpoint port is not set")
ErrWireguardEndpointPortSet = errors.New("endpoint port is set") ErrWireguardEndpointPortSet = errors.New("endpoint port is set")
ErrWireguardInterfaceAddressNotSet = errors.New("interface address is not set") ErrWireguardInterfaceAddressNotSet = errors.New("interface address is not set")
ErrWireguardInterfaceAddressIPv6 = errors.New("interface address is IPv6 but IPv6 is not supported")
ErrWireguardInterfaceNotValid = errors.New("interface name is not valid") ErrWireguardInterfaceNotValid = errors.New("interface name is not valid")
ErrWireguardPreSharedKeyNotSet = errors.New("pre-shared key is not set") ErrWireguardPreSharedKeyNotSet = errors.New("pre-shared key is not set")
ErrWireguardPrivateKeyNotSet = errors.New("private key is not set") ErrWireguardPrivateKeyNotSet = errors.New("private key is not set")

View File

@@ -31,7 +31,7 @@ type Storage interface {
// Validate validates all the settings and returns an error // Validate validates all the settings and returns an error
// if one of them is not valid. // if one of them is not valid.
// TODO v4 remove pointer for receiver (because of Surfshark). // TODO v4 remove pointer for receiver (because of Surfshark).
func (s *Settings) Validate(storage Storage) (err error) { func (s *Settings) Validate(storage Storage, ipv6Supported bool) (err error) {
nameToValidation := map[string]func() error{ nameToValidation := map[string]func() error{
"control server": s.ControlServer.validate, "control server": s.ControlServer.validate,
"dns": s.DNS.validate, "dns": s.DNS.validate,
@@ -46,7 +46,7 @@ func (s *Settings) Validate(storage Storage) (err error) {
"version": s.Version.validate, "version": s.Version.validate,
// Pprof validation done in pprof constructor // Pprof validation done in pprof constructor
"VPN": func() error { "VPN": func() error {
return s.VPN.Validate(storage) return s.VPN.Validate(storage, ipv6Supported)
}, },
} }
@@ -95,7 +95,7 @@ func (s *Settings) MergeWith(other Settings) {
} }
func (s *Settings) OverrideWith(other Settings, func (s *Settings) OverrideWith(other Settings,
storage Storage) (err error) { storage Storage, ipv6Supported bool) (err error) {
patchedSettings := s.copy() patchedSettings := s.copy()
patchedSettings.ControlServer.overrideWith(other.ControlServer) patchedSettings.ControlServer.overrideWith(other.ControlServer)
patchedSettings.DNS.overrideWith(other.DNS) patchedSettings.DNS.overrideWith(other.DNS)
@@ -110,7 +110,7 @@ func (s *Settings) OverrideWith(other Settings,
patchedSettings.Version.overrideWith(other.Version) patchedSettings.Version.overrideWith(other.Version)
patchedSettings.VPN.OverrideWith(other.VPN) patchedSettings.VPN.OverrideWith(other.VPN)
patchedSettings.Pprof.OverrideWith(other.Pprof) patchedSettings.Pprof.OverrideWith(other.Pprof)
err = patchedSettings.Validate(storage) err = patchedSettings.Validate(storage, ipv6Supported)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -20,7 +20,7 @@ type VPN struct {
} }
// TODO v4 remove pointer for receiver (because of Surfshark). // TODO v4 remove pointer for receiver (because of Surfshark).
func (v *VPN) Validate(storage Storage) (err error) { func (v *VPN) Validate(storage Storage, ipv6Supported bool) (err error) {
// Validate Type // Validate Type
validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard} validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard}
if !helpers.IsOneOf(v.Type, validVPNTypes...) { if !helpers.IsOneOf(v.Type, validVPNTypes...) {
@@ -39,7 +39,7 @@ func (v *VPN) Validate(storage Storage) (err error) {
return fmt.Errorf("OpenVPN settings: %w", err) return fmt.Errorf("OpenVPN settings: %w", err)
} }
} else { } else {
err := v.Wireguard.validate(*v.Provider.Name) err := v.Wireguard.validate(*v.Provider.Name, ipv6Supported)
if err != nil { if err != nil {
return fmt.Errorf("Wireguard settings: %w", err) return fmt.Errorf("Wireguard settings: %w", err)
} }

View File

@@ -38,7 +38,7 @@ var regexpInterfaceName = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
// Validate validates Wireguard settings. // Validate validates Wireguard settings.
// It should only be ran if the VPN type chosen is Wireguard. // It should only be ran if the VPN type chosen is Wireguard.
func (w Wireguard) validate(vpnProvider string) (err error) { func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error) {
if !helpers.IsOneOf(vpnProvider, if !helpers.IsOneOf(vpnProvider,
providers.Custom, providers.Custom,
providers.Ivpn, providers.Ivpn,
@@ -82,6 +82,12 @@ func (w Wireguard) validate(vpnProvider string) (err error) {
return fmt.Errorf("%w: for address at index %d: %s", return fmt.Errorf("%w: for address at index %d: %s",
ErrWireguardInterfaceAddressNotSet, i, ipNet.String()) ErrWireguardInterfaceAddressNotSet, i, ipNet.String())
} }
ipv6Net := ipNet.IP.To4() == nil
if ipv6Net && !ipv6Supported {
return fmt.Errorf("%w: address %s",
ErrWireguardInterfaceAddressIPv6, ipNet)
}
} }
// Validate interface // Validate interface

View File

@@ -16,10 +16,11 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool,
updaterLooper UpdaterLooper, updaterLooper UpdaterLooper,
publicIPLooper PublicIPLoop, publicIPLooper PublicIPLoop,
storage Storage, storage Storage,
ipv6Supported bool,
) http.Handler { ) http.Handler {
handler := &handler{} handler := &handler{}
vpn := newVPNHandler(ctx, vpnLooper, storage, logger) vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger)
openvpn := newOpenvpnHandler(ctx, vpnLooper, pfGetter, logger) openvpn := newOpenvpnHandler(ctx, vpnLooper, pfGetter, logger)
dns := newDNSHandler(ctx, unboundLooper, logger) dns := newDNSHandler(ctx, unboundLooper, logger)
updater := newUpdaterHandler(ctx, updaterLooper, logger) updater := newUpdaterHandler(ctx, updaterLooper, logger)

View File

@@ -11,10 +11,12 @@ import (
func New(ctx context.Context, address string, logEnabled bool, logger Logger, func New(ctx context.Context, address string, logEnabled bool, logger Logger,
buildInfo models.BuildInformation, openvpnLooper VPNLooper, buildInfo models.BuildInformation, openvpnLooper VPNLooper,
pfGetter PortForwardedGetter, unboundLooper DNSLoop, pfGetter PortForwardedGetter, unboundLooper DNSLoop,
updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage) ( updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage,
ipv6Supported bool) (
server *httpserver.Server, err error) { server *httpserver.Server, err error) {
handler := newHandler(ctx, logger, logEnabled, buildInfo, handler := newHandler(ctx, logger, logEnabled, buildInfo,
openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper, storage) openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper,
storage, ipv6Supported)
httpServerSettings := httpserver.Settings{ httpServerSettings := httpserver.Settings{
Address: address, Address: address,

View File

@@ -10,20 +10,22 @@ import (
) )
func newVPNHandler(ctx context.Context, looper VPNLooper, func newVPNHandler(ctx context.Context, looper VPNLooper,
storage Storage, w warner) http.Handler { storage Storage, ipv6Supported bool, w warner) http.Handler {
return &vpnHandler{ return &vpnHandler{
ctx: ctx, ctx: ctx,
looper: looper, looper: looper,
storage: storage, storage: storage,
warner: w, ipv6Supported: ipv6Supported,
warner: w,
} }
} }
type vpnHandler struct { type vpnHandler struct {
ctx context.Context //nolint:containedctx ctx context.Context //nolint:containedctx
looper VPNLooper looper VPNLooper
storage Storage storage Storage
warner warner ipv6Supported bool
warner warner
} }
func (h *vpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *vpnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -114,7 +116,7 @@ func (h *vpnHandler) patchSettings(w http.ResponseWriter, r *http.Request) {
updatedSettings := h.looper.GetSettings() // already copied updatedSettings := h.looper.GetSettings() // already copied
updatedSettings.OverrideWith(overrideSettings) updatedSettings.OverrideWith(overrideSettings)
err = updatedSettings.Validate(h.storage) err = updatedSettings.Validate(h.storage, h.ipv6Supported)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return