Firewall simplifications

- Only a map of allowed input port to interface
- port forwarded is in the map of allowed input ports
- port forwarded has the interface tun0 in this map
- Always allow tcp and udp for allowed input ports
- Port forward state is in openvpn looper only
- Shadowsocks input port allowed on default interface only
- Tinyproxy input port allowed on default interface only
This commit is contained in:
Quentin McGaw
2020-07-20 00:39:59 +00:00
parent 85bd4f2e8d
commit a13be8f45e
8 changed files with 99 additions and 148 deletions

View File

@@ -158,11 +158,11 @@ func _main(background context.Context, args []string) int {
go publicIPLooper.RunRestartTicker(ctx) go publicIPLooper.RunRestartTicker(ctx)
setPublicIPPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker setPublicIPPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker
tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid) tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid, defaultInterface)
restartTinyproxy := tinyproxyLooper.Restart restartTinyproxy := tinyproxyLooper.Restart
go tinyproxyLooper.Run(ctx, wg) go tinyproxyLooper.Run(ctx, wg)
shadowsocksLooper := shadowsocks.NewLooper(shadowsocksConf, firewallConf, allSettings.ShadowSocks, allSettings.DNS, logger, streamMerger, uid, gid) shadowsocksLooper := shadowsocks.NewLooper(shadowsocksConf, firewallConf, allSettings.ShadowSocks, allSettings.DNS, logger, streamMerger, uid, gid, defaultInterface)
restartShadowsocks := shadowsocksLooper.Restart restartShadowsocks := shadowsocksLooper.Restart
go shadowsocksLooper.Run(ctx, wg) go shadowsocksLooper.Run(ctx, wg)

View File

@@ -114,22 +114,8 @@ func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocogn
} }
} }
for port := range c.allowedPorts { for port, intf := range c.allowedInputPorts {
// TODO restrict interface if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
if c.portForwarded > 0 {
const tun = string(constants.TUN)
if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err) return fmt.Errorf("cannot enable firewall: %w", err)
} }
} }

View File

@@ -18,9 +18,8 @@ type Configurator interface {
SetEnabled(ctx context.Context, enabled bool) (err error) SetEnabled(ctx context.Context, enabled bool) (err error)
SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error) SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error)
SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error) SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error)
SetAllowedPort(ctx context.Context, port uint16) error SetAllowedPort(ctx context.Context, port uint16, intf string) (err error)
RemoveAllowedPort(ctx context.Context, port uint16) (err error) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
SetPortForward(ctx context.Context, port uint16) (err error)
SetDebug() SetDebug()
// SetNetworkInformation is meant to be called only once // SetNetworkInformation is meant to be called only once
SetNetworkInformation(defaultInterface string, defaultGateway net.IP, localSubnet net.IPNet) SetNetworkInformation(defaultInterface string, defaultGateway net.IP, localSubnet net.IPNet)
@@ -39,22 +38,21 @@ type configurator struct { //nolint:maligned
networkInfoMutex sync.Mutex networkInfoMutex sync.Mutex
// State // State
enabled bool enabled bool
vpnConnections []models.OpenVPNConnection vpnConnections []models.OpenVPNConnection
allowedSubnets []net.IPNet allowedSubnets []net.IPNet
allowedPorts map[uint16]struct{} allowedInputPorts map[uint16]string // port to interface mapping
portForwarded uint16 stateMutex sync.Mutex
stateMutex sync.Mutex
} }
// NewConfigurator creates a new Configurator instance // NewConfigurator creates a new Configurator instance
func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager files.FileManager) Configurator { func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager files.FileManager) Configurator {
return &configurator{ return &configurator{
commander: command.NewCommander(), commander: command.NewCommander(),
logger: logger.WithPrefix("firewall: "), logger: logger.WithPrefix("firewall: "),
routing: routing, routing: routing,
fileManager: fileManager, fileManager: fileManager,
allowedPorts: make(map[uint16]struct{}), allowedInputPorts: make(map[uint16]string),
} }
} }

View File

