Use PMTUD to set the MTU to the VPN interface

- Add `VPN_PMTUD` option enabled by default
- One can revert to use `VPN_PMTUD=off` to disable the new PMTUD mechanism
This commit is contained in:
Quentin McGaw
2025-09-10 14:43:21 +00:00
parent e21d798f57
commit 162d244865
12 changed files with 141 additions and 25 deletions

View File

@@ -77,6 +77,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
VPN_TYPE=openvpn \
# Common VPN options
VPN_INTERFACE=tun0 \
VPN_PMTUD=on \
# OpenVPN
OPENVPN_ENDPOINT_IP= \
OPENVPN_ENDPOINT_PORT= \

View File

@@ -580,6 +580,7 @@ type Linker interface {
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
LinkSetMTU(link netlink.Link, mtu int) error
}
type clier interface {

View File

@@ -18,6 +18,7 @@ type VPN struct {
Provider Provider `json:"provider"`
OpenVPN OpenVPN `json:"openvpn"`
Wireguard Wireguard `json:"wireguard"`
PMTUD *bool `json:"pmtud"`
}
// TODO v4 remove pointer for receiver (because of Surfshark).
@@ -54,6 +55,7 @@ func (v *VPN) Copy() (copied VPN) {
Provider: v.Provider.copy(),
OpenVPN: v.OpenVPN.copy(),
Wireguard: v.Wireguard.copy(),
PMTUD: gosettings.CopyPointer(v.PMTUD),
}
}
@@ -62,6 +64,7 @@ func (v *VPN) OverrideWith(other VPN) {
v.Provider.overrideWith(other.Provider)
v.OpenVPN.overrideWith(other.OpenVPN)
v.Wireguard.overrideWith(other.Wireguard)
v.PMTUD = gosettings.OverrideWithPointer(v.PMTUD, other.PMTUD)
}
func (v *VPN) setDefaults() {
@@ -69,6 +72,7 @@ func (v *VPN) setDefaults() {
v.Provider.setDefaults()
v.OpenVPN.setDefaults(v.Provider.Name)
v.Wireguard.setDefaults(v.Provider.Name)
v.PMTUD = gosettings.DefaultPointer(v.PMTUD, true)
}
func (v VPN) String() string {
@@ -86,6 +90,8 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
node.AppendNode(v.Wireguard.toLinesNode())
}
node.Appendf("Path MTU discovery update: %s", gosettings.BoolToYesNo(v.PMTUD))
return node
}
@@ -107,5 +113,10 @@ func (v *VPN) read(r *reader.Reader) (err error) {
return fmt.Errorf("wireguard: %w", err)
}
v.PMTUD, err = r.BoolPtr("VPN_PMTUD")
if err != nil {
return err
}
return nil
}

View File

@@ -62,6 +62,10 @@ func (n *NetLink) LinkSetDown(link Link) (err error) {
return netlink.LinkSetDown(linkToNetlinkLink(&link))
}
func (n *NetLink) LinkSetMTU(link Link, mtu int) error {
return netlink.LinkSetMTU(linkToNetlinkLink(&link), mtu)
}
type netlinkLinkImpl struct {
attrs *netlink.LinkAttrs
linkType string

View File

@@ -1,10 +1,24 @@
package pmtud
import (
"context"
"errors"
"fmt"
"net"
"time"
)
var (
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
)
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
switch {
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
case timedCtx.Err() != nil:
err = timedCtx.Err()
}
return err
}

View File

@@ -73,6 +73,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
@@ -84,6 +85,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
@@ -135,7 +137,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
if inboundID == outboundID {
return physicalLinkMTU, nil
}
logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d",
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID)
continue
default:

View File

@@ -53,6 +53,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("writing ICMP message: %w", err)
}
@@ -64,6 +65,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
bytesRead, _, err := conn.ReadFrom(buffer)
if err != nil {
err = wrapConnErr(err, ctx, pingTimeout)
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
}
packetBytes := buffer[:bytesRead]
@@ -106,7 +108,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
if inboundID == outboundID {
return physicalLinkMTU, nil
}
logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d",
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
inboundID, outboundID)
continue
default:

View File

