Use PMTUD to set the MTU to the VPN interface
- Add `VPN_PMTUD` option enabled by default - One can revert to use `VPN_PMTUD=off` to disable the new PMTUD mechanism
This commit is contained in:
@@ -77,6 +77,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
||||
VPN_TYPE=openvpn \
|
||||
# Common VPN options
|
||||
VPN_INTERFACE=tun0 \
|
||||
VPN_PMTUD=on \
|
||||
# OpenVPN
|
||||
OPENVPN_ENDPOINT_IP= \
|
||||
OPENVPN_ENDPOINT_PORT= \
|
||||
|
||||
@@ -580,6 +580,7 @@ type Linker interface {
|
||||
LinkDel(link netlink.Link) (err error)
|
||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||
LinkSetDown(link netlink.Link) (err error)
|
||||
LinkSetMTU(link netlink.Link, mtu int) error
|
||||
}
|
||||
|
||||
type clier interface {
|
||||
|
||||
@@ -18,6 +18,7 @@ type VPN struct {
|
||||
Provider Provider `json:"provider"`
|
||||
OpenVPN OpenVPN `json:"openvpn"`
|
||||
Wireguard Wireguard `json:"wireguard"`
|
||||
PMTUD *bool `json:"pmtud"`
|
||||
}
|
||||
|
||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||
@@ -54,6 +55,7 @@ func (v *VPN) Copy() (copied VPN) {
|
||||
Provider: v.Provider.copy(),
|
||||
OpenVPN: v.OpenVPN.copy(),
|
||||
Wireguard: v.Wireguard.copy(),
|
||||
PMTUD: gosettings.CopyPointer(v.PMTUD),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +64,7 @@ func (v *VPN) OverrideWith(other VPN) {
|
||||
v.Provider.overrideWith(other.Provider)
|
||||
v.OpenVPN.overrideWith(other.OpenVPN)
|
||||
v.Wireguard.overrideWith(other.Wireguard)
|
||||
v.PMTUD = gosettings.OverrideWithPointer(v.PMTUD, other.PMTUD)
|
||||
}
|
||||
|
||||
func (v *VPN) setDefaults() {
|
||||
@@ -69,6 +72,7 @@ func (v *VPN) setDefaults() {
|
||||
v.Provider.setDefaults()
|
||||
v.OpenVPN.setDefaults(v.Provider.Name)
|
||||
v.Wireguard.setDefaults(v.Provider.Name)
|
||||
v.PMTUD = gosettings.DefaultPointer(v.PMTUD, true)
|
||||
}
|
||||
|
||||
func (v VPN) String() string {
|
||||
@@ -86,6 +90,8 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
|
||||
node.AppendNode(v.Wireguard.toLinesNode())
|
||||
}
|
||||
|
||||
node.Appendf("Path MTU discovery update: %s", gosettings.BoolToYesNo(v.PMTUD))
|
||||
|
||||
return node
|
||||
}
|
||||
|
||||
@@ -107,5 +113,10 @@ func (v *VPN) read(r *reader.Reader) (err error) {
|
||||
return fmt.Errorf("wireguard: %w", err)
|
||||
}
|
||||
|
||||
v.PMTUD, err = r.BoolPtr("VPN_PMTUD")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -62,6 +62,10 @@ func (n *NetLink) LinkSetDown(link Link) (err error) {
|
||||
return netlink.LinkSetDown(linkToNetlinkLink(&link))
|
||||
}
|
||||
|
||||
func (n *NetLink) LinkSetMTU(link Link, mtu int) error {
|
||||
return netlink.LinkSetMTU(linkToNetlinkLink(&link), mtu)
|
||||
}
|
||||
|
||||
type netlinkLinkImpl struct {
|
||||
attrs *netlink.LinkAttrs
|
||||
linkType string
|
||||
|
||||
@@ -1,10 +1,24 @@
|
||||
package pmtud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
|
||||
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||
)
|
||||
|
||||
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
|
||||
switch {
|
||||
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
|
||||
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
|
||||
case timedCtx.Err() != nil:
|
||||
err = timedCtx.Err()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -73,6 +73,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
|
||||
@@ -84,6 +85,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
@@ -135,7 +137,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
||||
if inboundID == outboundID {
|
||||
return physicalLinkMTU, nil
|
||||
}
|
||||
logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
|
||||
@@ -53,6 +53,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
|
||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||
}
|
||||
|
||||
@@ -64,6 +65,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
err = wrapConnErr(err, ctx, pingTimeout)
|
||||
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
||||
}
|
||||
packetBytes := buffer[:bytesRead]
|
||||
@@ -106,7 +108,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
||||
if inboundID == outboundID {
|
||||
return physicalLinkMTU, nil
|
||||
}
|
||||
logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||
inboundID, outboundID)
|
||||
continue
|
||||
default:
|
||||
|
||||
@@ -81,6 +81,7 @@ type Linker interface {
|
||||
LinkDel(link netlink.Link) (err error)
|
||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||
LinkSetDown(link netlink.Link) (err error)
|
||||
LinkSetMTU(link netlink.Link, mtu int) (err error)
|
||||
}
|
||||
|
||||
type DNSLoop interface {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/openvpn"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
)
|
||||
@@ -14,39 +15,39 @@ import (
|
||||
func setupOpenVPN(ctx context.Context, fw Firewall,
|
||||
openvpnConf OpenVPN, providerConf provider.Provider,
|
||||
settings settings.VPN, ipv6Supported bool, starter CmdStarter,
|
||||
logger openvpn.Logger) (runner *openvpn.Runner, serverName string,
|
||||
canPortForward bool, err error,
|
||||
logger openvpn.Logger) (runner *openvpn.Runner,
|
||||
connection models.Connection, err error,
|
||||
) {
|
||||
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||
if err != nil {
|
||||
return nil, "", false, fmt.Errorf("finding a valid server connection: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("finding a valid server connection: %w", err)
|
||||
}
|
||||
|
||||
lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported)
|
||||
|
||||
if err := openvpnConf.WriteConfig(lines); err != nil {
|
||||
return nil, "", false, fmt.Errorf("writing configuration to file: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("writing configuration to file: %w", err)
|
||||
}
|
||||
|
||||
if *settings.OpenVPN.User != "" {
|
||||
err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password)
|
||||
if err != nil {
|
||||
return nil, "", false, fmt.Errorf("writing auth to file: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("writing auth to file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if *settings.OpenVPN.KeyPassphrase != "" {
|
||||
err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase)
|
||||
if err != nil {
|
||||
return nil, "", false, fmt.Errorf("writing askpass file: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("writing askpass file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil {
|
||||
return nil, "", false, fmt.Errorf("allowing VPN connection through firewall: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("allowing VPN connection through firewall: %w", err)
|
||||
}
|
||||
|
||||
runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)
|
||||
|
||||
return runner, connection.ServerName, connection.PortForward, nil
|
||||
return runner, connection, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/log"
|
||||
)
|
||||
|
||||
@@ -28,17 +29,17 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
var vpnRunner interface {
|
||||
Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{})
|
||||
}
|
||||
var serverName, vpnInterface string
|
||||
var canPortForward bool
|
||||
var vpnInterface string
|
||||
var connection models.Connection
|
||||
var err error
|
||||
subLogger := l.logger.New(log.SetComponent(settings.Type))
|
||||
if settings.Type == vpn.OpenVPN {
|
||||
vpnInterface = settings.OpenVPN.Interface
|
||||
vpnRunner, serverName, canPortForward, err = setupOpenVPN(ctx, l.fw,
|
||||
vpnRunner, connection, err = setupOpenVPN(ctx, l.fw,
|
||||
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
|
||||
} else { // Wireguard
|
||||
vpnInterface = settings.Wireguard.Interface
|
||||
vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw,
|
||||
vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw,
|
||||
providerConf, settings, l.ipv6Supported, subLogger)
|
||||
}
|
||||
if err != nil {
|
||||
@@ -46,8 +47,11 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
continue
|
||||
}
|
||||
tunnelUpData := tunnelUpData{
|
||||
serverName: serverName,
|
||||
canPortForward: canPortForward,
|
||||
PMTUD: *settings.PMTUD,
|
||||
serverIP: connection.IP,
|
||||
vpnType: settings.Type,
|
||||
serverName: connection.ServerName,
|
||||
canPortForward: connection.PortForward,
|
||||
portForwarder: portForwarder,
|
||||
vpnIntf: vpnInterface,
|
||||
username: settings.Provider.PortForwarding.Username,
|
||||
|
||||
@@ -2,15 +2,32 @@ package vpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/dns/v2/pkg/check"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/pmtud"
|
||||
"github.com/qdm12/gluetun/internal/version"
|
||||
"github.com/qdm12/log"
|
||||
)
|
||||
|
||||
type tunnelUpData struct {
|
||||
// Port forwarding
|
||||
// vpnIntf is the name of the VPN network interface
|
||||
// which is used both for port forwarding and MTU discovery
|
||||
vpnIntf string
|
||||
// Path MTU discovery fields:
|
||||
// PMTUD indicates whether to perform Path MTU Discovery and
|
||||
// adjust the VPN interface MTU accordingly.
|
||||
PMTUD bool
|
||||
// serverIP is used for path MTU discovery
|
||||
serverIP netip.Addr
|
||||
// vpnType is used for path MTU discovery to find the protocol overhead.
|
||||
// It can be "wireguard" or "openvpn".
|
||||
vpnType string
|
||||
// Port forwarding fields:
|
||||
serverName string // used for PIA
|
||||
canPortForward bool // used for PIA
|
||||
username string // used for PIA
|
||||
@@ -21,6 +38,16 @@ type tunnelUpData struct {
|
||||
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
||||
l.client.CloseIdleConnections()
|
||||
|
||||
if data.PMTUD {
|
||||
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
|
||||
mtuLogger.Info("finding maximum MTU, this takes around 3 seconds")
|
||||
err := updateToMaxMTU(ctx, data.vpnIntf, data.serverIP, data.vpnType,
|
||||
l.netLinker, mtuLogger)
|
||||
if err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
for _, vpnPort := range l.vpnInputPorts {
|
||||
err := l.fw.SetAllowedPort(ctx, vpnPort, data.vpnIntf)
|
||||
if err != nil {
|
||||
@@ -57,3 +84,50 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
var errVPNTypeUnknown = errors.New("unknown VPN type")
|
||||
|
||||
func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
||||
serverIP netip.Addr, vpnType string, netlinker NetLinker, logger *log.Logger,
|
||||
) error {
|
||||
link, err := netlinker.LinkByName(vpnInterface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting VPN interface by name: %w", err)
|
||||
}
|
||||
|
||||
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
|
||||
// protocol overhead, so start lower than 1500 according to the protocol used.
|
||||
const physicalLinkMTU = 1500
|
||||
vpnLinkMTU := physicalLinkMTU
|
||||
switch vpnType {
|
||||
case "wireguard":
|
||||
vpnLinkMTU -= 60 // Wireguard overhead
|
||||
case "openvpn":
|
||||
vpnLinkMTU -= 41 // OpenVPN overhead
|
||||
default:
|
||||
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
|
||||
}
|
||||
|
||||
// Setting the VPN link MTU to 1500 might interrupt the connection until
|
||||
// the new MTU is set again, but this is necessary to find the highest valid MTU.
|
||||
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
|
||||
|
||||
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
||||
}
|
||||
|
||||
const pingTimeout = time.Second
|
||||
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, serverIP, vpnLinkMTU, pingTimeout, logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("path MTU discovering: %w", err)
|
||||
}
|
||||
|
||||
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
||||
}
|
||||
|
||||
logger.Infof("VPN interface %s MTU set to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||
"github.com/qdm12/gluetun/internal/wireguard"
|
||||
@@ -16,11 +17,11 @@ import (
|
||||
func setupWireguard(ctx context.Context, netlinker NetLinker,
|
||||
fw Firewall, providerConf provider.Provider,
|
||||
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
|
||||
wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error,
|
||||
wireguarder *wireguard.Wireguard, connection models.Connection, err error,
|
||||
) {
|
||||
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||
if err != nil {
|
||||
return nil, "", false, fmt.Errorf("finding a VPN server: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err)
|
||||
}
|
||||
|
||||
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported)
|
||||
@@ -31,13 +32,13 @@ func setupWireguard(ctx context.Context, netlinker NetLinker,
|
||||
|
||||
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
|
||||
if err != nil {
|
||||
return nil, "", false, fmt.Errorf("creating Wireguard: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("creating Wireguard: %w", err)
|
||||
}
|
||||
|
||||
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
|
||||
if err != nil {
|
||||
return nil, "", false, fmt.Errorf("setting firewall: %w", err)
|
||||
return nil, models.Connection{}, fmt.Errorf("setting firewall: %w", err)
|
||||
}
|
||||
|
||||
return wireguarder, connection.ServerName, connection.PortForward, nil
|
||||
return wireguarder, connection, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user