diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index c9907ef9..ca96d5fe 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -104,6 +104,18 @@ func _main(background context.Context, args []string) int { routingConf.SetDebug() } + defaultInterface, defaultGateway, err := routingConf.DefaultRoute() + if err != nil { + fatalOnError(err) + } + + localSubnet, err := routingConf.LocalSubnet() + if err != nil { + fatalOnError(err) + } + + firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localSubnet) + if err := ovpnConf.CheckTUN(); err != nil { logger.Warn(err) err = ovpnConf.CreateTUN() diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index 21b34350..f8df1c6e 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -62,15 +62,6 @@ func (c *configurator) fallbackToDisabled(ctx context.Context) { } func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocognit - defaultInterface, defaultGateway, err := c.routing.DefaultRoute() - if err != nil { - return fmt.Errorf("cannot enable firewall: %w", err) - } - localSubnet, err := c.routing.LocalSubnet() - if err != nil { - return fmt.Errorf("cannot enable firewall: %w", err) - } - if err = c.setAllPolicies(ctx, "DROP"); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } @@ -95,30 +86,30 @@ func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocogn return fmt.Errorf("cannot enable firewall: %w", err) } for _, conn := range c.vpnConnections { - if err = c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil { + if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, conn, remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } } if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } - if err := c.acceptInputFromSubnetToSubnet(ctx, "*", localSubnet, localSubnet, remove); err != nil { + if err := c.acceptInputFromSubnetToSubnet(ctx, "*", c.localSubnet, c.localSubnet, remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } - if err := c.acceptOutputFromSubnetToSubnet(ctx, "*", localSubnet, localSubnet, remove); err != nil { + if err := c.acceptOutputFromSubnetToSubnet(ctx, "*", c.localSubnet, c.localSubnet, remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } for _, subnet := range c.allowedSubnets { - if err := c.acceptInputFromSubnetToSubnet(ctx, defaultInterface, subnet, localSubnet, remove); err != nil { + if err := c.acceptInputFromSubnetToSubnet(ctx, c.defaultInterface, subnet, c.localSubnet, remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } - if err := c.acceptOutputFromSubnetToSubnet(ctx, defaultInterface, localSubnet, subnet, remove); err != nil { + if err := c.acceptOutputFromSubnetToSubnet(ctx, c.defaultInterface, c.localSubnet, subnet, remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } } // Re-ensure all routes exist for _, subnet := range c.allowedSubnets { - if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil { + if err := c.routing.AddRouteVia(ctx, subnet, c.defaultGateway, c.defaultInterface); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } } diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index b8a478f7..6339ac8a 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -22,15 +22,21 @@ type Configurator interface { RemoveAllowedPort(ctx context.Context, port uint16) (err error) SetPortForward(ctx context.Context, port uint16) (err error) SetDebug() + // SetNetworkInformation is meant to be called only once + SetNetworkInformation(defaultInterface string, defaultGateway net.IP, localSubnet net.IPNet) } type configurator struct { //nolint:maligned - commander command.Commander - logger logging.Logger - routing routing.Routing - fileManager files.FileManager // for custom iptables rules - iptablesMutex sync.Mutex - debug bool + commander command.Commander + logger logging.Logger + routing routing.Routing + fileManager files.FileManager // for custom iptables rules + iptablesMutex sync.Mutex + debug bool + defaultInterface string + defaultGateway net.IP + localSubnet net.IPNet + networkInfoMutex sync.Mutex // State enabled bool @@ -55,3 +61,11 @@ func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager func (c *configurator) SetDebug() { c.debug = true } + +func (c *configurator) SetNetworkInformation(defaultInterface string, defaultGateway net.IP, localSubnet net.IPNet) { + c.networkInfoMutex.Lock() + defer c.networkInfoMutex.Unlock() + c.defaultInterface = defaultInterface + c.defaultGateway = defaultGateway + c.localSubnet = localSubnet +} diff --git a/internal/firewall/subnets.go b/internal/firewall/subnets.go index 9fa25222..d8715776 100644 --- a/internal/firewall/subnets.go +++ b/internal/firewall/subnets.go @@ -12,9 +12,7 @@ func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNe if !c.enabled { c.logger.Info("firewall disabled, only updating allowed subnets internal list and updating routes") - if err := c.updateSubnetRoutes(ctx, c.allowedSubnets, subnets); err != nil { - return err - } + c.updateSubnetRoutes(ctx, c.allowedSubnets, subnets) c.allowedSubnets = make([]net.IPNet, len(subnets)) copy(c.allowedSubnets, subnets) return nil @@ -28,17 +26,8 @@ func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNe return nil } - defaultInterface, defaultGateway, err := c.routing.DefaultRoute() - if err != nil { - return fmt.Errorf("cannot set allowed subnets through firewall: %w", err) - } - localSubnet, err := c.routing.LocalSubnet() - if err != nil { - return fmt.Errorf("cannot set allowed subnets through firewall: %w", err) - } - - c.removeSubnets(ctx, subnetsToRemove, defaultInterface, localSubnet) - if err := c.addSubnets(ctx, subnetsToAdd, defaultInterface, defaultGateway, localSubnet); err != nil { + c.removeSubnets(ctx, subnetsToRemove, c.defaultInterface, c.localSubnet) + if err := c.addSubnets(ctx, subnetsToAdd, c.defaultInterface, c.defaultGateway, c.localSubnet); err != nil { return fmt.Errorf("cannot set allowed subnets through firewall: %w", err) } @@ -135,15 +124,12 @@ func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defa return nil } -func (c *configurator) updateSubnetRoutes(ctx context.Context, oldSubnets, newSubnets []net.IPNet) error { +// updateSubnetRoutes does not return an error in order to try to run as many route commands as possible +func (c *configurator) updateSubnetRoutes(ctx context.Context, oldSubnets, newSubnets []net.IPNet) { subnetsToAdd := findSubnetsToAdd(oldSubnets, newSubnets) subnetsToRemove := findSubnetsToRemove(oldSubnets, newSubnets) if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 { - return nil - } - defaultInterface, defaultGateway, err := c.routing.DefaultRoute() - if err != nil { - return err + return } for _, subnet := range subnetsToRemove { if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil { @@ -151,9 +137,8 @@ func (c *configurator) updateSubnetRoutes(ctx context.Context, oldSubnets, newSu } } for _, subnet := range subnetsToAdd { - if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil { + if err := c.routing.AddRouteVia(ctx, subnet, c.defaultGateway, c.defaultInterface); err != nil { c.logger.Error("cannot add route for subnet: %s", err) } } - return nil } diff --git a/internal/firewall/vpn.go b/internal/firewall/vpn.go index c1d58c58..fa123c03 100644 --- a/internal/firewall/vpn.go +++ b/internal/firewall/vpn.go @@ -26,13 +26,8 @@ func (c *configurator) SetVPNConnections(ctx context.Context, connections []mode return nil } - defaultInterface, _, err := c.routing.DefaultRoute() - if err != nil { - return fmt.Errorf("cannot set VPN connections through firewall: %w", err) - } - - c.removeConnections(ctx, connectionsToRemove, defaultInterface) - if err := c.addConnections(ctx, connectionsToAdd, defaultInterface); err != nil { + c.removeConnections(ctx, connectionsToRemove, c.defaultInterface) + if err := c.addConnections(ctx, connectionsToAdd, c.defaultInterface); err != nil { return fmt.Errorf("cannot set VPN connections through firewall: %w", err) }