Use PMTUD to set the MTU to the VPN interface
- Add `VPN_PMTUD` option enabled by default - One can revert to use `VPN_PMTUD=off` to disable the new PMTUD mechanism
This commit is contained in:
@@ -77,6 +77,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
|
|||||||
VPN_TYPE=openvpn \
|
VPN_TYPE=openvpn \
|
||||||
# Common VPN options
|
# Common VPN options
|
||||||
VPN_INTERFACE=tun0 \
|
VPN_INTERFACE=tun0 \
|
||||||
|
VPN_PMTUD=on \
|
||||||
# OpenVPN
|
# OpenVPN
|
||||||
OPENVPN_ENDPOINT_IP= \
|
OPENVPN_ENDPOINT_IP= \
|
||||||
OPENVPN_ENDPOINT_PORT= \
|
OPENVPN_ENDPOINT_PORT= \
|
||||||
|
|||||||
@@ -580,6 +580,7 @@ type Linker interface {
|
|||||||
LinkDel(link netlink.Link) (err error)
|
LinkDel(link netlink.Link) (err error)
|
||||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||||
LinkSetDown(link netlink.Link) (err error)
|
LinkSetDown(link netlink.Link) (err error)
|
||||||
|
LinkSetMTU(link netlink.Link, mtu int) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type clier interface {
|
type clier interface {
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type VPN struct {
|
|||||||
Provider Provider `json:"provider"`
|
Provider Provider `json:"provider"`
|
||||||
OpenVPN OpenVPN `json:"openvpn"`
|
OpenVPN OpenVPN `json:"openvpn"`
|
||||||
Wireguard Wireguard `json:"wireguard"`
|
Wireguard Wireguard `json:"wireguard"`
|
||||||
|
PMTUD *bool `json:"pmtud"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO v4 remove pointer for receiver (because of Surfshark).
|
// TODO v4 remove pointer for receiver (because of Surfshark).
|
||||||
@@ -54,6 +55,7 @@ func (v *VPN) Copy() (copied VPN) {
|
|||||||
Provider: v.Provider.copy(),
|
Provider: v.Provider.copy(),
|
||||||
OpenVPN: v.OpenVPN.copy(),
|
OpenVPN: v.OpenVPN.copy(),
|
||||||
Wireguard: v.Wireguard.copy(),
|
Wireguard: v.Wireguard.copy(),
|
||||||
|
PMTUD: gosettings.CopyPointer(v.PMTUD),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,6 +64,7 @@ func (v *VPN) OverrideWith(other VPN) {
|
|||||||
v.Provider.overrideWith(other.Provider)
|
v.Provider.overrideWith(other.Provider)
|
||||||
v.OpenVPN.overrideWith(other.OpenVPN)
|
v.OpenVPN.overrideWith(other.OpenVPN)
|
||||||
v.Wireguard.overrideWith(other.Wireguard)
|
v.Wireguard.overrideWith(other.Wireguard)
|
||||||
|
v.PMTUD = gosettings.OverrideWithPointer(v.PMTUD, other.PMTUD)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *VPN) setDefaults() {
|
func (v *VPN) setDefaults() {
|
||||||
@@ -69,6 +72,7 @@ func (v *VPN) setDefaults() {
|
|||||||
v.Provider.setDefaults()
|
v.Provider.setDefaults()
|
||||||
v.OpenVPN.setDefaults(v.Provider.Name)
|
v.OpenVPN.setDefaults(v.Provider.Name)
|
||||||
v.Wireguard.setDefaults(v.Provider.Name)
|
v.Wireguard.setDefaults(v.Provider.Name)
|
||||||
|
v.PMTUD = gosettings.DefaultPointer(v.PMTUD, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v VPN) String() string {
|
func (v VPN) String() string {
|
||||||
@@ -86,6 +90,8 @@ func (v VPN) toLinesNode() (node *gotree.Node) {
|
|||||||
node.AppendNode(v.Wireguard.toLinesNode())
|
node.AppendNode(v.Wireguard.toLinesNode())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
node.Appendf("Path MTU discovery update: %s", gosettings.BoolToYesNo(v.PMTUD))
|
||||||
|
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,5 +113,10 @@ func (v *VPN) read(r *reader.Reader) (err error) {
|
|||||||
return fmt.Errorf("wireguard: %w", err)
|
return fmt.Errorf("wireguard: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
v.PMTUD, err = r.BoolPtr("VPN_PMTUD")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,6 +62,10 @@ func (n *NetLink) LinkSetDown(link Link) (err error) {
|
|||||||
return netlink.LinkSetDown(linkToNetlinkLink(&link))
|
return netlink.LinkSetDown(linkToNetlinkLink(&link))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *NetLink) LinkSetMTU(link Link, mtu int) error {
|
||||||
|
return netlink.LinkSetMTU(linkToNetlinkLink(&link), mtu)
|
||||||
|
}
|
||||||
|
|
||||||
type netlinkLinkImpl struct {
|
type netlinkLinkImpl struct {
|
||||||
attrs *netlink.LinkAttrs
|
attrs *netlink.LinkAttrs
|
||||||
linkType string
|
linkType string
|
||||||
|
|||||||
@@ -1,10 +1,24 @@
|
|||||||
package pmtud
|
package pmtud
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
|
ErrICMPDestinationUnreachable = errors.New("ICMP destination unreachable")
|
||||||
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
|
ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func wrapConnErr(err error, timedCtx context.Context, pingTimeout time.Duration) error { //nolint:revive
|
||||||
|
switch {
|
||||||
|
case errors.Is(timedCtx.Err(), context.DeadlineExceeded):
|
||||||
|
err = fmt.Errorf("%w (timed out after %s)", net.ErrClosed, pingTimeout)
|
||||||
|
case timedCtx.Err() != nil:
|
||||||
|
err = timedCtx.Err()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
|||||||
|
|
||||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
err = wrapConnErr(err, ctx, pingTimeout)
|
||||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,6 +85,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
|||||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
err = wrapConnErr(err, ctx, pingTimeout)
|
||||||
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
||||||
}
|
}
|
||||||
packetBytes := buffer[:bytesRead]
|
packetBytes := buffer[:bytesRead]
|
||||||
@@ -135,7 +137,7 @@ func findIPv4NextHopMTU(ctx context.Context, ip netip.Addr,
|
|||||||
if inboundID == outboundID {
|
if inboundID == outboundID {
|
||||||
return physicalLinkMTU, nil
|
return physicalLinkMTU, nil
|
||||||
}
|
}
|
||||||
logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||||
inboundID, outboundID)
|
inboundID, outboundID)
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
|||||||
|
|
||||||
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
|
_, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice(), Zone: ip.Zone()})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
err = wrapConnErr(err, ctx, pingTimeout)
|
||||||
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
return 0, fmt.Errorf("writing ICMP message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,6 +65,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
|||||||
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
// https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J
|
||||||
bytesRead, _, err := conn.ReadFrom(buffer)
|
bytesRead, _, err := conn.ReadFrom(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
err = wrapConnErr(err, ctx, pingTimeout)
|
||||||
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
return 0, fmt.Errorf("reading from ICMP connection: %w", err)
|
||||||
}
|
}
|
||||||
packetBytes := buffer[:bytesRead]
|
packetBytes := buffer[:bytesRead]
|
||||||
@@ -106,7 +108,7 @@ func getIPv6PacketTooBig(ctx context.Context, ip netip.Addr,
|
|||||||
if inboundID == outboundID {
|
if inboundID == outboundID {
|
||||||
return physicalLinkMTU, nil
|
return physicalLinkMTU, nil
|
||||||
}
|
}
|
||||||
logger.Debug("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
logger.Debugf("discarding received ICMP echo reply with id %d mismatching sent id %d",
|
||||||
inboundID, outboundID)
|
inboundID, outboundID)
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ type Linker interface {
|
|||||||
LinkDel(link netlink.Link) (err error)
|
LinkDel(link netlink.Link) (err error)
|
||||||
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||||
LinkSetDown(link netlink.Link) (err error)
|
LinkSetDown(link netlink.Link) (err error)
|
||||||
|
LinkSetMTU(link netlink.Link, mtu int) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DNSLoop interface {
|
type DNSLoop interface {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||||
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
"github.com/qdm12/gluetun/internal/openvpn"
|
"github.com/qdm12/gluetun/internal/openvpn"
|
||||||
"github.com/qdm12/gluetun/internal/provider"
|
"github.com/qdm12/gluetun/internal/provider"
|
||||||
)
|
)
|
||||||
@@ -14,39 +15,39 @@ import (
|
|||||||
func setupOpenVPN(ctx context.Context, fw Firewall,
|
func setupOpenVPN(ctx context.Context, fw Firewall,
|
||||||
openvpnConf OpenVPN, providerConf provider.Provider,
|
openvpnConf OpenVPN, providerConf provider.Provider,
|
||||||
settings settings.VPN, ipv6Supported bool, starter CmdStarter,
|
settings settings.VPN, ipv6Supported bool, starter CmdStarter,
|
||||||
logger openvpn.Logger) (runner *openvpn.Runner, serverName string,
|
logger openvpn.Logger) (runner *openvpn.Runner,
|
||||||
canPortForward bool, err error,
|
connection models.Connection, err error,
|
||||||
) {
|
) {
|
||||||
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", false, fmt.Errorf("finding a valid server connection: %w", err)
|
return nil, models.Connection{}, fmt.Errorf("finding a valid server connection: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported)
|
lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported)
|
||||||
|
|
||||||
if err := openvpnConf.WriteConfig(lines); err != nil {
|
if err := openvpnConf.WriteConfig(lines); err != nil {
|
||||||
return nil, "", false, fmt.Errorf("writing configuration to file: %w", err)
|
return nil, models.Connection{}, fmt.Errorf("writing configuration to file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if *settings.OpenVPN.User != "" {
|
if *settings.OpenVPN.User != "" {
|
||||||
err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password)
|
err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", false, fmt.Errorf("writing auth to file: %w", err)
|
return nil, models.Connection{}, fmt.Errorf("writing auth to file: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if *settings.OpenVPN.KeyPassphrase != "" {
|
if *settings.OpenVPN.KeyPassphrase != "" {
|
||||||
err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase)
|
err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", false, fmt.Errorf("writing askpass file: %w", err)
|
return nil, models.Connection{}, fmt.Errorf("writing askpass file: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil {
|
if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil {
|
||||||
return nil, "", false, fmt.Errorf("allowing VPN connection through firewall: %w", err)
|
return nil, models.Connection{}, fmt.Errorf("allowing VPN connection through firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)
|
runner = openvpn.NewRunner(settings.OpenVPN, starter, logger)
|
||||||
|
|
||||||
return runner, connection.ServerName, connection.PortForward, nil
|
return runner, connection, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
"github.com/qdm12/gluetun/internal/constants/vpn"
|
"github.com/qdm12/gluetun/internal/constants/vpn"
|
||||||
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
"github.com/qdm12/log"
|
"github.com/qdm12/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,17 +29,17 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
var vpnRunner interface {
|
var vpnRunner interface {
|
||||||
Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{})
|
Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{})
|
||||||
}
|
}
|
||||||
var serverName, vpnInterface string
|
var vpnInterface string
|
||||||
var canPortForward bool
|
var connection models.Connection
|
||||||
var err error
|
var err error
|
||||||
subLogger := l.logger.New(log.SetComponent(settings.Type))
|
subLogger := l.logger.New(log.SetComponent(settings.Type))
|
||||||
if settings.Type == vpn.OpenVPN {
|
if settings.Type == vpn.OpenVPN {
|
||||||
vpnInterface = settings.OpenVPN.Interface
|
vpnInterface = settings.OpenVPN.Interface
|
||||||
vpnRunner, serverName, canPortForward, err = setupOpenVPN(ctx, l.fw,
|
vpnRunner, connection, err = setupOpenVPN(ctx, l.fw,
|
||||||
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
|
l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger)
|
||||||
} else { // Wireguard
|
} else { // Wireguard
|
||||||
vpnInterface = settings.Wireguard.Interface
|
vpnInterface = settings.Wireguard.Interface
|
||||||
vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw,
|
vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw,
|
||||||
providerConf, settings, l.ipv6Supported, subLogger)
|
providerConf, settings, l.ipv6Supported, subLogger)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -46,8 +47,11 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
tunnelUpData := tunnelUpData{
|
tunnelUpData := tunnelUpData{
|
||||||
serverName: serverName,
|
PMTUD: *settings.PMTUD,
|
||||||
canPortForward: canPortForward,
|
serverIP: connection.IP,
|
||||||
|
vpnType: settings.Type,
|
||||||
|
serverName: connection.ServerName,
|
||||||
|
canPortForward: connection.PortForward,
|
||||||
portForwarder: portForwarder,
|
portForwarder: portForwarder,
|
||||||
vpnIntf: vpnInterface,
|
vpnIntf: vpnInterface,
|
||||||
username: settings.Provider.PortForwarding.Username,
|
username: settings.Provider.PortForwarding.Username,
|
||||||
|
|||||||
@@ -2,15 +2,32 @@ package vpn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/dns/v2/pkg/check"
|
"github.com/qdm12/dns/v2/pkg/check"
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
"github.com/qdm12/gluetun/internal/constants"
|
||||||
|
"github.com/qdm12/gluetun/internal/pmtud"
|
||||||
"github.com/qdm12/gluetun/internal/version"
|
"github.com/qdm12/gluetun/internal/version"
|
||||||
|
"github.com/qdm12/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tunnelUpData struct {
|
type tunnelUpData struct {
|
||||||
// Port forwarding
|
// vpnIntf is the name of the VPN network interface
|
||||||
vpnIntf string
|
// which is used both for port forwarding and MTU discovery
|
||||||
|
vpnIntf string
|
||||||
|
// Path MTU discovery fields:
|
||||||
|
// PMTUD indicates whether to perform Path MTU Discovery and
|
||||||
|
// adjust the VPN interface MTU accordingly.
|
||||||
|
PMTUD bool
|
||||||
|
// serverIP is used for path MTU discovery
|
||||||
|
serverIP netip.Addr
|
||||||
|
// vpnType is used for path MTU discovery to find the protocol overhead.
|
||||||
|
// It can be "wireguard" or "openvpn".
|
||||||
|
vpnType string
|
||||||
|
// Port forwarding fields:
|
||||||
serverName string // used for PIA
|
serverName string // used for PIA
|
||||||
canPortForward bool // used for PIA
|
canPortForward bool // used for PIA
|
||||||
username string // used for PIA
|
username string // used for PIA
|
||||||
@@ -21,6 +38,16 @@ type tunnelUpData struct {
|
|||||||
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
||||||
l.client.CloseIdleConnections()
|
l.client.CloseIdleConnections()
|
||||||
|
|
||||||
|
if data.PMTUD {
|
||||||
|
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
|
||||||
|
mtuLogger.Info("finding maximum MTU, this takes around 3 seconds")
|
||||||
|
err := updateToMaxMTU(ctx, data.vpnIntf, data.serverIP, data.vpnType,
|
||||||
|
l.netLinker, mtuLogger)
|
||||||
|
if err != nil {
|
||||||
|
l.logger.Error(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, vpnPort := range l.vpnInputPorts {
|
for _, vpnPort := range l.vpnInputPorts {
|
||||||
err := l.fw.SetAllowedPort(ctx, vpnPort, data.vpnIntf)
|
err := l.fw.SetAllowedPort(ctx, vpnPort, data.vpnIntf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -57,3 +84,50 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
|
|||||||
l.logger.Error(err.Error())
|
l.logger.Error(err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var errVPNTypeUnknown = errors.New("unknown VPN type")
|
||||||
|
|
||||||
|
func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
||||||
|
serverIP netip.Addr, vpnType string, netlinker NetLinker, logger *log.Logger,
|
||||||
|
) error {
|
||||||
|
link, err := netlinker.LinkByName(vpnInterface)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getting VPN interface by name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
|
||||||
|
// protocol overhead, so start lower than 1500 according to the protocol used.
|
||||||
|
const physicalLinkMTU = 1500
|
||||||
|
vpnLinkMTU := physicalLinkMTU
|
||||||
|
switch vpnType {
|
||||||
|
case "wireguard":
|
||||||
|
vpnLinkMTU -= 60 // Wireguard overhead
|
||||||
|
case "openvpn":
|
||||||
|
vpnLinkMTU -= 41 // OpenVPN overhead
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setting the VPN link MTU to 1500 might interrupt the connection until
|
||||||
|
// the new MTU is set again, but this is necessary to find the highest valid MTU.
|
||||||
|
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
|
||||||
|
|
||||||
|
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const pingTimeout = time.Second
|
||||||
|
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, serverIP, vpnLinkMTU, pingTimeout, logger)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("path MTU discovering: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("VPN interface %s MTU set to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/configuration/settings"
|
"github.com/qdm12/gluetun/internal/configuration/settings"
|
||||||
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
"github.com/qdm12/gluetun/internal/provider"
|
"github.com/qdm12/gluetun/internal/provider"
|
||||||
"github.com/qdm12/gluetun/internal/provider/utils"
|
"github.com/qdm12/gluetun/internal/provider/utils"
|
||||||
"github.com/qdm12/gluetun/internal/wireguard"
|
"github.com/qdm12/gluetun/internal/wireguard"
|
||||||
@@ -16,11 +17,11 @@ import (
|
|||||||
func setupWireguard(ctx context.Context, netlinker NetLinker,
|
func setupWireguard(ctx context.Context, netlinker NetLinker,
|
||||||
fw Firewall, providerConf provider.Provider,
|
fw Firewall, providerConf provider.Provider,
|
||||||
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
|
settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) (
|
||||||
wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error,
|
wireguarder *wireguard.Wireguard, connection models.Connection, err error,
|
||||||
) {
|
) {
|
||||||
connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", false, fmt.Errorf("finding a VPN server: %w", err)
|
return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported)
|
wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported)
|
||||||
@@ -31,13 +32,13 @@ func setupWireguard(ctx context.Context, netlinker NetLinker,
|
|||||||
|
|
||||||
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
|
wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", false, fmt.Errorf("creating Wireguard: %w", err)
|
return nil, models.Connection{}, fmt.Errorf("creating Wireguard: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
|
err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", false, fmt.Errorf("setting firewall: %w", err)
|
return nil, models.Connection{}, fmt.Errorf("setting firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return wireguarder, connection.ServerName, connection.PortForward, nil
|
return wireguarder, connection, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user