@@ -134,14 +134,15 @@ func (c *configurator) acceptOutputFromSubnetToSubnet(ctx context.Context, intf
} }
// Used for port forwarding, with intf set to tun // Used for port forwarding, with intf set to tun
func (c *configurator) acceptInputToPort(ctx context.Context, intf string, protocol models.NetworkProtocol, port uint16, remove bool) error { func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
interfaceFlag := "-i " + intf interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces if intf == "*" { // all interfaces
interfaceFlag = "" interfaceFlag = ""
} }
return c.runIptablesInstruction(ctx, return c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("%s INPUT %s -p %s --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, protocol, port), fmt.Sprintf("%s INPUT %s -p tcp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
) fmt.Sprintf("%s INPUT %s -p udp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
})
} }
func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error { func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error {

View File

@@ -3,11 +3,9 @@ package firewall
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/qdm12/private-internet-access-docker/internal/constants"
) )
func (c *configurator) SetAllowedPort(ctx context.Context, port uint16) (err error) { func (c *configurator) SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) {
c.stateMutex.Lock() c.stateMutex.Lock()
defer c.stateMutex.Unlock() defer c.stateMutex.Unlock()
@@ -16,25 +14,28 @@ func (c *configurator) SetAllowedPort(ctx context.Context, port uint16) (err err
} }
if !c.enabled { if !c.enabled {
c.logger.Info("firewall disabled, only updating allowed ports internal list") c.logger.Info("firewall disabled, only updating allowed ports internal state")
c.allowedPorts[port] = struct{}{} c.allowedInputPorts[port] = intf
return nil return nil
} }
c.logger.Info("setting allowed port %d through firewall...", port) c.logger.Info("setting allowed input port %d through interface %s...", port, intf)
if _, ok := c.allowedPorts[port]; ok { if existingIntf, ok := c.allowedInputPorts[port]; ok {
return nil if intf == existingIntf {
return nil
}
const remove = true
if err := c.acceptInputToPort(ctx, existingIntf, port, remove); err != nil {
return fmt.Errorf("cannot remove old allowed port %d through interface %s: %w", port, existingIntf, err)
}
} }
const remove = false const remove = false
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil { if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err) return fmt.Errorf("cannot set allowed port %d through interface %s: %w", port, intf, err)
} }
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil { c.allowedInputPorts[port] = intf
return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err)
}
c.allowedPorts[port] = struct{}{}
return nil return nil
} }
@@ -49,63 +50,22 @@ func (c *configurator) RemoveAllowedPort(ctx context.Context, port uint16) (err
if !c.enabled { if !c.enabled {
c.logger.Info("firewall disabled, only updating allowed ports internal list") c.logger.Info("firewall disabled, only updating allowed ports internal list")
delete(c.allowedPorts, port) delete(c.allowedInputPorts, port)
return nil return nil
} }
c.logger.Info("removing allowed port %d through firewall...", port) c.logger.Info("removing allowed port %d through firewall...", port)
if _, ok := c.allowedPorts[port]; !ok { intf, ok := c.allowedInputPorts[port]
if !ok {
return nil return nil
} }
const remove = true const remove = true
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil { if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err) return fmt.Errorf("cannot remove allowed port %d through interface %s: %w", port, intf, err)
} }
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil { delete(c.allowedInputPorts, port)
return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err)
}
delete(c.allowedPorts, port)
return nil return nil
} }
// Use 0 to remove
func (c *configurator) SetPortForward(ctx context.Context, port uint16) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if port == c.portForwarded {
return nil
}
if !c.enabled {
c.logger.Info("firewall disabled, only updating port forwarded internally")
c.portForwarded = port
return nil
}
const tun = string(constants.TUN)
if c.portForwarded > 0 {
if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, true); err != nil {
return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err)
}
if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, true); err != nil {
return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err)
}
}
if port == 0 { // not changing port
c.portForwarded = 0
return nil
}
if err := c.acceptInputToPort(ctx, tun, constants.TCP, port, false); err != nil {
return fmt.Errorf("cannot accept port forwarded through firewall: %w", err)
}
if err := c.acceptInputToPort(ctx, tun, constants.UDP, port, false); err != nil {
return fmt.Errorf("cannot accept port forwarded through firewall: %w", err)
}
return nil
}

View File

