diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index bf5150ec..1264d173 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -248,7 +248,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, Prefix: "firewall: ", Level: firewallLogLevel, }) - firewallConf := firewall.NewConfigurator(firewallLogger, cmder, routingConf, + firewallConf := firewall.NewConfig(firewallLogger, cmder, routingConf, defaultInterface, defaultGateway, localNetworks, defaultIP) if err := routingConf.Setup(); err != nil { diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index 004bc295..c01c4180 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -14,7 +14,11 @@ var ( ErrUserPostRules = errors.New("cannot run user post firewall rules") ) -func (c *configurator) SetEnabled(ctx context.Context, enabled bool) (err error) { +type Enabler interface { + SetEnabled(ctx context.Context, enabled bool) (err error) +} + +func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) { c.stateMutex.Lock() defer c.stateMutex.Unlock() @@ -48,7 +52,7 @@ func (c *configurator) SetEnabled(ctx context.Context, enabled bool) (err error) return nil } -func (c *configurator) disable(ctx context.Context) (err error) { +func (c *Config) disable(ctx context.Context) (err error) { if err = c.clearAllRules(ctx); err != nil { return fmt.Errorf("cannot disable firewall: %w", err) } @@ -62,7 +66,7 @@ func (c *configurator) disable(ctx context.Context) (err error) { } // To use in defered call when enabling the firewall. -func (c *configurator) fallbackToDisabled(ctx context.Context) { +func (c *Config) fallbackToDisabled(ctx context.Context) { if ctx.Err() != nil { return } @@ -71,7 +75,7 @@ func (c *configurator) fallbackToDisabled(ctx context.Context) { } } -func (c *configurator) enable(ctx context.Context) (err error) { +func (c *Config) enable(ctx context.Context) (err error) { touched := false if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index 87ca5e9d..a96c0600 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -13,19 +13,17 @@ import ( "github.com/qdm12/golibs/logging" ) +var _ Configurator = (*Config)(nil) + // Configurator allows to change firewall rules and modify network routes. type Configurator interface { - SetEnabled(ctx context.Context, enabled bool) (err error) - SetVPNConnection(ctx context.Context, connection models.OpenVPNConnection) (err error) - SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) - SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (err error) - RemoveAllowedPort(ctx context.Context, port uint16) (err error) - // SetNetworkInformation is meant to be called only once - SetNetworkInformation(defaultInterface string, defaultGateway net.IP, - localNetworks []routing.LocalNetwork, localIP net.IP) + Enabler + VPNConnectionSetter + PortAllower + OutboundSubnetsSetter } -type configurator struct { //nolint:maligned +type Config struct { //nolint:maligned commander command.Commander logger logging.Logger routing routing.Routing @@ -35,7 +33,6 @@ type configurator struct { //nolint:maligned defaultGateway net.IP localNetworks []routing.LocalNetwork localIP net.IP - networkInfoMutex sync.Mutex // Fixed state ip6Tables bool @@ -49,8 +46,8 @@ type configurator struct { //nolint:maligned stateMutex sync.Mutex } -// NewConfigurator creates a new Configurator instance. -func NewConfigurator(logger logging.Logger, cmder command.Commander, +// NewConfig creates a new Config instance. +func NewConfig(logger logging.Logger, cmder command.Commander, routing routing.Routing, defaultInterface string, defaultGateway net.IP, localNetworks []routing.LocalNetwork, localIP net.IP) *Config { return &Config{ diff --git a/internal/firewall/ip6tables.go b/internal/firewall/ip6tables.go index 8f5c57bb..fcee0ab7 100644 --- a/internal/firewall/ip6tables.go +++ b/internal/firewall/ip6tables.go @@ -23,7 +23,7 @@ func ip6tablesSupported(ctx context.Context, commander command.Commander) (suppo return true } -func (c *configurator) runIP6tablesInstructions(ctx context.Context, instructions []string) error { +func (c *Config) runIP6tablesInstructions(ctx context.Context, instructions []string) error { for _, instruction := range instructions { if err := c.runIP6tablesInstruction(ctx, instruction); err != nil { return err @@ -32,7 +32,7 @@ func (c *configurator) runIP6tablesInstructions(ctx context.Context, instruction return nil } -func (c *configurator) runIP6tablesInstruction(ctx context.Context, instruction string) error { +func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string) error { if !c.ip6Tables { return nil } @@ -51,7 +51,7 @@ func (c *configurator) runIP6tablesInstruction(ctx context.Context, instruction var errPolicyNotValid = errors.New("policy is not valid") -func (c *configurator) setIPv6AllPolicies(ctx context.Context, policy string) error { +func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error { switch policy { case "ACCEPT", "DROP": default: diff --git a/internal/firewall/iptables.go b/internal/firewall/iptables.go index 385a2eaa..026483e4 100644 --- a/internal/firewall/iptables.go +++ b/internal/firewall/iptables.go @@ -61,7 +61,7 @@ func Version(ctx context.Context, commander command.Commander) (string, error) { return words[1], nil } -func (c *configurator) runIptablesInstructions(ctx context.Context, instructions []string) error { +func (c *Config) runIptablesInstructions(ctx context.Context, instructions []string) error { for _, instruction := range instructions { if err := c.runIptablesInstruction(ctx, instruction); err != nil { return err @@ -70,7 +70,7 @@ func (c *configurator) runIptablesInstructions(ctx context.Context, instructions return nil } -func (c *configurator) runIptablesInstruction(ctx context.Context, instruction string) error { +func (c *Config) runIptablesInstruction(ctx context.Context, instruction string) error { c.iptablesMutex.Lock() // only one iptables command at once defer c.iptablesMutex.Unlock() @@ -84,7 +84,7 @@ func (c *configurator) runIptablesInstruction(ctx context.Context, instruction s return nil } -func (c *configurator) clearAllRules(ctx context.Context) error { +func (c *Config) clearAllRules(ctx context.Context) error { if err := c.runMixedIptablesInstructions(ctx, []string{ "--flush", // flush all chains "--delete-chain", // delete all chains @@ -94,7 +94,7 @@ func (c *configurator) clearAllRules(ctx context.Context) error { return nil } -func (c *configurator) setIPv4AllPolicies(ctx context.Context, policy string) error { +func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error { switch policy { case "ACCEPT", "DROP": default: @@ -110,13 +110,13 @@ func (c *configurator) setIPv4AllPolicies(ctx context.Context, policy string) er return nil } -func (c *configurator) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error { +func (c *Config) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error { return c.runMixedIptablesInstruction(ctx, fmt.Sprintf( "%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf, )) } -func (c *configurator) acceptInputToSubnet(ctx context.Context, intf string, destination net.IPNet, remove bool) error { +func (c *Config) acceptInputToSubnet(ctx context.Context, intf string, destination net.IPNet, remove bool) error { isIP4Subnet := destination.IP.To4() != nil interfaceFlag := "-i " + intf @@ -136,20 +136,20 @@ func (c *configurator) acceptInputToSubnet(ctx context.Context, intf string, des return c.runIP6tablesInstruction(ctx, instruction) } -func (c *configurator) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error { +func (c *Config) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error { return c.runMixedIptablesInstruction(ctx, fmt.Sprintf( "%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf, )) } -func (c *configurator) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error { +func (c *Config) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error { return c.runMixedIptablesInstructions(ctx, []string{ fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)), fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)), }) } -func (c *configurator) acceptOutputTrafficToVPN(ctx context.Context, +func (c *Config) acceptOutputTrafficToVPN(ctx context.Context, defaultInterface string, connection models.OpenVPNConnection, remove bool) error { instruction := fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT", appendOrDelete(remove), connection.IP, defaultInterface, connection.Protocol, @@ -164,7 +164,7 @@ func (c *configurator) acceptOutputTrafficToVPN(ctx context.Context, } // Thanks to @npawelek. -func (c *configurator) acceptOutputFromIPToSubnet(ctx context.Context, +func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context, intf string, sourceIP net.IP, destinationSubnet net.IPNet, remove bool) error { doIPv4 := sourceIP.To4() != nil && destinationSubnet.IP.To4() != nil @@ -185,7 +185,7 @@ func (c *configurator) acceptOutputFromIPToSubnet(ctx context.Context, } // Used for port forwarding, with intf set to tun. -func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error { +func (c *Config) acceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error { interfaceFlag := "-i " + intf if intf == "*" { // all interfaces interfaceFlag = "" @@ -196,7 +196,7 @@ func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port }) } -func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error { +func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove bool) error { file, err := os.OpenFile(filepath, os.O_RDONLY, 0) if os.IsNotExist(err) { return nil diff --git a/internal/firewall/iptablesmix.go b/internal/firewall/iptablesmix.go index e3456330..8d45c737 100644 --- a/internal/firewall/iptablesmix.go +++ b/internal/firewall/iptablesmix.go @@ -4,7 +4,7 @@ import ( "context" ) -func (c *configurator) runMixedIptablesInstructions(ctx context.Context, instructions []string) error { +func (c *Config) runMixedIptablesInstructions(ctx context.Context, instructions []string) error { for _, instruction := range instructions { if err := c.runMixedIptablesInstruction(ctx, instruction); err != nil { return err @@ -13,7 +13,7 @@ func (c *configurator) runMixedIptablesInstructions(ctx context.Context, instruc return nil } -func (c *configurator) runMixedIptablesInstruction(ctx context.Context, instruction string) error { +func (c *Config) runMixedIptablesInstruction(ctx context.Context, instruction string) error { if err := c.runIptablesInstruction(ctx, instruction); err != nil { return err } diff --git a/internal/firewall/outboundsubnets.go b/internal/firewall/outboundsubnets.go index 248d6d66..771f1a2b 100644 --- a/internal/firewall/outboundsubnets.go +++ b/internal/firewall/outboundsubnets.go @@ -6,7 +6,11 @@ import ( "net" ) -func (c *configurator) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (err error) { +type OutboundSubnetsSetter interface { + SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (err error) +} + +func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (err error) { c.stateMutex.Lock() defer c.stateMutex.Unlock() @@ -33,7 +37,7 @@ func (c *configurator) SetOutboundSubnets(ctx context.Context, subnets []net.IPN return nil } -func (c *configurator) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet) { +func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet) { const remove = true for _, subnet := range subnets { if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil { @@ -44,7 +48,7 @@ func (c *configurator) removeOutboundSubnets(ctx context.Context, subnets []net. } } -func (c *configurator) addOutboundSubnets(ctx context.Context, subnets []net.IPNet) error { +func (c *Config) addOutboundSubnets(ctx context.Context, subnets []net.IPNet) error { const remove = false for _, subnet := range subnets { if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil { diff --git a/internal/firewall/ports.go b/internal/firewall/ports.go index dcb75ec9..5aee05e0 100644 --- a/internal/firewall/ports.go +++ b/internal/firewall/ports.go @@ -6,7 +6,12 @@ import ( "strconv" ) -func (c *configurator) SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) { +type PortAllower interface { + SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) + RemoveAllowedPort(ctx context.Context, port uint16) (err error) +} + +func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) { c.stateMutex.Lock() defer c.stateMutex.Unlock() @@ -41,7 +46,7 @@ func (c *configurator) SetAllowedPort(ctx context.Context, port uint16, intf str return nil } -func (c *configurator) RemoveAllowedPort(ctx context.Context, port uint16) (err error) { +func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error) { c.stateMutex.Lock() defer c.stateMutex.Unlock() diff --git a/internal/firewall/vpn.go b/internal/firewall/vpn.go index 896d0c13..68824dcf 100644 --- a/internal/firewall/vpn.go +++ b/internal/firewall/vpn.go @@ -7,7 +7,11 @@ import ( "github.com/qdm12/gluetun/internal/models" ) -func (c *configurator) SetVPNConnection(ctx context.Context, connection models.OpenVPNConnection) (err error) { +type VPNConnectionSetter interface { + SetVPNConnection(ctx context.Context, connection models.OpenVPNConnection) error +} + +func (c *Config) SetVPNConnection(ctx context.Context, connection models.OpenVPNConnection) (err error) { c.stateMutex.Lock() defer c.stateMutex.Unlock()