feat(portforward): port redirection with VPN_PORT_FORWARDING_LISTENING_PORT

This commit is contained in:
Quentin McGaw
2023-11-10 17:21:35 +00:00
parent 8318be3159
commit 4105f74ce1
14 changed files with 226 additions and 6 deletions

View File

@@ -111,6 +111,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
# # Private Internet Access only: # # Private Internet Access only:
PRIVATE_INTERNET_ACCESS_OPENVPN_ENCRYPTION_PRESET= \ PRIVATE_INTERNET_ACCESS_OPENVPN_ENCRYPTION_PRESET= \
VPN_PORT_FORWARDING=off \ VPN_PORT_FORWARDING=off \
VPN_PORT_FORWARDING_LISTENING_PORT=0 \
VPN_PORT_FORWARDING_PROVIDER= \ VPN_PORT_FORWARDING_PROVIDER= \
VPN_PORT_FORWARDING_STATUS_FILE="/tmp/gluetun/forwarded_port" \ VPN_PORT_FORWARDING_STATUS_FILE="/tmp/gluetun/forwarded_port" \
# # Cyberghost only: # # Cyberghost only:

View File

@@ -100,7 +100,7 @@ func (d DoT) toLinesNode() (node *gotree.Node) {
return node return node
} }
update := "disabled" update := "disabled" //nolint:goconst
if *d.UpdatePeriod > 0 { if *d.UpdatePeriod > 0 {
update = "every " + d.UpdatePeriod.String() update = "every " + d.UpdatePeriod.String()
} }

View File

@@ -28,6 +28,10 @@ type PortForwarding struct {
// to write to a file. It cannot be nil for the // to write to a file. It cannot be nil for the
// internal state // internal state
Filepath *string `json:"status_file_path"` Filepath *string `json:"status_file_path"`
// ListeningPort is the port traffic would be redirected to from the
// forwarded port. The redirection is disabled if it is set to 0, which
// is its default as well.
ListeningPort *uint16 `json:"listening_port"`
} }
func (p PortForwarding) Validate(vpnProvider string) (err error) { func (p PortForwarding) Validate(vpnProvider string) (err error) {
@@ -61,9 +65,10 @@ func (p PortForwarding) Validate(vpnProvider string) (err error) {
func (p *PortForwarding) Copy() (copied PortForwarding) { func (p *PortForwarding) Copy() (copied PortForwarding) {
return PortForwarding{ return PortForwarding{
Enabled: gosettings.CopyPointer(p.Enabled), Enabled: gosettings.CopyPointer(p.Enabled),
Provider: gosettings.CopyPointer(p.Provider), Provider: gosettings.CopyPointer(p.Provider),
Filepath: gosettings.CopyPointer(p.Filepath), Filepath: gosettings.CopyPointer(p.Filepath),
ListeningPort: gosettings.CopyPointer(p.ListeningPort),
} }
} }
@@ -71,18 +76,21 @@ func (p *PortForwarding) mergeWith(other PortForwarding) {
p.Enabled = gosettings.MergeWithPointer(p.Enabled, other.Enabled) p.Enabled = gosettings.MergeWithPointer(p.Enabled, other.Enabled)
p.Provider = gosettings.MergeWithPointer(p.Provider, other.Provider) p.Provider = gosettings.MergeWithPointer(p.Provider, other.Provider)
p.Filepath = gosettings.MergeWithPointer(p.Filepath, other.Filepath) p.Filepath = gosettings.MergeWithPointer(p.Filepath, other.Filepath)
p.ListeningPort = gosettings.MergeWithPointer(p.ListeningPort, other.ListeningPort)
} }
func (p *PortForwarding) OverrideWith(other PortForwarding) { func (p *PortForwarding) OverrideWith(other PortForwarding) {
p.Enabled = gosettings.OverrideWithPointer(p.Enabled, other.Enabled) p.Enabled = gosettings.OverrideWithPointer(p.Enabled, other.Enabled)
p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider) p.Provider = gosettings.OverrideWithPointer(p.Provider, other.Provider)
p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath) p.Filepath = gosettings.OverrideWithPointer(p.Filepath, other.Filepath)
p.ListeningPort = gosettings.OverrideWithPointer(p.ListeningPort, other.ListeningPort)
} }
func (p *PortForwarding) setDefaults() { func (p *PortForwarding) setDefaults() {
p.Enabled = gosettings.DefaultPointer(p.Enabled, false) p.Enabled = gosettings.DefaultPointer(p.Enabled, false)
p.Provider = gosettings.DefaultPointer(p.Provider, "") p.Provider = gosettings.DefaultPointer(p.Provider, "")
p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port") p.Filepath = gosettings.DefaultPointer(p.Filepath, "/tmp/gluetun/forwarded_port")
p.ListeningPort = gosettings.DefaultPointer(p.ListeningPort, 0)
} }
func (p PortForwarding) String() string { func (p PortForwarding) String() string {
@@ -95,6 +103,13 @@ func (p PortForwarding) toLinesNode() (node *gotree.Node) {
} }
node = gotree.New("Automatic port forwarding settings:") node = gotree.New("Automatic port forwarding settings:")
listeningPort := "disabled"
if *p.ListeningPort != 0 {
listeningPort = fmt.Sprintf("%d", *p.ListeningPort)
}
node.Appendf("Redirection listening port: %s", listeningPort)
if *p.Provider == "" { if *p.Provider == "" {
node.Appendf("Use port forwarding code for current provider") node.Appendf("Use port forwarding code for current provider")
} else { } else {

View File

@@ -25,5 +25,10 @@ func (s *Source) readPortForward() (
"PRIVATE_INTERNET_ACCESS_VPN_PORT_FORWARDING_STATUS_FILE", "PRIVATE_INTERNET_ACCESS_VPN_PORT_FORWARDING_STATUS_FILE",
)) ))
portForwarding.ListeningPort, err = s.env.Uint16Ptr("VPN_PORT_FORWARDING_LISTENING_PORT")
if err != nil {
return portForwarding, err
}
return portForwarding, nil return portForwarding, nil
} }