@@ -195,6 +195,12 @@ func (l *looper) portForward(ctx context.Context, providerConf provider.Provider
l.logger.Info("port forwarded is %d", port) l.logger.Info("port forwarded is %d", port)
l.portForwardedMutex.Lock() l.portForwardedMutex.Lock()
if err := l.fw.RemoveAllowedPort(ctx, l.portForwarded); err != nil {
l.logger.Error(err)
}
if err := l.fw.SetAllowedPort(ctx, port, string(constants.TUN)); err != nil {
l.logger.Error(err)
}
l.portForwarded = port l.portForwarded = port
l.portForwardedMutex.Unlock() l.portForwardedMutex.Unlock()
@@ -207,10 +213,6 @@ func (l *looper) portForward(ctx context.Context, providerConf provider.Provider
if err != nil { if err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
if err := l.fw.SetPortForward(ctx, port); err != nil {
l.logger.Error(err)
}
} }
func (l *looper) GetPortForwarded() (portForwarded uint16) { func (l *looper) GetPortForwarded() (portForwarded uint16) {

View File

@@ -21,18 +21,19 @@ type Looper interface {
} }
type looper struct { type looper struct {
conf Configurator conf Configurator
firewallConf firewall.Configurator firewallConf firewall.Configurator
settings settings.ShadowSocks settings settings.ShadowSocks
settingsMutex sync.RWMutex settingsMutex sync.RWMutex
dnsSettings settings.DNS // TODO dnsSettings settings.DNS // TODO
logger logging.Logger logger logging.Logger
streamMerger command.StreamMerger streamMerger command.StreamMerger
uid int uid int
gid int gid int
restart chan struct{} defaultInterface string
start chan struct{} restart chan struct{}
stop chan struct{} start chan struct{}
stop chan struct{}
} }
func (l *looper) logAndWait(ctx context.Context, err error) { func (l *looper) logAndWait(ctx context.Context, err error) {
@@ -44,19 +45,20 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
} }
func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.ShadowSocks, dnsSettings settings.DNS, func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.ShadowSocks, dnsSettings settings.DNS,
logger logging.Logger, streamMerger command.StreamMerger, uid, gid int) Looper { logger logging.Logger, streamMerger command.StreamMerger, uid, gid int, defaultInterface string) Looper {
return &looper{ return &looper{
conf: conf, conf: conf,
firewallConf: firewallConf, firewallConf: firewallConf,
settings: settings, settings: settings,
dnsSettings: dnsSettings, dnsSettings: dnsSettings,
logger: logger.WithPrefix("shadowsocks: "), logger: logger.WithPrefix("shadowsocks: "),
streamMerger: streamMerger, streamMerger: streamMerger,
uid: uid, uid: uid,
gid: gid, gid: gid,
restart: make(chan struct{}), defaultInterface: defaultInterface,
start: make(chan struct{}), restart: make(chan struct{}),
stop: make(chan struct{}), start: make(chan struct{}),
stop: make(chan struct{}),
} }
} }
@@ -141,7 +143,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
continue continue
} }
} }
if err := l.firewallConf.SetAllowedPort(ctx, settings.Port); err != nil { if err := l.firewallConf.SetAllowedPort(ctx, settings.Port, l.defaultInterface); err != nil {
l.logger.Error(err) l.logger.Error(err)
continue continue
} }

View File

@@ -21,17 +21,18 @@ type Looper interface {
} }
type looper struct { type looper struct {
conf Configurator conf Configurator
firewallConf firewall.Configurator firewallConf firewall.Configurator
settings settings.TinyProxy settings settings.TinyProxy
settingsMutex sync.RWMutex settingsMutex sync.RWMutex
logger logging.Logger logger logging.Logger
streamMerger command.StreamMerger streamMerger command.StreamMerger
uid int uid int
gid int gid int
restart chan struct{} defaultInterface string
start chan struct{} restart chan struct{}
stop chan struct{} start chan struct{}
stop chan struct{}
} }
func (l *looper) logAndWait(ctx context.Context, err error) { func (l *looper) logAndWait(ctx context.Context, err error) {
@@ -43,18 +44,19 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
} }
func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.TinyProxy, func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.TinyProxy,
logger logging.Logger, streamMerger command.StreamMerger, uid, gid int) Looper { logger logging.Logger, streamMerger command.StreamMerger, uid, gid int, defaultInterface string) Looper {
return &looper{ return &looper{
conf: conf, conf: conf,
firewallConf: firewallConf, firewallConf: firewallConf,
settings: settings, settings: settings,
logger: logger.WithPrefix("tinyproxy: "), logger: logger.WithPrefix("tinyproxy: "),
streamMerger: streamMerger, streamMerger: streamMerger,
uid: uid, uid: uid,
gid: gid, gid: gid,
restart: make(chan struct{}), defaultInterface: defaultInterface,
start: make(chan struct{}), restart: make(chan struct{}),
stop: make(chan struct{}), start: make(chan struct{}),
stop: make(chan struct{}),
} }
} }
@@ -133,7 +135,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
continue continue
} }
} }
if err := l.firewallConf.SetAllowedPort(ctx, settings.Port); err != nil { if err := l.firewallConf.SetAllowedPort(ctx, settings.Port, l.defaultInterface); err != nil {
l.logger.Error(err) l.logger.Error(err)
continue continue
} }