diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index c76151ec..9bb265ae 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -184,7 +184,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, } routingConf := routing.New(netLinker, routingLogger) - defaultInterface, defaultGateway, err := routingConf.DefaultRoute() + defaultRoutes, err := routingConf.DefaultRoutes() if err != nil { return err } @@ -194,11 +194,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, return err } - defaultIP, err := routingConf.DefaultIP() - if err != nil { - return err - } - firewallLogger := logger.NewChild(logging.Settings{ Prefix: "firewall: ", }) @@ -206,7 +201,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, firewallLogger.PatchLevel(logging.LevelDebug) } firewallConf, err := firewall.NewConfig(ctx, firewallLogger, cmder, - defaultInterface, defaultGateway, localNetworks, defaultIP) + defaultRoutes, localNetworks) if err != nil { return err } @@ -321,9 +316,11 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, } for _, port := range allSettings.Firewall.InputPorts { - err = firewallConf.SetAllowedPort(ctx, port, defaultInterface) - if err != nil { - return err + for _, defaultRoute := range defaultRoutes { + err = firewallConf.SetAllowedPort(ctx, port, defaultRoute.NetInterface) + if err != nil { + return err + } } } // TODO move inside firewall? diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index ee2cb318..bf8b20ca 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -96,13 +96,9 @@ func (c *Config) enable(ctx context.Context) (err error) { if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil { return err } - if c.vpnConnection.IP != nil { - if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil { - return err - } - if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil { - return err - } + + if err = c.allowVPNIP(ctx); err != nil { + return err } for _, network := range c.localNetworks { @@ -111,10 +107,8 @@ func (c *Config) enable(ctx context.Context) (err error) { } } - for _, subnet := range c.outboundSubnets { - if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil { - return err - } + if err = c.allowOutboundSubnets(ctx); err != nil { + return err } // Allows packets from any IP address to go through eth0 / local network @@ -125,10 +119,8 @@ func (c *Config) enable(ctx context.Context) (err error) { } } - for port, intf := range c.allowedInputPorts { - if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil { - return err - } + if err = c.allowInputPorts(ctx); err != nil { + return err } if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil { @@ -137,3 +129,47 @@ func (c *Config) enable(ctx context.Context) (err error) { return nil } + +func (c *Config) allowVPNIP(ctx context.Context) (err error) { + if c.vpnConnection.IP == nil { + return nil + } + + const remove = false + for _, defaultRoute := range c.defaultRoutes { + err = c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove) + if err != nil { + return fmt.Errorf("cannot accept output traffic through VPN: %w", err) + } + } + + return nil +} + +func (c *Config) allowOutboundSubnets(ctx context.Context) (err error) { + for _, subnet := range c.outboundSubnets { + for _, defaultRoute := range c.defaultRoutes { + const remove = false + err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface, + defaultRoute.AssignedIP, subnet, remove) + if err != nil { + return err + } + } + } + return nil +} + +func (c *Config) allowInputPorts(ctx context.Context) (err error) { + for port, netInterfaces := range c.allowedInputPorts { + for netInterface := range netInterfaces { + const remove = false + err = c.acceptInputToPort(ctx, netInterface, port, remove) + if err != nil { + return fmt.Errorf("cannot accept input port %d on interface %s: %w", + port, netInterface, err) + } + } + } + return nil +} diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index 45698757..3b6cb3f7 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -23,14 +23,12 @@ type Configurator interface { } type Config struct { //nolint:maligned - runner command.Runner - logger Logger - iptablesMutex sync.Mutex - ip6tablesMutex sync.Mutex - defaultInterface string - defaultGateway net.IP - localNetworks []routing.LocalNetwork - localIP net.IP + runner command.Runner + logger Logger + iptablesMutex sync.Mutex + ip6tablesMutex sync.Mutex + defaultRoutes []routing.DefaultRoute + localNetworks []routing.LocalNetwork // Fixed state ipTables string @@ -42,16 +40,15 @@ type Config struct { //nolint:maligned vpnConnection models.Connection vpnIntf string outboundSubnets []net.IPNet - allowedInputPorts map[uint16]string // port to interface mapping + allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping stateMutex sync.Mutex } // NewConfig creates a new Config instance and returns an error // if no iptables implementation is available. func NewConfig(ctx context.Context, logger Logger, - runner command.Runner, defaultInterface string, - defaultGateway net.IP, localNetworks []routing.LocalNetwork, - localIP net.IP) (config *Config, err error) { + runner command.Runner, defaultRoutes []routing.DefaultRoute, + localNetworks []routing.LocalNetwork) (config *Config, err error) { iptables, err := findIptablesSupported(ctx, runner) if err != nil { return nil, err @@ -60,14 +57,12 @@ func NewConfig(ctx context.Context, logger Logger, return &Config{ runner: runner, logger: logger, - allowedInputPorts: make(map[uint16]string), + allowedInputPorts: make(map[uint16]map[string]struct{}), ipTables: iptables, ip6Tables: findIP6tablesSupported(ctx, runner), customRulesPath: "/iptables/post-rules.txt", // Obtained from routing - defaultInterface: defaultInterface, - defaultGateway: defaultGateway, - localNetworks: localNetworks, - localIP: localIP, + defaultRoutes: defaultRoutes, + localNetworks: localNetworks, }, nil } diff --git a/internal/firewall/outboundsubnets.go b/internal/firewall/outboundsubnets.go index c204a5eb..c5291f73 100644 --- a/internal/firewall/outboundsubnets.go +++ b/internal/firewall/outboundsubnets.go @@ -41,9 +41,13 @@ func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (e func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet) { const remove = true for _, subNet := range subnets { - if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subNet, remove); err != nil { - c.logger.Error("cannot remove outdated outbound subnet: " + err.Error()) - continue + for _, defaultRoute := range c.defaultRoutes { + err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface, + defaultRoute.AssignedIP, subNet, remove) + if err != nil { + c.logger.Error("cannot remove outdated outbound subnet: " + err.Error()) + continue + } } c.outboundSubnets = subnet.RemoveSubnetFromSubnets(c.outboundSubnets, subNet) } @@ -52,8 +56,12 @@ func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet) func (c *Config) addOutboundSubnets(ctx context.Context, subnets []net.IPNet) error { const remove = false for _, subnet := range subnets { - if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil { - return err + for _, defaultRoute := range c.defaultRoutes { + err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface, + defaultRoute.AssignedIP, subnet, remove) + if err != nil { + return err + } } c.outboundSubnets = append(c.outboundSubnets, subnet) } diff --git a/internal/firewall/ports.go b/internal/firewall/ports.go index e6b51e91..efed4552 100644 --- a/internal/firewall/ports.go +++ b/internal/firewall/ports.go @@ -21,27 +21,30 @@ func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) ( if !c.enabled { c.logger.Info("firewall disabled, only updating allowed ports internal state") - c.allowedInputPorts[port] = intf + existingInterfaces, ok := c.allowedInputPorts[port] + if !ok { + existingInterfaces = make(map[string]struct{}) + } + existingInterfaces[intf] = struct{}{} + c.allowedInputPorts[port] = existingInterfaces + return nil + } + + netInterfaces, has := c.allowedInputPorts[port] + if !has { + netInterfaces = make(map[string]struct{}) + } else if _, exists := netInterfaces[intf]; exists { return nil } c.logger.Info("setting allowed input port " + fmt.Sprint(port) + " through interface " + intf + "...") - if existingIntf, ok := c.allowedInputPorts[port]; ok { - if intf == existingIntf { - return nil - } - const remove = true - if err := c.acceptInputToPort(ctx, existingIntf, port, remove); err != nil { - return fmt.Errorf("cannot remove old allowed port %d: %w", port, err) - } - } - const remove = false if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil { - return fmt.Errorf("cannot allow input to port %d: %w", port, err) + return fmt.Errorf("cannot allow input to port %d through interface %s: %w", + port, intf, err) } - c.allowedInputPorts[port] = intf + netInterfaces[intf] = struct{}{} return nil } @@ -60,17 +63,24 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error) return nil } - c.logger.Info("removing allowed port " + strconv.Itoa(int(port)) + " ...") + c.logger.Info("removing allowed port " + strconv.Itoa(int(port)) + "...") - intf, ok := c.allowedInputPorts[port] + interfacesSet, ok := c.allowedInputPorts[port] if !ok { return nil } const remove = true - if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil { - return fmt.Errorf("cannot remove allowed port %d: %w", port, err) + for netInterface := range interfacesSet { + err := c.acceptInputToPort(ctx, netInterface, port, remove) + if err != nil { + return fmt.Errorf("cannot remove allowed port %d on interface %s: %w", + port, netInterface, err) + } + delete(interfacesSet, netInterface) } + + // All interfaces were removed successfully, so remove the port entry. delete(c.allowedInputPorts, port) return nil diff --git a/internal/firewall/vpn.go b/internal/firewall/vpn.go index 96ae1cf1..9c48fc35 100644 --- a/internal/firewall/vpn.go +++ b/internal/firewall/vpn.go @@ -31,8 +31,10 @@ func (c *Config) SetVPNConnection(ctx context.Context, remove := true if c.vpnConnection.IP != nil { - if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil { - c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error()) + for _, defaultRoute := range c.defaultRoutes { + if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil { + c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error()) + } } } c.vpnConnection = models.Connection{} @@ -46,8 +48,10 @@ func (c *Config) SetVPNConnection(ctx context.Context, remove = false - if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil { - return fmt.Errorf("cannot allow output traffic through VPN connection: %w", err) + for _, defaultRoute := range c.defaultRoutes { + if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil { + return fmt.Errorf("cannot allow output traffic through VPN connection: %w", err) + } } c.vpnConnection = connection diff --git a/internal/routing/default.go b/internal/routing/default.go index e8e2d0b6..1baec75d 100644 --- a/internal/routing/default.go +++ b/internal/routing/default.go @@ -13,56 +13,60 @@ var ( ) type DefaultRouteGetter interface { - DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) + DefaultRoutes() (defaultRoutes []DefaultRoute, err error) } -func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) { - routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) - if err != nil { - return "", nil, fmt.Errorf("cannot list routes: %w", err) - } - for _, route := range routes { - if route.Dst == nil { - defaultGateway = route.Gw - linkIndex := route.LinkIndex - link, err := r.netLinker.LinkByIndex(linkIndex) - if err != nil { - return "", nil, fmt.Errorf("cannot obtain link by index: for default route at index %d: %w", linkIndex, err) - } - attributes := link.Attrs() - defaultInterface = attributes.Name - r.logger.Info("default route found: interface " + defaultInterface + - ", gateway " + defaultGateway.String()) - return defaultInterface, defaultGateway, nil - } - } - return "", nil, fmt.Errorf("%w: in %d route(s)", ErrRouteDefaultNotFound, len(routes)) +type DefaultRoute struct { + NetInterface string + Gateway net.IP + AssignedIP net.IP } -type DefaultIPGetter interface { - DefaultIP() (defaultIP net.IP, err error) +func (d DefaultRoute) String() string { + return fmt.Sprintf("interface %s, gateway %s and assigned IP %s", + d.NetInterface, d.Gateway, d.AssignedIP) } -func (r *Routing) DefaultIP() (ip net.IP, err error) { +func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) { routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) if err != nil { return nil, fmt.Errorf("cannot list routes: %w", err) } - defaultLinkName := "" for _, route := range routes { if route.Dst == nil { + defaultRoute := DefaultRoute{ + Gateway: route.Gw, + } linkIndex := route.LinkIndex link, err := r.netLinker.LinkByIndex(linkIndex) if err != nil { - return nil, fmt.Errorf("cannot find link by index: for default route at index %d: %w", linkIndex, err) + return nil, fmt.Errorf("cannot obtain link by index: for default route at index %d: %w", linkIndex, err) } - defaultLinkName = link.Attrs().Name + attributes := link.Attrs() + defaultRoute.NetInterface = attributes.Name + + defaultRoute.AssignedIP, err = r.assignedIP(defaultRoute.NetInterface) + if err != nil { + return nil, fmt.Errorf("cannot get assigned IP of %s: %w", defaultRoute.NetInterface, err) + } + + r.logger.Info("default route found: " + defaultRoute.String()) + defaultRoutes = append(defaultRoutes, defaultRoute) } } - if defaultLinkName == "" { - return nil, fmt.Errorf("%w: in %d route(s)", ErrLinkDefaultNotFound, len(routes)) + + if len(defaultRoutes) == 0 { + return nil, fmt.Errorf("%w: in %d route(s)", ErrRouteDefaultNotFound, len(routes)) } - return r.assignedIP(defaultLinkName) + return defaultRoutes, nil +} + +func DefaultRoutesInterfaces(defaultRoutes []DefaultRoute) (interfaces []string) { + interfaces = make([]string, len(defaultRoutes)) + for i := range defaultRoutes { + interfaces[i] = defaultRoutes[i].NetInterface + } + return interfaces } diff --git a/internal/routing/enable.go b/internal/routing/enable.go index c330eb9c..e4c93437 100644 --- a/internal/routing/enable.go +++ b/internal/routing/enable.go @@ -9,9 +9,9 @@ type Setuper interface { } func (r *Routing) Setup() (err error) { - defaultInterfaceName, defaultGateway, err := r.DefaultRoute() + defaultRoutes, err := r.DefaultRoutes() if err != nil { - return fmt.Errorf("cannot get default route: %w", err) + return fmt.Errorf("cannot get default routes: %w", err) } touched := false @@ -25,7 +25,7 @@ func (r *Routing) Setup() (err error) { touched = true - err = r.routeInboundFromDefault(defaultGateway, defaultInterfaceName) + err = r.routeInboundFromDefault(defaultRoutes) if err != nil { return fmt.Errorf("cannot add routes for inbound traffic from default IP: %w", err) } @@ -33,7 +33,7 @@ func (r *Routing) Setup() (err error) { r.stateMutex.RLock() outboundSubnets := r.outboundSubnets r.stateMutex.RUnlock() - if err := r.setOutboundRoutes(outboundSubnets, defaultInterfaceName, defaultGateway); err != nil { + if err := r.setOutboundRoutes(outboundSubnets, defaultRoutes); err != nil { return fmt.Errorf("cannot set outbound subnets routes: %w", err) } @@ -45,17 +45,17 @@ type TearDowner interface { } func (r *Routing) TearDown() error { - defaultInterfaceName, defaultGateway, err := r.DefaultRoute() + defaultRoutes, err := r.DefaultRoutes() if err != nil { return fmt.Errorf("cannot get default route: %w", err) } - err = r.unrouteInboundFromDefault(defaultGateway, defaultInterfaceName) + err = r.unrouteInboundFromDefault(defaultRoutes) if err != nil { return fmt.Errorf("cannot remove routes for inbound traffic from default IP: %w", err) } - if err := r.setOutboundRoutes(nil, defaultInterfaceName, defaultGateway); err != nil { + if err := r.setOutboundRoutes(nil, defaultRoutes); err != nil { return fmt.Errorf("cannot set outbound subnets routes: %w", err) } diff --git a/internal/routing/inbound.go b/internal/routing/inbound.go index 2ce77bab..21a91619 100644 --- a/internal/routing/inbound.go +++ b/internal/routing/inbound.go @@ -12,61 +12,62 @@ const ( inboundPriority = 100 ) -func (r *Routing) routeInboundFromDefault(defaultGateway net.IP, - defaultInterface string) (err error) { - if err := r.addRuleInboundFromDefault(inboundTable); err != nil { +func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err error) { + if err := r.addRuleInboundFromDefault(inboundTable, defaultRoutes); err != nil { return fmt.Errorf("cannot add rule: %w", err) } defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)} - if err := r.addRouteVia(defaultDestination, defaultGateway, defaultInterface, inboundTable); err != nil { - return fmt.Errorf("cannot add route: %w", err) + // TODO IPv6 + + for _, defaultRoute := range defaultRoutes { + err := r.addRouteVia(defaultDestination, defaultRoute.Gateway, defaultRoute.NetInterface, inboundTable) + if err != nil { + return fmt.Errorf("cannot add route: %w", err) + } } return nil } -func (r *Routing) unrouteInboundFromDefault(defaultGateway net.IP, - defaultInterface string) (err error) { +func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err error) { defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)} - if err := r.deleteRouteVia(defaultDestination, defaultGateway, defaultInterface, inboundTable); err != nil { - return fmt.Errorf("cannot delete route: %w", err) + + for _, defaultRoute := range defaultRoutes { + err := r.deleteRouteVia(defaultDestination, defaultRoute.Gateway, defaultRoute.NetInterface, inboundTable) + if err != nil { + return fmt.Errorf("cannot delete route: %w", err) + } } - if err := r.delRuleInboundFromDefault(inboundTable); err != nil { + if err := r.delRuleInboundFromDefault(inboundTable, defaultRoutes); err != nil { return fmt.Errorf("cannot delete rule: %w", err) } return nil } -func (r *Routing) addRuleInboundFromDefault(table int) (err error) { - defaultIP, err := r.DefaultIP() - if err != nil { - return fmt.Errorf("cannot find default IP: %w", err) - } - - defaultIPMasked32 := netlink.NewIPNet(defaultIP) - ruleDstNet := (*net.IPNet)(nil) - err = r.addIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority) - if err != nil { - return fmt.Errorf("cannot add rule: %w", err) +func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) { + for _, defaultRoute := range defaultRoutes { + defaultIPMasked32 := netlink.NewIPNet(defaultRoute.AssignedIP) + ruleDstNet := (*net.IPNet)(nil) + err = r.addIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority) + if err != nil { + return fmt.Errorf("cannot add rule for default route %s: %w", defaultRoute, err) + } } return nil } -func (r *Routing) delRuleInboundFromDefault(table int) (err error) { - defaultIP, err := r.DefaultIP() - if err != nil { - return fmt.Errorf("cannot find default IP: %w", err) - } - - defaultIPMasked32 := netlink.NewIPNet(defaultIP) - ruleDstNet := (*net.IPNet)(nil) - err = r.deleteIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority) - if err != nil { - return fmt.Errorf("cannot delete rule: %w", err) +func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) { + for _, defaultRoute := range defaultRoutes { + defaultIPMasked32 := netlink.NewIPNet(defaultRoute.AssignedIP) + ruleDstNet := (*net.IPNet)(nil) + err = r.deleteIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority) + if err != nil { + return fmt.Errorf("cannot delete rule for default route %s: %w", defaultRoute, err) + } } return nil diff --git a/internal/routing/outbound.go b/internal/routing/outbound.go index 39ec3af8..9f42ea7a 100644 --- a/internal/routing/outbound.go +++ b/internal/routing/outbound.go @@ -17,15 +17,15 @@ type OutboundRoutesSetter interface { } func (r *Routing) SetOutboundRoutes(outboundSubnets []net.IPNet) error { - defaultInterface, defaultGateway, err := r.DefaultRoute() + defaultRoutes, err := r.DefaultRoutes() if err != nil { return err } - return r.setOutboundRoutes(outboundSubnets, defaultInterface, defaultGateway) + return r.setOutboundRoutes(outboundSubnets, defaultRoutes) } func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet, - defaultInterfaceName string, defaultGateway net.IP) (err error) { + defaultRoutes []DefaultRoute) (err error) { r.stateMutex.Lock() defer r.stateMutex.Unlock() @@ -36,12 +36,12 @@ func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet, return nil } - warnings := r.removeOutboundSubnets(subnetsToRemove, defaultInterfaceName, defaultGateway) + warnings := r.removeOutboundSubnets(subnetsToRemove, defaultRoutes) for _, warning := range warnings { r.logger.Warn("cannot remove outdated outbound subnet from routing: " + warning) } - err = r.addOutboundSubnets(subnetsToAdd, defaultInterfaceName, defaultGateway) + err = r.addOutboundSubnets(subnetsToAdd, defaultRoutes) if err != nil { return fmt.Errorf("cannot add outbound subnet to routes: %w", err) } @@ -50,17 +50,19 @@ func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet, } func (r *Routing) removeOutboundSubnets(subnets []net.IPNet, - defaultInterfaceName string, defaultGateway net.IP) (warnings []string) { + defaultRoutes []DefaultRoute) (warnings []string) { for i, subNet := range subnets { - err := r.deleteRouteVia(subNet, defaultGateway, defaultInterfaceName, outboundTable) - if err != nil { - warnings = append(warnings, err.Error()) - continue + for _, defaultRoute := range defaultRoutes { + err := r.deleteRouteVia(subNet, defaultRoute.Gateway, defaultRoute.NetInterface, outboundTable) + if err != nil { + warnings = append(warnings, err.Error()) + continue + } } ruleSrcNet := (*net.IPNet)(nil) ruleDstNet := &subnets[i] - err = r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority) + err := r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority) if err != nil { warnings = append(warnings, "cannot delete rule: for subnet "+subNet.String()+": "+err.Error()) @@ -74,11 +76,13 @@ func (r *Routing) removeOutboundSubnets(subnets []net.IPNet, } func (r *Routing) addOutboundSubnets(subnets []net.IPNet, - defaultInterfaceName string, defaultGateway net.IP) error { + defaultRoutes []DefaultRoute) (err error) { for i, subnet := range subnets { - err := r.addRouteVia(subnet, defaultGateway, defaultInterfaceName, outboundTable) - if err != nil { - return fmt.Errorf("cannot add route for subnet %s: %w", subnet, err) + for _, defaultRoute := range defaultRoutes { + err = r.addRouteVia(subnet, defaultRoute.Gateway, defaultRoute.NetInterface, outboundTable) + if err != nil { + return fmt.Errorf("cannot add route for subnet %s: %w", subnet, err) + } } ruleSrcNet := (*net.IPNet)(nil) diff --git a/internal/routing/routing.go b/internal/routing/routing.go index 9a44c9b1..27ccea95 100644 --- a/internal/routing/routing.go +++ b/internal/routing/routing.go @@ -15,7 +15,6 @@ type ReadWriter interface { type Reader interface { DefaultRouteGetter - DefaultIPGetter LocalSubnetGetter LocalNetworksGetter VPNGetter