View File

@@ -51,6 +51,13 @@ func (c *Config) disable(ctx context.Context) (err error) {
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil { if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
return fmt.Errorf("setting ipv6 policies: %w", err) return fmt.Errorf("setting ipv6 policies: %w", err)
} }
const remove = true
err = c.redirectPorts(ctx, remove)
if err != nil {
return fmt.Errorf("removing port redirections: %w", err)
}
return nil return nil
} }
@@ -124,6 +131,11 @@ func (c *Config) enable(ctx context.Context) (err error) {
return err return err
} }
err = c.redirectPorts(ctx, remove)
if err != nil {
return fmt.Errorf("redirecting ports: %w", err)
}
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil { if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
return fmt.Errorf("running user defined post firewall rules: %w", err) return fmt.Errorf("running user defined post firewall rules: %w", err)
} }
@@ -188,3 +200,14 @@ func (c *Config) allowInputPorts(ctx context.Context) (err error) {
} }
return nil return nil
} }
func (c *Config) redirectPorts(ctx context.Context, remove bool) (err error) {
for _, portRedirection := range c.portRedirections {
err = c.redirectPort(ctx, portRedirection.interfaceName, portRedirection.sourcePort,
portRedirection.destinationPort, remove)
if err != nil {
return err
}
}
return nil
}

View File

@@ -29,6 +29,7 @@ type Config struct { //nolint:maligned
vpnIntf string vpnIntf string
outboundSubnets []netip.Prefix outboundSubnets []netip.Prefix
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
portRedirections portRedirections
stateMutex sync.Mutex stateMutex sync.Mutex
} }

View File