@@ -81,6 +81,7 @@ type Linker interface {
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
LinkSetMTU(link netlink.Link, mtu int) (err error)
}
type DNSLoop interface {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/provider"
)
@@ -14,39 +15,39 @@ import (
func setupOpenVPN(ctx context.Context, fw Firewall,
openvpnConf OpenVPN, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, starter CmdStarter,
logger openvpn.Logger) (runner *openvpn.Runner, serverName string,
canPortForward bool, err error,
logger openvpn.Logger) (runner *openvpn.Runner,
connection models.Connection, err error,
) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
if err != nil {
return nil, "", false, fmt.Errorf("finding a valid server connection: %w", err)
return nil, models.Connection{}, fmt.Errorf("finding a valid server connection: %w", err)
}
lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported)
if err := openvpnConf.WriteConfig(lines); err != nil {
return nil, "", false, fmt.Errorf("writing configuration to file: %w", err)
return nil, models.Connection{}, fmt.Errorf("writing configuration to file: %w", err)
}
if *settings.OpenVPN.User != "" {
err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password)
if err != nil {
return nil, "", false, fmt.Errorf("writing auth to file: %w", err)
return nil, models.Connection{}, fmt.Errorf("writing auth to file: %w", err)
}
}
if *settings.OpenVPN.KeyPassphrase != "" {
err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase)
if err != nil {
return nil, "", false, fmt.Errorf("writing askpass file: %w", err)
return nil, models.Connection{}, fmt.Errorf("writing askpass file: %w", err)
}
}
if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil {
return nil, "", false, fmt.Errorf("allowing VPN connection through firewall: %w", err)
return nil, models.Connection{}, fmt.Errorf("allowing VPN connection through firewall: %w", err)
}
runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)
return runner, connection.ServerName, connection.PortForward, nil
return runner, connection, nil
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/vpn"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/log"
)
@@ -28,17 +29,17 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
var vpnRunner interface {
Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{})
}
var serverName, vpnInterface string
var canPortForward bool
var vpnInterface string
var connection models.Connection
var err error
subLogger := l.logger.New(log.SetComponent(settings.Type))
if settings.Type == vpn.OpenVPN {
vpnInterface = settings.OpenVPN.Interface
vpnRunner, serverName, canPortForward, err = setupOpenVPN(ctx, l.fw,
vpnRunner, connection, err = setupOpenVPN(ctx, l.fw,
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
} else { // Wireguard
vpnInterface = settings.Wireguard.Interface
vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw,
vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw,
providerConf, settings, l.ipv6Supported, subLogger)
}
if err != nil {
@@ -46,8 +47,11 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
continue
}
tunnelUpData := tunnelUpData{
serverName: serverName,
canPortForward: canPortForward,
PMTUD: *settings.PMTUD,
serverIP: connection.IP,
vpnType: settings.Type,
serverName: connection.ServerName,
canPortForward: connection.PortForward,
portForwarder: portForwarder,
vpnIntf: vpnInterface,
username: settings.Provider.PortForwarding.Username,

View File

@@ -2,15 +2,32 @@ package vpn
import (
"context"
"errors"
"fmt"
"net/netip"
"time"
"github.com/qdm12/dns/v2/pkg/check"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/pmtud"
"github.com/qdm12/gluetun/internal/version"
"github.com/qdm12/log"
)
type tunnelUpData struct {
// Port forwarding
// vpnIntf is the name of the VPN network interface
// which is used both for port forwarding and MTU discovery
vpnIntf string
// Path MTU discovery fields:
// PMTUD indicates whether to perform Path MTU Discovery and
// adjust the VPN interface MTU accordingly.
PMTUD bool
// serverIP is used for path MTU discovery
serverIP netip.Addr
// vpnType is used for path MTU discovery to find the protocol overhead.
// It can be "wireguard" or "openvpn".
vpnType string
// Port forwarding fields:
serverName string // used for PIA
canPortForward bool // used for PIA
username string // used for PIA
@@ -21,6 +38,16 @@ type tunnelUpData struct {
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
l.client.CloseIdleConnections()
if data.PMTUD {
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
mtuLogger.Info("finding maximum MTU, this takes around 3 seconds")
err := updateToMaxMTU(ctx, data.vpnIntf, data.serverIP, data.vpnType,
l.netLinker, mtuLogger)
if err != nil {
l.logger.Error(err.Error())
}
}
for _, vpnPort := range l.vpnInputPorts {
err := l.fw.SetAllowedPort(ctx, vpnPort, data.vpnIntf)
if err != nil {
@@ -57,3 +84,50 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
l.logger.Error(err.Error())
}
}
var errVPNTypeUnknown = errors.New("unknown VPN type")
func updateToMaxMTU(ctx context.Context, vpnInterface string,
serverIP netip.Addr, vpnType string, netlinker NetLinker, logger *log.Logger,
) error {
link, err := netlinker.LinkByName(vpnInterface)
if err != nil {
return fmt.Errorf("getting VPN interface by name: %w", err)
}
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
// protocol overhead, so start lower than 1500 according to the protocol used.
const physicalLinkMTU = 1500
vpnLinkMTU := physicalLinkMTU
switch vpnType {
case "wireguard":
vpnLinkMTU -= 60 // Wireguard overhead
case "openvpn":
vpnLinkMTU -= 41 // OpenVPN overhead
default:
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
}
// Setting the VPN link MTU to 1500 might interrupt the connection until
// the new MTU is set again, but this is necessary to find the highest valid MTU.
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
const pingTimeout = time.Second
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, serverIP, vpnLinkMTU, pingTimeout, logger)
if err != nil {
return fmt.Errorf("path MTU discovering: %w", err)
}
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
if err != nil {
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
}
logger.Infof("VPN interface %s MTU set to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/provider/utils"
"github.com/qdm12/gluetun/internal/wireguard"
@@ -16,11 +17,11 @@ import (
func setupWireguard(ctx context.Context, netlinker NetLinker,
fw Firewall, providerConf provider.Provider,
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error,
wireguarder *wireguard.Wireguard, connection models.Connection, err error,
) {
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
if err != nil {
return nil, "", false, fmt.Errorf("finding a VPN server: %w", err)
return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err)
}
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported)
@@ -31,13 +32,13 @@ func setupWireguard(ctx context.Context, netlinker NetLinker,
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
if err != nil {
return nil, "", false, fmt.Errorf("creating Wireguard: %w", err)
return nil, models.Connection{}, fmt.Errorf("creating Wireguard: %w", err)
}
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
if err != nil {
return nil, "", false, fmt.Errorf("setting firewall: %w", err)
return nil, models.Connection{}, fmt.Errorf("setting firewall: %w", err)
}
return wireguarder, connection.ServerName, connection.PortForward, nil
return wireguarder, connection, nil
}