feat(netlink): detect IPv6 using query to address

- If a default IPv6 route is found, query the ip:port defined by `IPV6_CHECK_ADDRESS` to check for internet access
This commit is contained in:
Quentin McGaw
2024-12-12 06:48:43 +00:00
parent dae44051f6
commit 5ca13021e7
13 changed files with 384 additions and 7 deletions

View File

@@ -159,6 +159,8 @@ ENV VPN_SERVICE_PROVIDER=pia \
FIREWALL_INPUT_PORTS= \ FIREWALL_INPUT_PORTS= \
FIREWALL_OUTBOUND_SUBNETS= \ FIREWALL_OUTBOUND_SUBNETS= \
FIREWALL_DEBUG=off \ FIREWALL_DEBUG=off \
# IPv6
IPV6_CHECK_ADDRESS=[2606:4700::6810:84e5]:443 \
# Logging # Logging
LOG_LEVEL=info \ LOG_LEVEL=info \
# Health # Health

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"net/http" "net/http"
"net/netip"
"os" "os"
"os/exec" "os/exec"
"os/signal" "os/signal"
@@ -242,7 +243,8 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
return err return err
} }
ipv6SupportLevel, err := netLinker.FindIPv6SupportLevel() ipv6SupportLevel, err := netLinker.FindIPv6SupportLevel(ctx,
allSettings.IPv6.CheckAddress, firewallConf)
if err != nil { if err != nil {
return fmt.Errorf("checking for IPv6 support: %w", err) return fmt.Errorf("checking for IPv6 support: %w", err)
} }
@@ -552,7 +554,9 @@ type netLinker interface {
Ruler Ruler
Linker Linker
IsWireguardSupported() (ok bool, err error) IsWireguardSupported() (ok bool, err error)
FindIPv6SupportLevel() (level netlink.IPv6SupportLevel, err error) FindIPv6SupportLevel(ctx context.Context,
checkAddress netip.AddrPort, firewall netlink.Firewall,
) (level netlink.IPv6SupportLevel, err error)
PatchLoggerLevel(level log.Level) PatchLoggerLevel(level log.Level)
} }

View File

@@ -0,0 +1,14 @@
package cli
import (
"context"
"net/netip"
)
type noopFirewall struct{}
func (f *noopFirewall) AcceptOutput(_ context.Context, _, _ string, _ netip.Addr,
_ uint16, _ bool,
) (err error) {
return nil
}

View File

@@ -41,7 +41,9 @@ type IPFetcher interface {
} }
type IPv6Checker interface { type IPv6Checker interface {
FindIPv6SupportLevel() (level netlink.IPv6SupportLevel, err error) FindIPv6SupportLevel(ctx context.Context,
checkAddress netip.AddrPort, firewall netlink.Firewall,
) (level netlink.IPv6SupportLevel, err error)
} }
func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader, func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader,
@@ -59,7 +61,8 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader,
} }
allSettings.SetDefaults() allSettings.SetDefaults()
ipv6SupportLevel, err := ipv6Checker.FindIPv6SupportLevel() ipv6SupportLevel, err := ipv6Checker.FindIPv6SupportLevel(context.Background(),
allSettings.IPv6.CheckAddress, &noopFirewall{})
if err != nil { if err != nil {
return fmt.Errorf("checking for IPv6 support: %w", err) return fmt.Errorf("checking for IPv6 support: %w", err)
} }

View File