@@ -198,6 +198,38 @@ func (c *Config) acceptInputToPort(ctx context.Context, intf string, port uint16
}) })
} }
// Used for VPN server side port forwarding, with intf set to the VPN tunnel interface.
func (c *Config) redirectPort(ctx context.Context, intf string,
sourcePort, destinationPort uint16, remove bool) (err error) {
interfaceFlag := "-i " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
err = c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -d 127.0.0.1 -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("-t nat %s PREROUTING %s -d 127.0.0.1 -p udp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
})
if err != nil {
return fmt.Errorf("redirecting IPv4 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}
err = c.runIP6tablesInstructions(ctx, []string{
fmt.Sprintf("-t nat %s PREROUTING %s -d ::1 -p tcp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
fmt.Sprintf("-t nat %s PREROUTING %s -d ::1 -p udp --dport %d -j REDIRECT --to-ports %d",
appendOrDelete(remove), interfaceFlag, sourcePort, destinationPort),
})
if err != nil {
return fmt.Errorf("redirecting IPv6 source port %d to destination port %d on interface %s: %w",
sourcePort, destinationPort, intf, err)
}
return nil
}
func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove bool) error { func (c *Config) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
file, err := os.OpenFile(filepath, os.O_RDONLY, 0) file, err := os.OpenFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) { if os.IsNotExist(err) {

View File

@@ -0,0 +1,119 @@
package firewall
import (
"context"
"fmt"
)
// RedirectPort redirects a source port to a destination port on the interface
// intf. If intf is empty, it is set to "*" which means all interfaces.
// If a redirection for the source port given already exists, it is removed first.
// If the destination port is zero, the redirection for the source port is removed
// and no new redirection is added.
func (c *Config) RedirectPort(ctx context.Context, intf string, sourcePort,
destinationPort uint16) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if sourcePort == 0 {
panic("source port cannot be 0")
}
newRedirection := portRedirection{
interfaceName: intf,
sourcePort: sourcePort,
destinationPort: destinationPort,
}
if !c.enabled {
c.logger.Info("firewall disabled, only updating redirected ports internal state")
if destinationPort == 0 {
c.portRedirections.remove(intf, sourcePort)
return nil
}
exists, conflict := c.portRedirections.check(newRedirection)
switch {
case exists:
return nil
case conflict != nil:
c.portRedirections.remove(conflict.interfaceName,
conflict.sourcePort)
}
c.portRedirections.append(newRedirection)
return nil
}
exists, conflict := c.portRedirections.check(newRedirection)
switch {
case exists:
return nil
case conflict != nil:
const remove = true
err = c.redirectPort(ctx, conflict.interfaceName, conflict.sourcePort,
conflict.destinationPort, remove)
if err != nil {
return fmt.Errorf("removing conflicting redirection: %w", err)
}
c.portRedirections.remove(conflict.interfaceName,
conflict.sourcePort)
}
const remove = false
err = c.redirectPort(ctx, intf, sourcePort, destinationPort, remove)
if err != nil {
return fmt.Errorf("redirecting port: %w", err)
}
c.portRedirections.append(newRedirection)
return nil
}
type portRedirection struct {
interfaceName string
sourcePort uint16
destinationPort uint16
}
type portRedirections []portRedirection
func (p *portRedirections) remove(intf string, sourcePort uint16) {
slice := *p
for i, redirection := range slice {
interfaceMatch := intf == "" || intf == redirection.interfaceName
if redirection.sourcePort == sourcePort && interfaceMatch {
// Remove redirection - note: order does not matter
slice[i] = slice[len(slice)-1]
slice = slice[:len(slice)-1]
}
}
*p = slice
}
func (p *portRedirections) check(dryRun portRedirection) (alreadyExists bool,
conflict *portRedirection) {
slice := *p
for _, redirection := range slice {
interfaceMatch := redirection.interfaceName == "" ||
redirection.interfaceName == dryRun.interfaceName
if redirection.sourcePort == dryRun.sourcePort &&
redirection.destinationPort == dryRun.destinationPort &&
interfaceMatch {
return true, nil
}
if redirection.sourcePort == dryRun.sourcePort &&
interfaceMatch {
// Source port has a redirection already for the same interface or all interfaces
return false, &redirection
}
}
return false, nil
}
// append should be called after running `check` to avoid rule conflicts.
func (p *portRedirections) append(newRedirection portRedirection) {
slice := *p
slice = append(slice, newRedirection)
*p = slice
}

View File

@@ -18,6 +18,8 @@ type Routing interface {
type PortAllower interface { type PortAllower interface {
SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) SetAllowedPort(ctx context.Context, port uint16, intf string) (err error)
RemoveAllowedPort(ctx context.Context, port uint16) (err error) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
RedirectPort(ctx context.Context, intf string, sourcePort,
destinationPort uint16) (err error)
} }
type Logger interface { type Logger interface {

View File

@@ -39,8 +39,9 @@ func NewLoop(settings settings.PortForwarding, routing Routing,
settings: Settings{ settings: Settings{
VPNIsUp: ptrTo(false), VPNIsUp: ptrTo(false),
Service: service.Settings{ Service: service.Settings{
Enabled: settings.Enabled, Enabled: settings.Enabled,
Filepath: *settings.Filepath, Filepath: *settings.Filepath,
ListeningPort: *settings.ListeningPort,
}, },
}, },
routing: routing, routing: routing,

View File

@@ -10,6 +10,8 @@ import (
type PortAllower interface { type PortAllower interface {
SetAllowedPort(ctx context.Context, port uint16, intf string) (err error) SetAllowedPort(ctx context.Context, port uint16, intf string) (err error)
RemoveAllowedPort(ctx context.Context, port uint16) (err error) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
RedirectPort(ctx context.Context, intf string, sourcePort,
destinationPort uint16) (err error)
} }
type Routing interface { type Routing interface {

View File

@@ -14,6 +14,7 @@ type Settings struct {
Filepath string Filepath string
Interface string // needed for PIA and ProtonVPN, tun0 for example Interface string // needed for PIA and ProtonVPN, tun0 for example
ServerName string // needed for PIA ServerName string // needed for PIA
ListeningPort uint16
} }
func (s Settings) Copy() (copied Settings) { func (s Settings) Copy() (copied Settings) {
@@ -22,6 +23,7 @@ func (s Settings) Copy() (copied Settings) {
copied.Filepath = s.Filepath copied.Filepath = s.Filepath
copied.Interface = s.Interface copied.Interface = s.Interface
copied.ServerName = s.ServerName copied.ServerName = s.ServerName
copied.ListeningPort = s.ListeningPort
return copied return copied
} }
@@ -31,6 +33,7 @@ func (s *Settings) OverrideWith(update Settings) {
s.Filepath = gosettings.OverrideWithString(s.Filepath, update.Filepath) s.Filepath = gosettings.OverrideWithString(s.Filepath, update.Filepath)
s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface) s.Interface = gosettings.OverrideWithString(s.Interface, update.Interface)
s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName) s.ServerName = gosettings.OverrideWithString(s.ServerName, update.ServerName)
s.ListeningPort = gosettings.OverrideWithNumber(s.ListeningPort, update.ListeningPort)
} }
var ( var (

View File

@@ -40,6 +40,13 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, err error)
return nil, fmt.Errorf("allowing port in firewall: %w", err) return nil, fmt.Errorf("allowing port in firewall: %w", err)
} }
if s.settings.ListeningPort != 0 {
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, port, s.settings.ListeningPort)
if err != nil {
return nil, fmt.Errorf("redirecting port in firewall: %w", err)
}
}
err = s.writePortForwardedFile(port) err = s.writePortForwardedFile(port)
if err != nil { if err != nil {
_ = s.cleanup() _ = s.cleanup()

View File

@@ -35,6 +35,15 @@ func (s *Service) cleanup() (err error) {
return fmt.Errorf("blocking previous port in firewall: %w", err) return fmt.Errorf("blocking previous port in firewall: %w", err)
} }
if s.settings.ListeningPort != 0 {
ctx := context.Background()
const listeningPort = 0 // 0 to clear the redirection
err = s.portAllower.RedirectPort(ctx, s.settings.Interface, s.port, listeningPort)
if err != nil {
return fmt.Errorf("removing previous port redirection in firewall: %w", err)
}
}
s.port = 0 s.port = 0
filepath := s.settings.Filepath filepath := s.settings.Filepath