feat(wireguard): WIREGUARD_ALLOWED_IPS variable (#1291)

This commit is contained in:
Quentin McGaw
2023-07-06 10:08:59 +03:00
committed by GitHub
parent 9c0f187a12
commit 919b55c3aa
11 changed files with 225 additions and 69 deletions

View File

@@ -48,6 +48,9 @@ func Test_New(t *testing.T) {
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 32),
},
AllowedIPs: []netip.Prefix{
allIPv4(),
},
FirewallMark: 100,
MTU: device.DefaultMTU,
IPv6: ptr(false),

View File

@@ -3,11 +3,30 @@ package wireguard
import (
"fmt"
"net/netip"
"strings"
"github.com/qdm12/gluetun/internal/netlink"
)
// TODO add IPv6 route if IPv6 is supported
func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix,
firewallMark int) (err error) {
for _, dst := range destinations {
err = w.addRoute(link, dst, firewallMark)
if err == nil {
continue
}
if dst.Addr().Is6() && strings.Contains(err.Error(), "permission denied") {
w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
"Ignoring and continuing execution; "+
"Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+
"Full error string: %s", err)
continue
}
return fmt.Errorf("adding route for destination %s: %w", dst, err)
}
return nil
}
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
firewallMark int) (err error) {

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net"
"strings"
"github.com/qdm12/gluetun/internal/netlink"
"golang.org/x/sys/unix"
@@ -103,7 +102,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
return w.netlink.LinkSetDown(link)
})
err = w.addRoute(link, allIPv4(), w.settings.FirewallMark)
err = w.addRoutes(link, w.settings.AllowedIPs, w.settings.FirewallMark)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
return
@@ -111,11 +110,13 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
if *w.settings.IPv6 {
// requires net.ipv6.conf.all.disable_ipv6=0
err = w.setupIPv6(link, &closers)
ruleCleanup6, err := w.addRule(w.settings.RulePriority,
w.settings.FirewallMark, unix.AF_INET6)
if err != nil {
waitError <- fmt.Errorf("setting up IPv6: %w", err)
waitError <- fmt.Errorf("adding IPv6 rule: %w", err)
return
}
closers.add("removing IPv6 rule", stepOne, ruleCleanup6)
}
ruleCleanup, err := w.addRule(w.settings.RulePriority,
@@ -132,31 +133,6 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
waitError <- waitAndCleanup()
}
func (w *Wireguard) setupIPv6(link netlink.Link, closers *closers) (err error) {
// requires net.ipv6.conf.all.disable_ipv6=0
err = w.addRoute(link, allIPv6(), w.settings.FirewallMark)
if err != nil {
if strings.Contains(err.Error(), "permission denied") {
w.logger.Errorf("cannot add route for IPv6 due to a permission denial. "+
"Ignoring and continuing execution; "+
"Please report to https://github.com/qdm12/gluetun/issues/998 if you find a fix. "+
"Full error string: %s", err)
return nil
}
return fmt.Errorf("%w: %s", ErrRouteAdd, err)
}
ruleCleanup6, ruleErr := w.addRule(
w.settings.RulePriority, w.settings.FirewallMark,
unix.AF_INET6)
if ruleErr != nil {
return fmt.Errorf("adding IPv6 rule: %w", ruleErr)
}
closers.add("removing IPv6 rule", stepOne, ruleCleanup6)
return nil
}
type waitAndCleanupFunc func() error
func setupKernelSpace(ctx context.Context,

View File

@@ -26,6 +26,10 @@ type Settings struct {
// Addresses assigned to the client.
// Note IPv6 addresses are ignored if IPv6 is not supported.
Addresses []netip.Prefix
// AllowedIPs is the IP networks to be routed through
// the Wireguard interface.
// Note IPv6 addresses are ignored if IPv6 is not supported.
AllowedIPs []netip.Prefix
// FirewallMark to be used in routing tables and IP rules.
// It defaults to 51820 if left to 0.
FirewallMark int
@@ -68,6 +72,13 @@ func (s *Settings) SetDefaults() {
s.IPv6 = &ipv6
}
if len(s.AllowedIPs) == 0 {
s.AllowedIPs = append(s.AllowedIPs, allIPv4())
if *s.IPv6 {
s.AllowedIPs = append(s.AllowedIPs, allIPv6())
}
}
if s.Implementation == "" {
const defaultImplementation = "auto"
s.Implementation = defaultImplementation
@@ -75,19 +86,22 @@ func (s *Settings) SetDefaults() {
}
var (
ErrInterfaceNameInvalid = errors.New("invalid interface name")
ErrPrivateKeyMissing = errors.New("private key is missing")
ErrPrivateKeyInvalid = errors.New("cannot parse private key")
ErrPublicKeyMissing = errors.New("public key is missing")
ErrPublicKeyInvalid = errors.New("cannot parse public key")
ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key")
ErrEndpointAddrMissing = errors.New("endpoint address is missing")
ErrEndpointPortMissing = errors.New("endpoint port is missing")
ErrAddressMissing = errors.New("interface address is missing")
ErrAddressNotValid = errors.New("interface address is not valid")
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
ErrMTUMissing = errors.New("MTU is missing")
ErrImplementationInvalid = errors.New("invalid implementation")
ErrInterfaceNameInvalid = errors.New("invalid interface name")
ErrPrivateKeyMissing = errors.New("private key is missing")
ErrPrivateKeyInvalid = errors.New("cannot parse private key")
ErrPublicKeyMissing = errors.New("public key is missing")
ErrPublicKeyInvalid = errors.New("cannot parse public key")
ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key")
ErrEndpointAddrMissing = errors.New("endpoint address is missing")
ErrEndpointPortMissing = errors.New("endpoint port is missing")
ErrAddressMissing = errors.New("interface address is missing")
ErrAddressNotValid = errors.New("interface address is not valid")
ErrAllowedIPsMissing = errors.New("allowed IPs are missing")
ErrAllowedIPNotValid = errors.New("allowed IP is not valid")
ErrAllowedIPv6NotSupported = errors.New("allowed IPv6 address not supported")
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
ErrMTUMissing = errors.New("MTU is missing")
ErrImplementationInvalid = errors.New("invalid implementation")
)
var interfaceNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
@@ -132,6 +146,20 @@ func (s *Settings) Check() (err error) {
}
}
if len(s.AllowedIPs) == 0 {
return fmt.Errorf("%w", ErrAllowedIPsMissing)
}
for i, allowedIP := range s.AllowedIPs {
switch {
case !allowedIP.IsValid():
return fmt.Errorf("%w: for allowed IP %d of %d",
ErrAllowedIPNotValid, i+1, len(s.AllowedIPs))
case allowedIP.Addr().Is6() && !*s.IPv6:
return fmt.Errorf("%w: for allowed IP %s",
ErrAllowedIPv6NotSupported, allowedIP)
}
}
if s.FirewallMark == 0 {
return fmt.Errorf("%w", ErrFirewallMarkMissing)
}
@@ -247,5 +275,16 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
}
}
if len(s.AllowedIPs) > 0 {
lines = append(lines, fieldPrefix+"Allowed IPs:")
for i, allowedIP := range s.AllowedIPs {
prefix := fieldPrefix
if i == len(s.AllowedIPs)-1 {
prefix = lastFieldPrefix
}
lines = append(lines, indent+prefix+allowedIP.String())
}
}
return lines
}