@@ -0,0 +1,51 @@
package settings
import (
"net/netip"
"github.com/qdm12/gosettings"
"github.com/qdm12/gosettings/reader"
"github.com/qdm12/gotree"
)
// IPv6 contains settings regarding IPv6 configuration.
type IPv6 struct {
// CheckAddress is the TCP ip:port address to dial to check
// IPv6 is supported, in case a default IPv6 route is found.
// It defaults to cloudflare.com address [2606:4700::6810:84e5]:443
CheckAddress netip.AddrPort
}
func (i IPv6) validate() (err error) {
return nil
}
func (i *IPv6) copy() (copied IPv6) {
return IPv6{
CheckAddress: i.CheckAddress,
}
}
func (i *IPv6) overrideWith(other IPv6) {
i.CheckAddress = gosettings.OverrideWithValidator(i.CheckAddress, other.CheckAddress)
}
func (i *IPv6) setDefaults() {
defaultCheckAddress := netip.MustParseAddrPort("[2606:4700::6810:84e5]:443")
i.CheckAddress = gosettings.DefaultComparable(i.CheckAddress, defaultCheckAddress)
}
func (i IPv6) String() string {
return i.toLinesNode().String()
}
func (i IPv6) toLinesNode() (node *gotree.Node) {
node = gotree.New("IPv6 settings:")
node.Appendf("Check address: %s", i.CheckAddress)
return node
}
func (i *IPv6) read(r *reader.Reader) (err error) {
i.CheckAddress, err = r.NetipAddrPort("IPV6_CHECK_ADDRESS")
return err
}

View File

