Use PMTUD to set the MTU to the VPN interface

- Add `VPN_PMTUD` option enabled by default
- One can revert to use `VPN_PMTUD=off` to disable the new PMTUD mechanism
This commit is contained in:
Quentin McGaw
2025-09-10 14:43:21 +00:00
parent e21d798f57
commit 162d244865
12 changed files with 141 additions and 25 deletions

View File

@@ -77,6 +77,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
VPN_TYPE=openvpn \ VPN_TYPE=openvpn \
# Common VPN options # Common VPN options
VPN_INTERFACE=tun0 \ VPN_INTERFACE=tun0 \
VPN_PMTUD=on \
# OpenVPN # OpenVPN
OPENVPN_ENDPOINT_IP= \ OPENVPN_ENDPOINT_IP= \
OPENVPN_ENDPOINT_PORT= \ OPENVPN_ENDPOINT_PORT= \

View File

@@ -580,6 +580,7 @@ type Linker interface {
LinkDel(link netlink.Link) (err error) LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error) LinkSetDown(link netlink.Link) (err error)
LinkSetMTU(link netlink.Link, mtu int) error
} }
type clier interface { type clier interface {

View File

@@ -18,6 +18,7 @@ type VPN struct {
Provider Provider `json:"provider"` Provider Provider `json:"provider"`
OpenVPN OpenVPN `json:"openvpn"` OpenVPN OpenVPN `json:"openvpn"`
Wireguard Wireguard `json:"wireguard"` Wireguard Wireguard `json:"wireguard"`
PMTUD *bool `json:"pmtud"`
} }
// TODO v4 remove pointer for receiver (because of Surfshark). // TODO v4 remove pointer for receiver (because of Surfshark).
@@ -54,6 +55,7 @@ func (v *VPN) Copy() (copied VPN) {
Provider: v.Provider.copy(), Provider: v.Provider.copy(),
OpenVPN: v.OpenVPN.copy(), OpenVPN: v.OpenVPN.copy(),
Wireguard: v.Wireguard.copy(), Wireguard: v.Wireguard.copy(),
PMTUD: gosettings.CopyPointer(v.PMTUD),
} }
} }
@@ -62,6 +64,7 @@ func (v *VPN) OverrideWith(other VPN) {
v.Provider.overrideWith(other.Provider) v.Provider.overrideWith(other.Provider)
v.OpenVPN.overrideWith(other.OpenVPN) v.OpenVPN.overrideWith(other.OpenVPN)
v.Wireguard.overrideWith(other.Wireguard) v.Wireguard.overrideWith(other.Wireguard)
v.PMTUD = gosettings.OverrideWithPointer(v.PMTUD, other.PMTUD)
} }
func (v *VPN) setDefaults() { func (v *VPN) setDefaults() {
@@ -69,6 +72,7 @@ func (v *VPN) setDefaults() {
v.Provider.setDefaults() v.Provider.setDefaults()
v.OpenVPN.setDefaults(v.Provider.Name) v.OpenVPN.setDefaults(v.Provider.Name)
v.Wireguard.setDefaults(v.Provider.Name) v.Wireguard.setDefaults(v.Provider.Name)
v.PMTUD = gosettings.DefaultPointer(v.PMTUD, true)
} }
func (v VPN) String() string { func (v VPN) String() string {
@@ -86,6 +90,8 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
node.AppendNode(v.Wireguard.toLinesNode()) node.AppendNode(v.Wireguard.toLinesNode())
} }
node.Appendf("Path MTU discovery update: %s", gosettings.BoolToYesNo(v.PMTUD))
return node return node
} }
@@ -107,5 +113,10 @@ func (v *VPN) read(r *reader.Reader) (err error) {
return fmt.Errorf("wireguard: %w", err) return fmt.Errorf("wireguard: %w", err)
} }
v.PMTUD, err = r.BoolPtr("VPN_PMTUD")
if err != nil {
return err
}
return nil return nil
} }

View File

@@ -62,6 +62,10 @@ func (n *NetLink) LinkSetDown(link Link) (err error) {
return netlink.LinkSetDown(linkToNetlinkLink(&link)) return netlink.LinkSetDown(linkToNetlinkLink(&link))
} }
func (n *NetLink) LinkSetMTU(link Link, mtu int) error {
return netlink.LinkSetMTU(linkToNetlinkLink(&link), mtu)
}
type netlinkLinkImpl struct { type netlinkLinkImpl struct {
attrs *netlink.LinkAttrs attrs *netlink.LinkAttrs
linkType string linkType string

View File

@@ -1,10 +1,24 @@
package pmtud package pmtud
import ( import (
"context"
"errors" "errors"
"fmt"
"net"
"time"
) )
var ( var (
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable") ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported") ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
) )
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
switch {
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
case timedCtx.Err() != nil:
err = timedCtx.Err()
}
return err
}

View File

