feat(wireguard): WIREGUARD_MTU enviromnent variable (#1571)

This commit is contained in:
Lars Haalck
2023-05-21 15:11:07 +02:00
committed by GitHub
parent 63303bc311
commit 1dd38bc658
8 changed files with 63 additions and 6 deletions

View File

@@ -97,6 +97,7 @@ ENV VPN_SERVICE_PROVIDER=pia \
WIREGUARD_PRESHARED_KEY= \ WIREGUARD_PRESHARED_KEY= \
WIREGUARD_PUBLIC_KEY= \ WIREGUARD_PUBLIC_KEY= \
WIREGUARD_ADDRESSES= \ WIREGUARD_ADDRESSES= \
WIREGUARD_MTU= \
WIREGUARD_IMPLEMENTATION=auto \ WIREGUARD_IMPLEMENTATION=auto \
# VPN server filtering # VPN server filtering
SERVER_REGIONS= \ SERVER_REGIONS= \

View File

@@ -8,6 +8,7 @@ import (
"github.com/qdm12/gluetun/internal/configuration/settings/helpers" "github.com/qdm12/gluetun/internal/configuration/settings/helpers"
"github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/providers"
"github.com/qdm12/gotree" "github.com/qdm12/gotree"
wireguarddevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@@ -27,6 +28,10 @@ type Wireguard struct {
// to create. It cannot be the empty string in the // to create. It cannot be the empty string in the
// internal state. // internal state.
Interface string Interface string
// Maximum Transmission Unit (MTU) of the Wireguard interface.
// It cannot be zero in the internal state, and defaults to
// the wireguard-go MTU default of 1420.
MTU uint16
// Implementation is the Wireguard implementation to use. // Implementation is the Wireguard implementation to use.
// It can be "auto", "userspace" or "kernelspace". // It can be "auto", "userspace" or "kernelspace".
// It defaults to "auto" and cannot be the empty string // It defaults to "auto" and cannot be the empty string
@@ -110,6 +115,7 @@ func (w *Wireguard) copy() (copied Wireguard) {
PreSharedKey: helpers.CopyPointer(w.PreSharedKey), PreSharedKey: helpers.CopyPointer(w.PreSharedKey),
Addresses: helpers.CopySlice(w.Addresses), Addresses: helpers.CopySlice(w.Addresses),
Interface: w.Interface, Interface: w.Interface,
MTU: w.MTU,
Implementation: w.Implementation, Implementation: w.Implementation,
} }
} }
@@ -119,6 +125,7 @@ func (w *Wireguard) mergeWith(other Wireguard) {
w.PreSharedKey = helpers.MergeWithPointer(w.PreSharedKey, other.PreSharedKey) w.PreSharedKey = helpers.MergeWithPointer(w.PreSharedKey, other.PreSharedKey)
w.Addresses = helpers.MergeSlices(w.Addresses, other.Addresses) w.Addresses = helpers.MergeSlices(w.Addresses, other.Addresses)
w.Interface = helpers.MergeWithString(w.Interface, other.Interface) w.Interface = helpers.MergeWithString(w.Interface, other.Interface)
w.MTU = helpers.MergeWithNumber(w.MTU, other.MTU)
w.Implementation = helpers.MergeWithString(w.Implementation, other.Implementation) w.Implementation = helpers.MergeWithString(w.Implementation, other.Implementation)
} }
@@ -127,6 +134,7 @@ func (w *Wireguard) overrideWith(other Wireguard) {
w.PreSharedKey = helpers.OverrideWithPointer(w.PreSharedKey, other.PreSharedKey) w.PreSharedKey = helpers.OverrideWithPointer(w.PreSharedKey, other.PreSharedKey)
w.Addresses = helpers.OverrideWithSlice(w.Addresses, other.Addresses) w.Addresses = helpers.OverrideWithSlice(w.Addresses, other.Addresses)
w.Interface = helpers.OverrideWithString(w.Interface, other.Interface) w.Interface = helpers.OverrideWithString(w.Interface, other.Interface)
w.MTU = helpers.OverrideWithNumber(w.MTU, other.MTU)
w.Implementation = helpers.OverrideWithString(w.Implementation, other.Implementation) w.Implementation = helpers.OverrideWithString(w.Implementation, other.Implementation)
} }
@@ -134,6 +142,7 @@ func (w *Wireguard) setDefaults() {
w.PrivateKey = helpers.DefaultPointer(w.PrivateKey, "") w.PrivateKey = helpers.DefaultPointer(w.PrivateKey, "")
w.PreSharedKey = helpers.DefaultPointer(w.PreSharedKey, "") w.PreSharedKey = helpers.DefaultPointer(w.PreSharedKey, "")
w.Interface = helpers.DefaultString(w.Interface, "wg0") w.Interface = helpers.DefaultString(w.Interface, "wg0")
w.MTU = helpers.DefaultNumber(w.MTU, wireguarddevice.DefaultMTU)
w.Implementation = helpers.DefaultString(w.Implementation, "auto") w.Implementation = helpers.DefaultString(w.Implementation, "auto")
} }
@@ -159,7 +168,8 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) {
addressesNode.Appendf(address.String()) addressesNode.Appendf(address.String())
} }
node.Appendf("Network interface: %s", w.Interface) interfaceNode := node.Appendf("Network interface: %s", w.Interface)
interfaceNode.Appendf("MTU: %d", w.MTU)
if w.Implementation != "auto" { if w.Implementation != "auto" {
node.Appendf("Implementation: %s", w.Implementation) node.Appendf("Implementation: %s", w.Implementation)

View File

@@ -21,6 +21,12 @@ func (s *Source) readWireguard() (wireguard settings.Wireguard, err error) {
if err != nil { if err != nil {
return wireguard, err // already wrapped return wireguard, err // already wrapped
} }
mtuPtr, err := envToUint16Ptr("WIREGUARD_MTU")
if err != nil {
return wireguard, fmt.Errorf("environment variable WIREGUARD_MTU: %w", err)
} else if mtuPtr != nil {
wireguard.MTU = *mtuPtr
}
return wireguard, nil return wireguard, nil
} }

View File

@@ -15,6 +15,7 @@ func BuildWireguardSettings(connection models.Connection,
settings.PreSharedKey = *userSettings.PreSharedKey settings.PreSharedKey = *userSettings.PreSharedKey
settings.InterfaceName = userSettings.Interface settings.InterfaceName = userSettings.Interface
settings.Implementation = userSettings.Implementation settings.Implementation = userSettings.Implementation
settings.MTU = userSettings.MTU
settings.IPv6 = &ipv6Supported settings.IPv6 = &ipv6Supported
const rulePriority = 101 // 100 is to receive external connections const rulePriority = 101 // 100 is to receive external connections

View File

@@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/device"
) )
func Test_New(t *testing.T) { func Test_New(t *testing.T) {
@@ -48,6 +49,7 @@ func Test_New(t *testing.T) {
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32), netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32),
}, },
FirewallMark: 100, FirewallMark: 100,
MTU: device.DefaultMTU,
IPv6: ptr(false), IPv6: ptr(false),
Implementation: "auto", Implementation: "auto",
}, },