@@ -27,6 +27,7 @@ type Settings struct {
Updater Updater Updater Updater
Version Version Version Version
VPN VPN VPN VPN
IPv6 IPv6
Pprof pprof.Settings Pprof pprof.Settings
} }
@@ -53,6 +54,7 @@ func (s *Settings) Validate(filterChoicesGetter FilterChoicesGetter, ipv6Support
"system": s.System.validate, "system": s.System.validate,
"updater": s.Updater.Validate, "updater": s.Updater.Validate,
"version": s.Version.validate, "version": s.Version.validate,
"ipv6": s.IPv6.validate,
// Pprof validation done in pprof constructor // Pprof validation done in pprof constructor
"VPN": func() error { "VPN": func() error {
return s.VPN.Validate(filterChoicesGetter, ipv6Supported, warner) return s.VPN.Validate(filterChoicesGetter, ipv6Supported, warner)
@@ -85,6 +87,7 @@ func (s *Settings) copy() (copied Settings) {
Version: s.Version.copy(), Version: s.Version.copy(),
VPN: s.VPN.Copy(), VPN: s.VPN.Copy(),
Pprof: s.Pprof.Copy(), Pprof: s.Pprof.Copy(),
IPv6: s.IPv6.copy(),
} }
} }
@@ -106,6 +109,7 @@ func (s *Settings) OverrideWith(other Settings,
patchedSettings.Version.overrideWith(other.Version) patchedSettings.Version.overrideWith(other.Version)
patchedSettings.VPN.OverrideWith(other.VPN) patchedSettings.VPN.OverrideWith(other.VPN)
patchedSettings.Pprof.OverrideWith(other.Pprof) patchedSettings.Pprof.OverrideWith(other.Pprof)
patchedSettings.IPv6.overrideWith(other.IPv6)
err = patchedSettings.Validate(filterChoicesGetter, ipv6Supported, warner) err = patchedSettings.Validate(filterChoicesGetter, ipv6Supported, warner)
if err != nil { if err != nil {
return err return err
@@ -121,6 +125,7 @@ func (s *Settings) SetDefaults() {
s.Health.SetDefaults() s.Health.SetDefaults()
s.HTTPProxy.setDefaults() s.HTTPProxy.setDefaults()
s.Log.setDefaults() s.Log.setDefaults()
s.IPv6.setDefaults()
s.PublicIP.setDefaults() s.PublicIP.setDefaults()
s.Shadowsocks.setDefaults() s.Shadowsocks.setDefaults()
s.Storage.setDefaults() s.Storage.setDefaults()
@@ -142,6 +147,7 @@ func (s Settings) toLinesNode() (node *gotree.Node) {
node.AppendNode(s.DNS.toLinesNode()) node.AppendNode(s.DNS.toLinesNode())
node.AppendNode(s.Firewall.toLinesNode()) node.AppendNode(s.Firewall.toLinesNode())
node.AppendNode(s.Log.toLinesNode()) node.AppendNode(s.Log.toLinesNode())
node.AppendNode(s.IPv6.toLinesNode())
node.AppendNode(s.Health.toLinesNode()) node.AppendNode(s.Health.toLinesNode())
node.AppendNode(s.Shadowsocks.toLinesNode()) node.AppendNode(s.Shadowsocks.toLinesNode())
node.AppendNode(s.HTTPProxy.toLinesNode()) node.AppendNode(s.HTTPProxy.toLinesNode())
@@ -208,6 +214,7 @@ func (s *Settings) Read(r *reader.Reader, warner Warner) (err error) {
"updater": s.Updater.read, "updater": s.Updater.read,
"version": s.Version.read, "version": s.Version.read,
"VPN": s.VPN.read, "VPN": s.VPN.read,
"IPv6": s.IPv6.read,
"profiling": s.Pprof.Read, "profiling": s.Pprof.Read,
} }

View File

@@ -55,6 +55,8 @@ func Test_Settings_String(t *testing.T) {
| └── Enabled: yes | └── Enabled: yes
├── Log settings: ├── Log settings:
| └── Log level: INFO | └── Log level: INFO
├── IPv6 settings:
| └── Check address: [2606:4700::6810:84e5]:443
├── Health settings: ├── Health settings:
| ├── Server listening address: 127.0.0.1:9999 | ├── Server listening address: 127.0.0.1:9999
| ├── Target address: cloudflare.com:443 | ├── Target address: cloudflare.com:443

View File

@@ -162,6 +162,24 @@ func (c *Config) acceptOutputTrafficToVPN(ctx context.Context,
return c.runIP6tablesInstruction(ctx, instruction) return c.runIP6tablesInstruction(ctx, instruction)
} }
func (c *Config) AcceptOutput(ctx context.Context,
protocol, intf string, ip netip.Addr, port uint16, remove bool,
) error {
interfaceFlag := "-o " + intf
if intf == "*" { // all interfaces
interfaceFlag = ""
}
instruction := fmt.Sprintf("%s OUTPUT -d %s %s -p %s -m %s --dport %d -j ACCEPT",
appendOrDelete(remove), ip, interfaceFlag, protocol, protocol, port)
if ip.Is4() {
return c.runIptablesInstruction(ctx, instruction)
} else if c.ip6Tables == "" {
return fmt.Errorf("accept output to VPN server: %w", ErrNeedIP6Tables)
}
return c.runIP6tablesInstruction(ctx, instruction)
}
// Thanks to @npawelek. // Thanks to @npawelek.
func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context, func (c *Config) acceptOutputFromIPToSubnet(ctx context.Context,
intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool, intf string, sourceIP netip.Addr, destinationSubnet netip.Prefix, remove bool,

View File

@@ -1,9 +1,19 @@
package netlink package netlink
import "github.com/qdm12/log" import (
"context"
"net/netip"
"github.com/qdm12/log"
)
type DebugLogger interface { type DebugLogger interface {
Debug(message string) Debug(message string)
Debugf(format string, args ...any) Debugf(format string, args ...any)
Patch(options ...log.Option) Patch(options ...log.Option)
} }
type Firewall interface {
AcceptOutput(ctx context.Context, protocol, intf string, ip netip.Addr,
port uint16, remove bool) (err error)
}

View File

@@ -1,7 +1,11 @@
package netlink package netlink
import ( import (
"context"
"fmt" "fmt"
"net"
"net/netip"
"time"
) )
type IPv6SupportLevel uint8 type IPv6SupportLevel uint8
@@ -21,7 +25,9 @@ func (i IPv6SupportLevel) IsSupported() bool {
return i == IPv6Supported || i == IPv6Internet return i == IPv6Supported || i == IPv6Internet
} }
func (n *NetLink) FindIPv6SupportLevel() (level IPv6SupportLevel, err error) { func (n *NetLink) FindIPv6SupportLevel(ctx context.Context,
checkAddress netip.AddrPort, firewall Firewall,
) (level IPv6SupportLevel, err error) {
routes, err := n.RouteList(FamilyV6) routes, err := n.RouteList(FamilyV6)
if err != nil { if err != nil {
return IPv6Unsupported, fmt.Errorf("listing IPv6 routes: %w", err) return IPv6Unsupported, fmt.Errorf("listing IPv6 routes: %w", err)
@@ -44,7 +50,14 @@ func (n *NetLink) FindIPv6SupportLevel() (level IPv6SupportLevel, err error) {
case sourceIsIPv4 && destinationIsIPv4, case sourceIsIPv4 && destinationIsIPv4,
destinationIsIPv6 && route.Dst.Addr().IsLoopback(): destinationIsIPv6 && route.Dst.Addr().IsLoopback():
case route.Dst.Addr().IsUnspecified(): // default ipv6 route case route.Dst.Addr().IsUnspecified(): // default ipv6 route
n.debugLogger.Debugf("IPv6 internet access is enabled on link %s", link.Name) n.debugLogger.Debugf("IPv6 default route found on link %s", link.Name)
err = dialAddrThroughFirewall(ctx, link.Name, checkAddress, firewall)
if err != nil {
n.debugLogger.Debugf("IPv6 query failed on %s: %w", link.Name, err)
level = IPv6Supported
continue
}
n.debugLogger.Debugf("IPv6 internet is accessible through link %s", link.Name)
return IPv6Internet, nil return IPv6Internet, nil
default: // non-default ipv6 route found default: // non-default ipv6 route found
n.debugLogger.Debugf("IPv6 is supported by link %s", link.Name) n.debugLogger.Debugf("IPv6 is supported by link %s", link.Name)
@@ -57,3 +70,37 @@ func (n *NetLink) FindIPv6SupportLevel() (level IPv6SupportLevel, err error) {
} }
return level, nil return level, nil
} }
func dialAddrThroughFirewall(ctx context.Context, intf string,
checkAddress netip.AddrPort, firewall Firewall,
) (err error) {
const protocol = "tcp"
remove := false
err = firewall.AcceptOutput(ctx, protocol, intf,
checkAddress.Addr(), checkAddress.Port(), remove)
if err != nil {
return fmt.Errorf("accepting output traffic: %w", err)
}
defer func() {
remove = true
firewallErr := firewall.AcceptOutput(ctx, protocol, intf,
checkAddress.Addr(), checkAddress.Port(), remove)
if err == nil && firewallErr != nil {
err = fmt.Errorf("removing output traffic rule: %w", firewallErr)
}
}()
dialer := &net.Dialer{
Timeout: time.Second,
}
conn, err := dialer.DialContext(ctx, protocol, checkAddress.String())
if err != nil {
return fmt.Errorf("dialing: %w", err)
}
err = conn.Close()
if err != nil {
return fmt.Errorf("closing connection: %w", err)
}
return nil
}

View File

@@ -0,0 +1,166 @@
package netlink
import (
"context"
"errors"
"net"
"net/netip"
"strings"
"testing"
"time"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func isIPv6LocallySupported() bool {
dialer := net.Dialer{Timeout: time.Millisecond}
_, err := dialer.Dial("tcp6", "[::1]:9999")
return !strings.HasSuffix(err.Error(), "connect: cannot assign requested address")
}
// Susceptible to TOCTOU but it should be fine for the use case.
func findAvailableTCPPort(t *testing.T) (port uint16) {
t.Helper()
listener, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
addr := listener.Addr().String()
err = listener.Close()
require.NoError(t, err)
addrPort, err := netip.ParseAddrPort(addr)
require.NoError(t, err)
return addrPort.Port()
}
func Test_dialAddrThroughFirewall(t *testing.T) {
t.Parallel()
errTest := errors.New("test error")
const ipv6InternetWorks = false
testCases := map[string]struct {
getIPv6CheckAddr func(t *testing.T) netip.AddrPort
firewallAddErr error
firewallRemoveErr error
errMessageRegex func() string
}{
"cloudflare.com": {
getIPv6CheckAddr: func(_ *testing.T) netip.AddrPort {
return netip.MustParseAddrPort("[2606:4700::6810:84e5]:443")
},
errMessageRegex: func() string {
if ipv6InternetWorks {
return ""
}
return "dialing: dial tcp \\[2606:4700::6810:84e5\\]:443: " +
"connect: (cannot assign requested address|network is unreachable)"
},
},
"local_server": {
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
t.Helper()
network := "tcp6"
loopback := netip.MustParseAddr("::1")
if !isIPv6LocallySupported() {
network = "tcp4"
loopback = netip.MustParseAddr("127.0.0.1")
}
listener, err := net.ListenTCP(network, nil)
require.NoError(t, err)
t.Cleanup(func() {
err := listener.Close()
require.NoError(t, err)
})
addrPort := netip.MustParseAddrPort(listener.Addr().String())
return netip.AddrPortFrom(loopback, addrPort.Port())
},
},
"no_local_server": {
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
t.Helper()
loopback := netip.MustParseAddr("::1")
if !ipv6InternetWorks {
loopback = netip.MustParseAddr("127.0.0.1")
}
availablePort := findAvailableTCPPort(t)
return netip.AddrPortFrom(loopback, availablePort)
},
errMessageRegex: func() string {
return "dialing: dial tcp (\\[::1\\]|127\\.0\\.0\\.1):[1-9][0-9]{1,4}: " +
"connect: connection refused"
},
},
"firewall_add_error": {
firewallAddErr: errTest,
errMessageRegex: func() string {
return "accepting output traffic: test error"
},
},
"firewall_remove_error": {
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
t.Helper()
network := "tcp4"
loopback := netip.MustParseAddr("127.0.0.1")
listener, err := net.ListenTCP(network, nil)
require.NoError(t, err)
t.Cleanup(func() {
err := listener.Close()
require.NoError(t, err)
})
addrPort := netip.MustParseAddrPort(listener.Addr().String())
return netip.AddrPortFrom(loopback, addrPort.Port())
},
firewallRemoveErr: errTest,
errMessageRegex: func() string {
return "removing output traffic rule: test error"
},
},
}
for name, testCase := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
var checkAddr netip.AddrPort
if testCase.getIPv6CheckAddr != nil {
checkAddr = testCase.getIPv6CheckAddr(t)
}
ctx := context.Background()
const intf = "eth0"
firewall := NewMockFirewall(ctrl)
call := firewall.EXPECT().AcceptOutput(ctx, "tcp", intf,
checkAddr.Addr(), checkAddr.Port(), false).
Return(testCase.firewallAddErr)
if testCase.firewallAddErr == nil {
firewall.EXPECT().AcceptOutput(ctx, "tcp", intf,
checkAddr.Addr(), checkAddr.Port(), true).
Return(testCase.firewallRemoveErr).After(call)
}
err := dialAddrThroughFirewall(ctx, intf, checkAddr, firewall)
var errMessageRegex string
if testCase.errMessageRegex != nil {
errMessageRegex = testCase.errMessageRegex()
}
if errMessageRegex == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Regexp(t, errMessageRegex, err.Error())
}
})
}
}

View File

@@ -0,0 +1,3 @@
package netlink
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall

View File

@@ -0,0 +1,50 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/netlink (interfaces: Firewall)
// Package netlink is a generated GoMock package.
package netlink
import (
context "context"
netip "net/netip"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockFirewall is a mock of Firewall interface.
type MockFirewall struct {
ctrl *gomock.Controller
recorder *MockFirewallMockRecorder
}
// MockFirewallMockRecorder is the mock recorder for MockFirewall.
type MockFirewallMockRecorder struct {
mock *MockFirewall
}
// NewMockFirewall creates a new mock instance.
func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall {
mock := &MockFirewall{ctrl: ctrl}
mock.recorder = &MockFirewallMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder {
return m.recorder
}
// AcceptOutput mocks base method.
func (m *MockFirewall) AcceptOutput(arg0 context.Context, arg1, arg2 string, arg3 netip.Addr, arg4 uint16, arg5 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcceptOutput", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].(error)
return ret0
}
// AcceptOutput indicates an expected call of AcceptOutput.
func (mr *MockFirewallMockRecorder) AcceptOutput(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutput", reflect.TypeOf((*MockFirewall)(nil).AcceptOutput), arg0, arg1, arg2, arg3, arg4, arg5)
}