@@ -73,6 +73,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()}) _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil { if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err) return 0, fmt.Errorf("writing ICMP message: %w", err)
} }
@@ -84,6 +85,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer) bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil { if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err) return 0, fmt.Errorf("reading from ICMP connection: %w", err)
} }
packetBytes := buffer[:bytesRead] packetBytes := buffer[:bytesRead]
@@ -135,7 +137,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
if inboundID == outboundID { if inboundID == outboundID {
return physicalLinkMTU, nil return physicalLinkMTU, nil
} }
logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d", logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID) inboundID, outboundID)
continue continue
default: default:

View File

@@ -53,6 +53,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()}) _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
if err != nil { if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err) return 0, fmt.Errorf("writing ICMP message: %w", err)
} }
@@ -64,6 +65,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer) bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil { if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err) return 0, fmt.Errorf("reading from ICMP connection: %w", err)
} }
packetBytes := buffer[:bytesRead] packetBytes := buffer[:bytesRead]
@@ -106,7 +108,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
if inboundID == outboundID { if inboundID == outboundID {
return physicalLinkMTU, nil return physicalLinkMTU, nil
} }
logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d", logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID) inboundID, outboundID)
continue continue
default: default:

View File

@@ -81,6 +81,7 @@ type Linker interface {
LinkDel(link netlink.Link) (err error) LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error) LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error) LinkSetDown(link netlink.Link) (err error)
LinkSetMTU(link netlink.Link, mtu int) (err error)
} }
type DNSLoop interface { type DNSLoop interface {

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn" "github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
) )
@@ -14,39 +15,39 @@ import (
func setupOpenVPN(ctx context.Context, fw Firewall, func setupOpenVPN(ctx context.Context, fw Firewall,
openvpnConf OpenVPN, providerConf provider.Provider, openvpnConf OpenVPN, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, starter CmdStarter, settings settings.VPN, ipv6Supported bool, starter CmdStarter,
logger openvpn.Logger) (runner *openvpn.Runner, serverName string, logger openvpn.Logger) (runner *openvpn.Runner,
canPortForward bool, err error, connection models.Connection, err error,
) { ) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
if err != nil { if err != nil {
return nil, "", false, fmt.Errorf("finding a valid server connection: %w", err) return nil, models.Connection{}, fmt.Errorf("finding a valid server connection: %w", err)
} }
lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported) lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported)
if err := openvpnConf.WriteConfig(lines); err != nil { if err := openvpnConf.WriteConfig(lines); err != nil {
return nil, "", false, fmt.Errorf("writing configuration to file: %w", err) return nil, models.Connection{}, fmt.Errorf("writing configuration to file: %w", err)
} }
if *settings.OpenVPN.User != "" { if *settings.OpenVPN.User != "" {
err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password) err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password)
if err != nil { if err != nil {
return nil, "", false, fmt.Errorf("writing auth to file: %w", err) return nil, models.Connection{}, fmt.Errorf("writing auth to file: %w", err)
} }
} }
if *settings.OpenVPN.KeyPassphrase != "" { if *settings.OpenVPN.KeyPassphrase != "" {
err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase) err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase)
if err != nil { if err != nil {
return nil, "", false, fmt.Errorf("writing askpass file: %w", err) return nil, models.Connection{}, fmt.Errorf("writing askpass file: %w", err)
} }
} }
if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil { if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil {
return nil, "", false, fmt.Errorf("allowing VPN connection through firewall: %w", err) return nil, models.Connection{}, fmt.Errorf("allowing VPN connection through firewall: %w", err)
} }
runner = openvpn.NewRunner(settings.OpenVPN, starter, logger) runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)
return runner, connection.ServerName, connection.PortForward, nil return runner, connection, nil
} }

View File