View File

@@ -74,7 +74,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
defer closers.cleanup(w.logger) defer closers.cleanup(w.logger)
link, waitAndCleanup, err := setupFunction(ctx, link, waitAndCleanup, err := setupFunction(ctx,
w.settings.InterfaceName, w.netlink, &closers, w.logger) w.settings.InterfaceName, w.netlink, w.settings.MTU, &closers, w.logger)
if err != nil { if err != nil {
waitError <- err waitError <- err
return return
@@ -158,12 +158,12 @@ func (w *Wireguard) setupIPv6(link netlink.Link, closers *closers) (err error) {
type waitAndCleanupFunc func() error type waitAndCleanupFunc func() error
func setupKernelSpace(ctx context.Context, func setupKernelSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, interfaceName string, netLinker NetLinker, mtu uint16,
closers *closers, logger Logger) ( closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) { link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
linkAttrs := netlink.LinkAttrs{ linkAttrs := netlink.LinkAttrs{
Name: interfaceName, Name: interfaceName,
MTU: device.DefaultMTU, // TODO MTU: int(mtu),
} }
link = &netlink.Wireguard{ link = &netlink.Wireguard{
LinkAttrs: linkAttrs, LinkAttrs: linkAttrs,
@@ -186,10 +186,10 @@ func setupKernelSpace(ctx context.Context,
} }
func setupUserSpace(ctx context.Context, func setupUserSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, interfaceName string, netLinker NetLinker, mtu uint16,
closers *closers, logger Logger) ( closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) { link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
tun, err := tun.CreateTUN(interfaceName, device.DefaultMTU) tun, err := tun.CreateTUN(interfaceName, int(mtu))
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrCreateTun, err) return nil, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
} }

View File

@@ -7,6 +7,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@@ -28,6 +29,9 @@ type Settings struct {
// FirewallMark to be used in routing tables and IP rules. // FirewallMark to be used in routing tables and IP rules.
// It defaults to 51820 if left to 0. // It defaults to 51820 if left to 0.
FirewallMark int FirewallMark int
// Maximum Transmission Unit (MTU) setting for the network interface.
// It defaults to device.DefaultMTU from wireguard-go which is 1420
MTU uint16
// RulePriority is the priority for the rule created with the // RulePriority is the priority for the rule created with the
// FirewallMark. // FirewallMark.
RulePriority int RulePriority int
@@ -55,6 +59,10 @@ func (s *Settings) SetDefaults() {
s.FirewallMark = defaultFirewallMark s.FirewallMark = defaultFirewallMark
} }
if s.MTU == 0 {
s.MTU = device.DefaultMTU
}
if s.IPv6 == nil { if s.IPv6 == nil {
ipv6 := false // this should be injected from host ipv6 := false // this should be injected from host
s.IPv6 = &ipv6 s.IPv6 = &ipv6
@@ -78,6 +86,7 @@ var (
ErrAddressMissing = errors.New("interface address is missing") ErrAddressMissing = errors.New("interface address is missing")
ErrAddressNotValid = errors.New("interface address is not valid") ErrAddressNotValid = errors.New("interface address is not valid")
ErrFirewallMarkMissing = errors.New("firewall mark is missing") ErrFirewallMarkMissing = errors.New("firewall mark is missing")
ErrMTUMissing = errors.New("MTU is missing")
ErrImplementationInvalid = errors.New("invalid implementation") ErrImplementationInvalid = errors.New("invalid implementation")
) )
@@ -127,6 +136,10 @@ func (s *Settings) Check() (err error) {
return fmt.Errorf("%w", ErrFirewallMarkMissing) return fmt.Errorf("%w", ErrFirewallMarkMissing)
} }
if s.MTU == 0 {
return fmt.Errorf("%w", ErrMTUMissing)
}
switch s.Implementation { switch s.Implementation {
case "auto", "kernelspace", "userspace": case "auto", "kernelspace", "userspace":
default: default:
@@ -209,6 +222,10 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
lines = append(lines, fieldPrefix+"Firewall mark: "+fmt.Sprint(s.FirewallMark)) lines = append(lines, fieldPrefix+"Firewall mark: "+fmt.Sprint(s.FirewallMark))
} }
if s.MTU != 0 {
lines = append(lines, fieldPrefix+"MTU: "+fmt.Sprint(s.MTU))
}
if s.RulePriority != 0 { if s.RulePriority != 0 {
lines = append(lines, fieldPrefix+"Rule priority: "+fmt.Sprint(s.RulePriority)) lines = append(lines, fieldPrefix+"Rule priority: "+fmt.Sprint(s.RulePriority))
} }

View File

@@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/device"
) )
func ptr[T any](v T) *T { return &v } func ptr[T any](v T) *T { return &v }
@@ -22,6 +23,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
expected: Settings{ expected: Settings{
InterfaceName: "wg0", InterfaceName: "wg0",
FirewallMark: 51820, FirewallMark: 51820,
MTU: device.DefaultMTU,
IPv6: ptr(false), IPv6: ptr(false),
Implementation: "auto", Implementation: "auto",
}, },
@@ -34,6 +36,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
InterfaceName: "wg0", InterfaceName: "wg0",
FirewallMark: 51820, FirewallMark: 51820,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820), Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
MTU: device.DefaultMTU,
IPv6: ptr(false), IPv6: ptr(false),
Implementation: "auto", Implementation: "auto",
}, },
@@ -43,6 +46,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
InterfaceName: "wg1", InterfaceName: "wg1",
FirewallMark: 999, FirewallMark: 999,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 9999), Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 9999),
MTU: device.DefaultMTU,
IPv6: ptr(true), IPv6: ptr(true),
Implementation: "userspace", Implementation: "userspace",
}, },
@@ -50,6 +54,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
InterfaceName: "wg1", InterfaceName: "wg1",
FirewallMark: 999, FirewallMark: 999,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 9999), Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 9999),
MTU: device.DefaultMTU,
IPv6: ptr(true), IPv6: ptr(true),
Implementation: "userspace", Implementation: "userspace",
}, },
@@ -174,6 +179,19 @@ func Test_Settings_Check(t *testing.T) {
}, },
err: ErrFirewallMarkMissing, err: ErrFirewallMarkMissing,
}, },
"missing_MTU": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
FirewallMark: 999,
},
err: ErrMTUMissing,
},
"invalid implementation": { "invalid implementation": {
settings: Settings{ settings: Settings{
InterfaceName: "wg0", InterfaceName: "wg0",
@@ -184,6 +202,7 @@ func Test_Settings_Check(t *testing.T) {
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
}, },
FirewallMark: 999, FirewallMark: 999,
MTU: 1420,
Implementation: "x", Implementation: "x",
}, },
err: errors.New("invalid implementation: x"), err: errors.New("invalid implementation: x"),
@@ -198,6 +217,7 @@ func Test_Settings_Check(t *testing.T) {
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24), netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
}, },
FirewallMark: 999, FirewallMark: 999,
MTU: 1420,
Implementation: "userspace", Implementation: "userspace",
}, },
}, },