From a13be8f45e2ae42a6323f6857e41048a00073d4f Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 20 Jul 2020 00:39:59 +0000 Subject: [PATCH] Firewall simplifications - Only a map of allowed input port to interface - port forwarded is in the map of allowed input ports - port forwarded has the interface tun0 in this map - Always allow tcp and udp for allowed input ports - Port forward state is in openvpn looper only - Shadowsocks input port allowed on default interface only - Tinyproxy input port allowed on default interface only --- cmd/gluetun/main.go | 4 +- internal/firewall/enable.go | 18 +------- internal/firewall/firewall.go | 24 +++++----- internal/firewall/iptables.go | 9 ++-- internal/firewall/ports.go | 82 +++++++++-------------------------- internal/openvpn/loop.go | 10 +++-- internal/shadowsocks/loop.go | 52 +++++++++++----------- internal/tinyproxy/loop.go | 48 ++++++++++---------- 8 files changed, 99 insertions(+), 148 deletions(-) diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index ca96d5fe..53a8e2e6 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -158,11 +158,11 @@ func _main(background context.Context, args []string) int { go publicIPLooper.RunRestartTicker(ctx) setPublicIPPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker - tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid) + tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid, defaultInterface) restartTinyproxy := tinyproxyLooper.Restart go tinyproxyLooper.Run(ctx, wg) - shadowsocksLooper := shadowsocks.NewLooper(shadowsocksConf, firewallConf, allSettings.ShadowSocks, allSettings.DNS, logger, streamMerger, uid, gid) + shadowsocksLooper := shadowsocks.NewLooper(shadowsocksConf, firewallConf, allSettings.ShadowSocks, allSettings.DNS, logger, streamMerger, uid, gid, defaultInterface) restartShadowsocks := shadowsocksLooper.Restart go shadowsocksLooper.Run(ctx, wg) diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index f8df1c6e..09faab0c 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -114,22 +114,8 @@ func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocogn } } - for port := range c.allowedPorts { - // TODO restrict interface - if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil { - return fmt.Errorf("cannot enable firewall: %w", err) - } - if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil { - return fmt.Errorf("cannot enable firewall: %w", err) - } - } - - if c.portForwarded > 0 { - const tun = string(constants.TUN) - if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, remove); err != nil { - return fmt.Errorf("cannot enable firewall: %w", err) - } - if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, remove); err != nil { + for port, intf := range c.allowedInputPorts { + if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } } diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index 6339ac8a..26610ff8 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -18,9 +18,8 @@ type Configurator interface { SetEnabled(ctx context.Context, enabled bool) (err error) SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error) SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error) - SetAllowedPort(ctx context.Context, port uint16) error + SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) RemoveAllowedPort(ctx context.Context, port uint16) (err error) - SetPortForward(ctx context.Context, port uint16) (err error) SetDebug() // SetNetworkInformation is meant to be called only once SetNetworkInformation(defaultInterface string, defaultGateway net.IP, localSubnet net.IPNet) @@ -39,22 +38,21 @@ type configurator struct { //nolint:maligned networkInfoMutex sync.Mutex // State - enabled bool - vpnConnections []models.OpenVPNConnection - allowedSubnets []net.IPNet - allowedPorts map[uint16]struct{} - portForwarded uint16 - stateMutex sync.Mutex + enabled bool + vpnConnections []models.OpenVPNConnection + allowedSubnets []net.IPNet + allowedInputPorts map[uint16]string // port to interface mapping + stateMutex sync.Mutex } // NewConfigurator creates a new Configurator instance func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager files.FileManager) Configurator { return &configurator{ - commander: command.NewCommander(), - logger: logger.WithPrefix("firewall: "), - routing: routing, - fileManager: fileManager, - allowedPorts: make(map[uint16]struct{}), + commander: command.NewCommander(), + logger: logger.WithPrefix("firewall: "), + routing: routing, + fileManager: fileManager, + allowedInputPorts: make(map[uint16]string), } } diff --git a/internal/firewall/iptables.go b/internal/firewall/iptables.go index 85b56553..c0b98543 100644 --- a/internal/firewall/iptables.go +++ b/internal/firewall/iptables.go @@ -134,14 +134,15 @@ func (c *configurator) acceptOutputFromSubnetToSubnet(ctx context.Context, intf } // Used for port forwarding, with intf set to tun -func (c *configurator) acceptInputToPort(ctx context.Context, intf string, protocol models.NetworkProtocol, port uint16, remove bool) error { +func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error { interfaceFlag := "-i " + intf if intf == "*" { // all interfaces interfaceFlag = "" } - return c.runIptablesInstruction(ctx, - fmt.Sprintf("%s INPUT %s -p %s --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, protocol, port), - ) + return c.runIptablesInstructions(ctx, []string{ + fmt.Sprintf("%s INPUT %s -p tcp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port), + fmt.Sprintf("%s INPUT %s -p udp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port), + }) } func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error { diff --git a/internal/firewall/ports.go b/internal/firewall/ports.go index 7c79a3b7..a7ab3607 100644 --- a/internal/firewall/ports.go +++ b/internal/firewall/ports.go @@ -3,11 +3,9 @@ package firewall import ( "context" "fmt" - - "github.com/qdm12/private-internet-access-docker/internal/constants" ) -func (c *configurator) SetAllowedPort(ctx context.Context, port uint16) (err error) { +func (c *configurator) SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) { c.stateMutex.Lock() defer c.stateMutex.Unlock() @@ -16,25 +14,28 @@ func (c *configurator) SetAllowedPort(ctx context.Context, port uint16) (err err } if !c.enabled { - c.logger.Info("firewall disabled, only updating allowed ports internal list") - c.allowedPorts[port] = struct{}{} + c.logger.Info("firewall disabled, only updating allowed ports internal state") + c.allowedInputPorts[port] = intf return nil } - c.logger.Info("setting allowed port %d through firewall...", port) + c.logger.Info("setting allowed input port %d through interface %s...", port, intf) - if _, ok := c.allowedPorts[port]; ok { - return nil + if existingIntf, ok := c.allowedInputPorts[port]; ok { + if intf == existingIntf { + return nil + } + const remove = true + if err := c.acceptInputToPort(ctx, existingIntf, port, remove); err != nil { + return fmt.Errorf("cannot remove old allowed port %d through interface %s: %w", port, existingIntf, err) + } } const remove = false - if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil { - return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err) + if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil { + return fmt.Errorf("cannot set allowed port %d through interface %s: %w", port, intf, err) } - if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil { - return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err) - } - c.allowedPorts[port] = struct{}{} + c.allowedInputPorts[port] = intf return nil } @@ -49,63 +50,22 @@ func (c *configurator) RemoveAllowedPort(ctx context.Context, port uint16) (err if !c.enabled { c.logger.Info("firewall disabled, only updating allowed ports internal list") - delete(c.allowedPorts, port) + delete(c.allowedInputPorts, port) return nil } c.logger.Info("removing allowed port %d through firewall...", port) - if _, ok := c.allowedPorts[port]; !ok { + intf, ok := c.allowedInputPorts[port] + if !ok { return nil } const remove = true - if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil { - return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err) + if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil { + return fmt.Errorf("cannot remove allowed port %d through interface %s: %w", port, intf, err) } - if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil { - return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err) - } - delete(c.allowedPorts, port) + delete(c.allowedInputPorts, port) return nil } - -// Use 0 to remove -func (c *configurator) SetPortForward(ctx context.Context, port uint16) (err error) { - c.stateMutex.Lock() - defer c.stateMutex.Unlock() - - if port == c.portForwarded { - return nil - } - - if !c.enabled { - c.logger.Info("firewall disabled, only updating port forwarded internally") - c.portForwarded = port - return nil - } - - const tun = string(constants.TUN) - if c.portForwarded > 0 { - if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, true); err != nil { - return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err) - } - if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, true); err != nil { - return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err) - } - } - - if port == 0 { // not changing port - c.portForwarded = 0 - return nil - } - - if err := c.acceptInputToPort(ctx, tun, constants.TCP, port, false); err != nil { - return fmt.Errorf("cannot accept port forwarded through firewall: %w", err) - } - if err := c.acceptInputToPort(ctx, tun, constants.UDP, port, false); err != nil { - return fmt.Errorf("cannot accept port forwarded through firewall: %w", err) - } - return nil -} diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 8185141d..41ac35c3 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -195,6 +195,12 @@ func (l *looper) portForward(ctx context.Context, providerConf provider.Provider l.logger.Info("port forwarded is %d", port) l.portForwardedMutex.Lock() + if err := l.fw.RemoveAllowedPort(ctx, l.portForwarded); err != nil { + l.logger.Error(err) + } + if err := l.fw.SetAllowedPort(ctx, port, string(constants.TUN)); err != nil { + l.logger.Error(err) + } l.portForwarded = port l.portForwardedMutex.Unlock() @@ -207,10 +213,6 @@ func (l *looper) portForward(ctx context.Context, providerConf provider.Provider if err != nil { l.logger.Error(err) } - - if err := l.fw.SetPortForward(ctx, port); err != nil { - l.logger.Error(err) - } } func (l *looper) GetPortForwarded() (portForwarded uint16) { diff --git a/internal/shadowsocks/loop.go b/internal/shadowsocks/loop.go index 281dfc6f..d7e11ef1 100644 --- a/internal/shadowsocks/loop.go +++ b/internal/shadowsocks/loop.go @@ -21,18 +21,19 @@ type Looper interface { } type looper struct { - conf Configurator - firewallConf firewall.Configurator - settings settings.ShadowSocks - settingsMutex sync.RWMutex - dnsSettings settings.DNS // TODO - logger logging.Logger - streamMerger command.StreamMerger - uid int - gid int - restart chan struct{} - start chan struct{} - stop chan struct{} + conf Configurator + firewallConf firewall.Configurator + settings settings.ShadowSocks + settingsMutex sync.RWMutex + dnsSettings settings.DNS // TODO + logger logging.Logger + streamMerger command.StreamMerger + uid int + gid int + defaultInterface string + restart chan struct{} + start chan struct{} + stop chan struct{} } func (l *looper) logAndWait(ctx context.Context, err error) { @@ -44,19 +45,20 @@ func (l *looper) logAndWait(ctx context.Context, err error) { } func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.ShadowSocks, dnsSettings settings.DNS, - logger logging.Logger, streamMerger command.StreamMerger, uid, gid int) Looper { + logger logging.Logger, streamMerger command.StreamMerger, uid, gid int, defaultInterface string) Looper { return &looper{ - conf: conf, - firewallConf: firewallConf, - settings: settings, - dnsSettings: dnsSettings, - logger: logger.WithPrefix("shadowsocks: "), - streamMerger: streamMerger, - uid: uid, - gid: gid, - restart: make(chan struct{}), - start: make(chan struct{}), - stop: make(chan struct{}), + conf: conf, + firewallConf: firewallConf, + settings: settings, + dnsSettings: dnsSettings, + logger: logger.WithPrefix("shadowsocks: "), + streamMerger: streamMerger, + uid: uid, + gid: gid, + defaultInterface: defaultInterface, + restart: make(chan struct{}), + start: make(chan struct{}), + stop: make(chan struct{}), } } @@ -141,7 +143,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { continue } } - if err := l.firewallConf.SetAllowedPort(ctx, settings.Port); err != nil { + if err := l.firewallConf.SetAllowedPort(ctx, settings.Port, l.defaultInterface); err != nil { l.logger.Error(err) continue } diff --git a/internal/tinyproxy/loop.go b/internal/tinyproxy/loop.go index 0c0fcf86..598ea042 100644 --- a/internal/tinyproxy/loop.go +++ b/internal/tinyproxy/loop.go @@ -21,17 +21,18 @@ type Looper interface { } type looper struct { - conf Configurator - firewallConf firewall.Configurator - settings settings.TinyProxy - settingsMutex sync.RWMutex - logger logging.Logger - streamMerger command.StreamMerger - uid int - gid int - restart chan struct{} - start chan struct{} - stop chan struct{} + conf Configurator + firewallConf firewall.Configurator + settings settings.TinyProxy + settingsMutex sync.RWMutex + logger logging.Logger + streamMerger command.StreamMerger + uid int + gid int + defaultInterface string + restart chan struct{} + start chan struct{} + stop chan struct{} } func (l *looper) logAndWait(ctx context.Context, err error) { @@ -43,18 +44,19 @@ func (l *looper) logAndWait(ctx context.Context, err error) { } func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.TinyProxy, - logger logging.Logger, streamMerger command.StreamMerger, uid, gid int) Looper { + logger logging.Logger, streamMerger command.StreamMerger, uid, gid int, defaultInterface string) Looper { return &looper{ - conf: conf, - firewallConf: firewallConf, - settings: settings, - logger: logger.WithPrefix("tinyproxy: "), - streamMerger: streamMerger, - uid: uid, - gid: gid, - restart: make(chan struct{}), - start: make(chan struct{}), - stop: make(chan struct{}), + conf: conf, + firewallConf: firewallConf, + settings: settings, + logger: logger.WithPrefix("tinyproxy: "), + streamMerger: streamMerger, + uid: uid, + gid: gid, + defaultInterface: defaultInterface, + restart: make(chan struct{}), + start: make(chan struct{}), + stop: make(chan struct{}), } } @@ -133,7 +135,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { continue } } - if err := l.firewallConf.SetAllowedPort(ctx, settings.Port); err != nil { + if err := l.firewallConf.SetAllowedPort(ctx, settings.Port, l.defaultInterface); err != nil { l.logger.Error(err) continue }