Firewall simplifications
- Only a map of allowed input port to interface - port forwarded is in the map of allowed input ports - port forwarded has the interface tun0 in this map - Always allow tcp and udp for allowed input ports - Port forward state is in openvpn looper only - Shadowsocks input port allowed on default interface only - Tinyproxy input port allowed on default interface only
This commit is contained in:
@@ -158,11 +158,11 @@ func _main(background context.Context, args []string) int {
|
||||
go publicIPLooper.RunRestartTicker(ctx)
|
||||
setPublicIPPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker
|
||||
|
||||
tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid)
|
||||
tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid, defaultInterface)
|
||||
restartTinyproxy := tinyproxyLooper.Restart
|
||||
go tinyproxyLooper.Run(ctx, wg)
|
||||
|
||||
shadowsocksLooper := shadowsocks.NewLooper(shadowsocksConf, firewallConf, allSettings.ShadowSocks, allSettings.DNS, logger, streamMerger, uid, gid)
|
||||
shadowsocksLooper := shadowsocks.NewLooper(shadowsocksConf, firewallConf, allSettings.ShadowSocks, allSettings.DNS, logger, streamMerger, uid, gid, defaultInterface)
|
||||
restartShadowsocks := shadowsocksLooper.Restart
|
||||
go shadowsocksLooper.Run(ctx, wg)
|
||||
|
||||
|
||||
@@ -114,22 +114,8 @@ func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocogn
|
||||
}
|
||||
}
|
||||
|
||||
for port := range c.allowedPorts {
|
||||
// TODO restrict interface
|
||||
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.portForwarded > 0 {
|
||||
const tun = string(constants.TUN)
|
||||
if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, remove); err != nil {
|
||||
for port, intf := range c.allowedInputPorts {
|
||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,9 +18,8 @@ type Configurator interface {
|
||||
SetEnabled(ctx context.Context, enabled bool) (err error)
|
||||
SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error)
|
||||
SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error)
|
||||
SetAllowedPort(ctx context.Context, port uint16) error
|
||||
SetAllowedPort(ctx context.Context, port uint16, intf string) (err error)
|
||||
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||
SetPortForward(ctx context.Context, port uint16) (err error)
|
||||
SetDebug()
|
||||
// SetNetworkInformation is meant to be called only once
|
||||
SetNetworkInformation(defaultInterface string, defaultGateway net.IP, localSubnet net.IPNet)
|
||||
@@ -42,8 +41,7 @@ type configurator struct { //nolint:maligned
|
||||
enabled bool
|
||||
vpnConnections []models.OpenVPNConnection
|
||||
allowedSubnets []net.IPNet
|
||||
allowedPorts map[uint16]struct{}
|
||||
portForwarded uint16
|
||||
allowedInputPorts map[uint16]string // port to interface mapping
|
||||
stateMutex sync.Mutex
|
||||
}
|
||||
|
||||
@@ -54,7 +52,7 @@ func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager
|
||||
logger: logger.WithPrefix("firewall: "),
|
||||
routing: routing,
|
||||
fileManager: fileManager,
|
||||
allowedPorts: make(map[uint16]struct{}),
|
||||
allowedInputPorts: make(map[uint16]string),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -134,14 +134,15 @@ func (c *configurator) acceptOutputFromSubnetToSubnet(ctx context.Context, intf
|
||||
}
|
||||
|
||||
// Used for port forwarding, with intf set to tun
|
||||
func (c *configurator) acceptInputToPort(ctx context.Context, intf string, protocol models.NetworkProtocol, port uint16, remove bool) error {
|
||||
func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port uint16, remove bool) error {
|
||||
interfaceFlag := "-i " + intf
|
||||
if intf == "*" { // all interfaces
|
||||
interfaceFlag = ""
|
||||
}
|
||||
return c.runIptablesInstruction(ctx,
|
||||
fmt.Sprintf("%s INPUT %s -p %s --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, protocol, port),
|
||||
)
|
||||
return c.runIptablesInstructions(ctx, []string{
|
||||
fmt.Sprintf("%s INPUT %s -p tcp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
|
||||
fmt.Sprintf("%s INPUT %s -p udp --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, port),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
|
||||
|
||||
@@ -3,11 +3,9 @@ package firewall
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) SetAllowedPort(ctx context.Context, port uint16) (err error) {
|
||||
func (c *configurator) SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
|
||||
@@ -16,25 +14,28 @@ func (c *configurator) SetAllowedPort(ctx context.Context, port uint16) (err err
|
||||
}
|
||||
|
||||
if !c.enabled {
|
||||
c.logger.Info("firewall disabled, only updating allowed ports internal list")
|
||||
c.allowedPorts[port] = struct{}{}
|
||||
c.logger.Info("firewall disabled, only updating allowed ports internal state")
|
||||
c.allowedInputPorts[port] = intf
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Info("setting allowed port %d through firewall...", port)
|
||||
c.logger.Info("setting allowed input port %d through interface %s...", port, intf)
|
||||
|
||||
if _, ok := c.allowedPorts[port]; ok {
|
||||
if existingIntf, ok := c.allowedInputPorts[port]; ok {
|
||||
if intf == existingIntf {
|
||||
return nil
|
||||
}
|
||||
const remove = true
|
||||
if err := c.acceptInputToPort(ctx, existingIntf, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot remove old allowed port %d through interface %s: %w", port, existingIntf, err)
|
||||
}
|
||||
}
|
||||
|
||||
const remove = false
|
||||
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err)
|
||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot set allowed port %d through interface %s: %w", port, intf, err)
|
||||
}
|
||||
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err)
|
||||
}
|
||||
c.allowedPorts[port] = struct{}{}
|
||||
c.allowedInputPorts[port] = intf
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -49,63 +50,22 @@ func (c *configurator) RemoveAllowedPort(ctx context.Context, port uint16) (err
|
||||
|
||||
if !c.enabled {
|
||||
c.logger.Info("firewall disabled, only updating allowed ports internal list")
|
||||
delete(c.allowedPorts, port)
|
||||
delete(c.allowedInputPorts, port)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Info("removing allowed port %d through firewall...", port)
|
||||
|
||||
if _, ok := c.allowedPorts[port]; !ok {
|
||||
intf, ok := c.allowedInputPorts[port]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
const remove = true
|
||||
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err)
|
||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot remove allowed port %d through interface %s: %w", port, intf, err)
|
||||
}
|
||||
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err)
|
||||
}
|
||||
delete(c.allowedPorts, port)
|
||||
delete(c.allowedInputPorts, port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use 0 to remove
|
||||
func (c *configurator) SetPortForward(ctx context.Context, port uint16) (err error) {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
|
||||
if port == c.portForwarded {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.enabled {
|
||||
c.logger.Info("firewall disabled, only updating port forwarded internally")
|
||||
c.portForwarded = port
|
||||
return nil
|
||||
}
|
||||
|
||||
const tun = string(constants.TUN)
|
||||
if c.portForwarded > 0 {
|
||||
if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, true); err != nil {
|
||||
return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err)
|
||||
}
|
||||
if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, true); err != nil {
|
||||
return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if port == 0 { // not changing port
|
||||
c.portForwarded = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.acceptInputToPort(ctx, tun, constants.TCP, port, false); err != nil {
|
||||
return fmt.Errorf("cannot accept port forwarded through firewall: %w", err)
|
||||
}
|
||||
if err := c.acceptInputToPort(ctx, tun, constants.UDP, port, false); err != nil {
|
||||
return fmt.Errorf("cannot accept port forwarded through firewall: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -195,6 +195,12 @@ func (l *looper) portForward(ctx context.Context, providerConf provider.Provider
|
||||
|
||||
l.logger.Info("port forwarded is %d", port)
|
||||
l.portForwardedMutex.Lock()
|
||||
if err := l.fw.RemoveAllowedPort(ctx, l.portForwarded); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
if err := l.fw.SetAllowedPort(ctx, port, string(constants.TUN)); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
l.portForwarded = port
|
||||
l.portForwardedMutex.Unlock()
|
||||
|
||||
@@ -207,10 +213,6 @@ func (l *looper) portForward(ctx context.Context, providerConf provider.Provider
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
|
||||
if err := l.fw.SetPortForward(ctx, port); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) GetPortForwarded() (portForwarded uint16) {
|
||||
|
||||
@@ -30,6 +30,7 @@ type looper struct {
|
||||
streamMerger command.StreamMerger
|
||||
uid int
|
||||
gid int
|
||||
defaultInterface string
|
||||
restart chan struct{}
|
||||
start chan struct{}
|
||||
stop chan struct{}
|
||||
@@ -44,7 +45,7 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||
}
|
||||
|
||||
func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.ShadowSocks, dnsSettings settings.DNS,
|
||||
logger logging.Logger, streamMerger command.StreamMerger, uid, gid int) Looper {
|
||||
logger logging.Logger, streamMerger command.StreamMerger, uid, gid int, defaultInterface string) Looper {
|
||||
return &looper{
|
||||
conf: conf,
|
||||
firewallConf: firewallConf,
|
||||
@@ -54,6 +55,7 @@ func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings s
|
||||
streamMerger: streamMerger,
|
||||
uid: uid,
|
||||
gid: gid,
|
||||
defaultInterface: defaultInterface,
|
||||
restart: make(chan struct{}),
|
||||
start: make(chan struct{}),
|
||||
stop: make(chan struct{}),
|
||||
@@ -141,7 +143,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := l.firewallConf.SetAllowedPort(ctx, settings.Port); err != nil {
|
||||
if err := l.firewallConf.SetAllowedPort(ctx, settings.Port, l.defaultInterface); err != nil {
|
||||
l.logger.Error(err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ type looper struct {
|
||||
streamMerger command.StreamMerger
|
||||
uid int
|
||||
gid int
|
||||
defaultInterface string
|
||||
restart chan struct{}
|
||||
start chan struct{}
|
||||
stop chan struct{}
|
||||
@@ -43,7 +44,7 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||
}
|
||||
|
||||
func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.TinyProxy,
|
||||
logger logging.Logger, streamMerger command.StreamMerger, uid, gid int) Looper {
|
||||
logger logging.Logger, streamMerger command.StreamMerger, uid, gid int, defaultInterface string) Looper {
|
||||
return &looper{
|
||||
conf: conf,
|
||||
firewallConf: firewallConf,
|
||||
@@ -52,6 +53,7 @@ func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings s
|
||||
streamMerger: streamMerger,
|
||||
uid: uid,
|
||||
gid: gid,
|
||||
defaultInterface: defaultInterface,
|
||||
restart: make(chan struct{}),
|
||||
start: make(chan struct{}),
|
||||
stop: make(chan struct{}),
|
||||
@@ -133,7 +135,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := l.firewallConf.SetAllowedPort(ctx, settings.Port); err != nil {
|
||||
if err := l.firewallConf.SetAllowedPort(ctx, settings.Port, l.defaultInterface); err != nil {
|
||||
l.logger.Error(err)
|
||||
continue
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user