View File

@@ -1,12 +1,10 @@
package wireguard
import (
"errors"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/device"
)
@@ -23,6 +21,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
expected: Settings{
InterfaceName: "wg0",
FirewallMark: 51820,
AllowedIPs: []netip.Prefix{allIPv4()},
MTU: device.DefaultMTU,
IPv6: ptr(false),
Implementation: "auto",
@@ -36,6 +35,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
InterfaceName: "wg0",
FirewallMark: 51820,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
AllowedIPs: []netip.Prefix{allIPv4()},
MTU: device.DefaultMTU,
IPv6: ptr(false),
Implementation: "auto",
@@ -46,6 +46,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
InterfaceName: "wg1",
FirewallMark: 999,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 9999),
AllowedIPs: []netip.Prefix{allIPv4()},
MTU: device.DefaultMTU,
IPv6: ptr(true),
Implementation: "userspace",
@@ -54,6 +55,7 @@ func Test_Settings_SetDefaults(t *testing.T) {
InterfaceName: "wg1",
FirewallMark: 999,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 9999),
AllowedIPs: []netip.Prefix{allIPv4()},
MTU: device.DefaultMTU,
IPv6: ptr(true),
Implementation: "userspace",
@@ -82,37 +84,43 @@ func Test_Settings_Check(t *testing.T) {
)
testCases := map[string]struct {
settings Settings
err error
settings Settings
errWrapped error
errMessage string
}{
"empty settings": {
err: errors.New("invalid interface name: "),
errWrapped: ErrInterfaceNameInvalid,
errMessage: "invalid interface name: ",
},
"bad interface name": {
settings: Settings{
InterfaceName: "$H1T",
},
err: errors.New("invalid interface name: $H1T"),
errWrapped: ErrInterfaceNameInvalid,
errMessage: "invalid interface name: $H1T",
},
"empty private key": {
settings: Settings{
InterfaceName: "wg0",
},
err: ErrPrivateKeyMissing,
errWrapped: ErrPrivateKeyMissing,
errMessage: "private key is missing",
},
"bad private key": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: "bad key",
},
err: ErrPrivateKeyInvalid,
errWrapped: ErrPrivateKeyInvalid,
errMessage: "cannot parse private key",
},
"empty public key": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
},
err: ErrPublicKeyMissing,
errWrapped: ErrPublicKeyMissing,
errMessage: "public key is missing",
},
"bad public key": {
settings: Settings{
@@ -120,7 +128,8 @@ func Test_Settings_Check(t *testing.T) {
PrivateKey: validKey1,
PublicKey: "bad key",
},
err: errors.New("cannot parse public key: bad key"),
errWrapped: ErrPublicKeyInvalid,
errMessage: "cannot parse public key: bad key",
},
"bad preshared key": {
settings: Settings{
@@ -129,7 +138,8 @@ func Test_Settings_Check(t *testing.T) {
PublicKey: validKey2,
PreSharedKey: "bad key",
},
err: errors.New("cannot parse pre-shared key"),
errWrapped: ErrPreSharedKeyInvalid,
errMessage: "cannot parse pre-shared key",
},
"invalid endpoint address": {
settings: Settings{
@@ -137,7 +147,8 @@ func Test_Settings_Check(t *testing.T) {
PrivateKey: validKey1,
PublicKey: validKey2,
},
err: ErrEndpointAddrMissing,
errWrapped: ErrEndpointAddrMissing,
errMessage: "endpoint address is missing",
},
"zero endpoint port": {
settings: Settings{
@@ -146,7 +157,8 @@ func Test_Settings_Check(t *testing.T) {
PublicKey: validKey2,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 0),
},
err: ErrEndpointPortMissing,
errWrapped: ErrEndpointPortMissing,
errMessage: "endpoint port is missing",
},
"no address": {
settings: Settings{
@@ -155,7 +167,8 @@ func Test_Settings_Check(t *testing.T) {
PublicKey: validKey2,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
},
err: ErrAddressMissing,
errWrapped: ErrAddressMissing,
errMessage: "interface address is missing",
},
"invalid address": {
settings: Settings{
@@ -165,7 +178,53 @@ func Test_Settings_Check(t *testing.T) {
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
Addresses: []netip.Prefix{{}},
},
err: errors.New("interface address is not valid: for address 1 of 1"),
errWrapped: ErrAddressNotValid,
errMessage: "interface address is not valid: for address 1 of 1",
},
"no allowed IP": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 24),
},
},
errWrapped: ErrAllowedIPsMissing,
errMessage: "allowed IPs are missing",
},
"invalid allowed IP": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 24),
},
AllowedIPs: []netip.Prefix{{}},
},
errWrapped: ErrAllowedIPNotValid,
errMessage: "allowed IP is not valid: for allowed IP 1 of 1",
},
"ipv6 allowed IP": {
settings: Settings{
InterfaceName: "wg0",
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{5, 6, 7, 8}), 24),
},
AllowedIPs: []netip.Prefix{
allIPv6(),
},
IPv6: ptrTo(false),
},
errWrapped: ErrAllowedIPv6NotSupported,
errMessage: "allowed IPv6 address not supported: for allowed IP ::/0",
},
"zero firewall mark": {
settings: Settings{
@@ -173,11 +232,13 @@ func Test_Settings_Check(t *testing.T) {
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
AllowedIPs: []netip.Prefix{allIPv4()},
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
},
err: ErrFirewallMarkMissing,
errWrapped: ErrFirewallMarkMissing,
errMessage: "firewall mark is missing",
},
"missing_MTU": {
settings: Settings{
@@ -185,12 +246,14 @@ func Test_Settings_Check(t *testing.T) {
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
AllowedIPs: []netip.Prefix{allIPv4()},
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
FirewallMark: 999,
},
err: ErrMTUMissing,
errWrapped: ErrMTUMissing,
errMessage: "MTU is missing",
},
"invalid implementation": {
settings: Settings{
@@ -198,6 +261,7 @@ func Test_Settings_Check(t *testing.T) {
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
AllowedIPs: []netip.Prefix{allIPv4()},
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
@@ -205,7 +269,8 @@ func Test_Settings_Check(t *testing.T) {
MTU: 1420,
Implementation: "x",
},
err: errors.New("invalid implementation: x"),
errWrapped: ErrImplementationInvalid,
errMessage: "invalid implementation: x",
},
"all valid": {
settings: Settings{
@@ -213,11 +278,15 @@ func Test_Settings_Check(t *testing.T) {
PrivateKey: validKey1,
PublicKey: validKey2,
Endpoint: netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 51820),
AllowedIPs: []netip.Prefix{
allIPv6(),
},
Addresses: []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
FirewallMark: 999,
MTU: 1420,
IPv6: ptrTo(true),
Implementation: "userspace",
},
},
@@ -230,11 +299,9 @@ func Test_Settings_Check(t *testing.T) {
err := testCase.settings.Check()
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
assert.ErrorIs(t, err, testCase.errWrapped)
if testCase.errWrapped != nil {
assert.EqualError(t, err, testCase.errMessage)
}
})
}