@@ -5,6 +5,7 @@ import (
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/log" "github.com/qdm12/log"
) )
@@ -28,17 +29,17 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
var vpnRunner interface { var vpnRunner interface {
Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{}) Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{})
} }
var serverName, vpnInterface string var vpnInterface string
var canPortForward bool var connection models.Connection
var err error var err error
subLogger := l.logger.New(log.SetComponent(settings.Type)) subLogger := l.logger.New(log.SetComponent(settings.Type))
if settings.Type == vpn.OpenVPN { if settings.Type == vpn.OpenVPN {
vpnInterface = settings.OpenVPN.Interface vpnInterface = settings.OpenVPN.Interface
vpnRunner, serverName, canPortForward, err = setupOpenVPN(ctx, l.fw, vpnRunner, connection, err = setupOpenVPN(ctx, l.fw,
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger) l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
} else { // Wireguard } else { // Wireguard
vpnInterface = settings.Wireguard.Interface vpnInterface = settings.Wireguard.Interface
vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw, vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw,
providerConf, settings, l.ipv6Supported, subLogger) providerConf, settings, l.ipv6Supported, subLogger)
} }
if err != nil { if err != nil {
@@ -46,8 +47,11 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
continue continue
} }
tunnelUpData := tunnelUpData{ tunnelUpData := tunnelUpData{
serverName: serverName, PMTUD: *settings.PMTUD,
canPortForward: canPortForward, serverIP: connection.IP,
vpnType: settings.Type,
serverName: connection.ServerName,
canPortForward: connection.PortForward,
portForwarder: portForwarder, portForwarder: portForwarder,
vpnIntf: vpnInterface, vpnIntf: vpnInterface,
username: settings.Provider.PortForwarding.Username, username: settings.Provider.PortForwarding.Username,

View File

@@ -2,15 +2,32 @@ package vpn
import ( import (
"context" "context"
"errors"
"fmt"
"net/netip"
"time"
"github.com/qdm12/dns/v2/pkg/check" "github.com/qdm12/dns/v2/pkg/check"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/pmtud"
"github.com/qdm12/gluetun/internal/version" "github.com/qdm12/gluetun/internal/version"
"github.com/qdm12/log"
) )
type tunnelUpData struct { type tunnelUpData struct {
// Port forwarding // vpnIntf is the name of the VPN network interface
vpnIntf string // which is used both for port forwarding and MTU discovery
vpnIntf string
// Path MTU discovery fields:
// PMTUD indicates whether to perform Path MTU Discovery and
// adjust the VPN interface MTU accordingly.
PMTUD bool
// serverIP is used for path MTU discovery
serverIP netip.Addr
// vpnType is used for path MTU discovery to find the protocol overhead.
// It can be "wireguard" or "openvpn".
vpnType string
// Port forwarding fields:
serverName string // used for PIA serverName string // used for PIA
canPortForward bool // used for PIA canPortForward bool // used for PIA
username string // used for PIA username string // used for PIA
@@ -21,6 +38,16 @@ type tunnelUpData struct {
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
l.client.CloseIdleConnections() l.client.CloseIdleConnections()
if data.PMTUD {
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
mtuLogger.Info("finding maximum MTU, this takes around 3 seconds")
err := updateToMaxMTU(ctx, data.vpnIntf, data.serverIP, data.vpnType,
l.netLinker, mtuLogger)
if err != nil {
l.logger.Error(err.Error())
}
}
for _, vpnPort := range l.vpnInputPorts { for _, vpnPort := range l.vpnInputPorts {
err := l.fw.SetAllowedPort(ctx, vpnPort, data.vpnIntf) err := l.fw.SetAllowedPort(ctx, vpnPort, data.vpnIntf)
if err != nil { if err != nil {
@@ -57,3 +84,50 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
l.logger.Error(err.Error()) l.logger.Error(err.Error())
} }
} }
var errVPNTypeUnknown = errors.New("unknown VPN type")
func updateToMaxMTU(ctx context.Context, vpnInterface string,
serverIP netip.Addr, vpnType string, netlinker NetLinker, logger *log.Logger,
) error {
link, err := netlinker.LinkByName(vpnInterface)
if err != nil {
return fmt.Errorf("getting VPN interface by name: %w", err)
}
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
// protocol overhead, so start lower than 1500 according to the protocol used.
const physicalLinkMTU = 1500
vpnLinkMTU := physicalLinkMTU
switch vpnType {
case "wireguard":
vpnLinkMTU -= 60 // Wireguard overhead
case "openvpn":
vpnLinkMTU -= 41 // OpenVPN overhead
default:
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
}
// Setting the VPN link MTU to 1500 might interrupt the connection until
// the new MTU is set again, but this is necessary to find the highest valid MTU.
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
const pingTimeout = time.Second
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, serverIP, vpnLinkMTU, pingTimeout, logger)
if err != nil {
return fmt.Errorf("path MTU discovering: %w", err)
}
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
logger.Infof("VPN interface %s MTU set to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/gluetun/internal/wireguard" "github.com/qdm12/gluetun/internal/wireguard"
@@ -16,11 +17,11 @@ import (
func setupWireguard(ctx context.Context, netlinker NetLinker, func setupWireguard(ctx context.Context, netlinker NetLinker,
fw Firewall, providerConf provider.Provider, fw Firewall, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) ( settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error, wireguarder *wireguard.Wireguard, connection models.Connection, err error,
) { ) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
if err != nil { if err != nil {
return nil, "", false, fmt.Errorf("finding a VPN server: %w", err) return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err)
} }
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported) wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported)
@@ -31,13 +32,13 @@ func setupWireguard(ctx context.Context, netlinker NetLinker,
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger) wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
if err != nil { if err != nil {
return nil, "", false, fmt.Errorf("creating Wireguard: %w", err) return nil, models.Connection{}, fmt.Errorf("creating Wireguard: %w", err)
} }
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface) err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
if err != nil { if err != nil {
return nil, "", false, fmt.Errorf("setting firewall: %w", err) return nil, models.Connection{}, fmt.Errorf("setting firewall: %w", err)
} }
return wireguarder, connection.ServerName, connection.PortForward, nil return wireguarder, connection, nil
} }