feat(netlink): detect IPv6 using query to address
- If a default IPv6 route is found, query the ip:port defined by `IPV6_CHECK_ADDRESS` to check for internet access
This commit is contained in:
@@ -1,9 +1,19 @@
|
||||
package netlink
|
||||
|
||||
import "github.com/qdm12/log"
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
|
||||
"github.com/qdm12/log"
|
||||
)
|
||||
|
||||
type DebugLogger interface {
|
||||
Debug(message string)
|
||||
Debugf(format string, args ...any)
|
||||
Patch(options ...log.Option)
|
||||
}
|
||||
|
||||
type Firewall interface {
|
||||
AcceptOutput(ctx context.Context, protocol, intf string, ip netip.Addr,
|
||||
port uint16, remove bool) (err error)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
type IPv6SupportLevel uint8
|
||||
@@ -21,7 +25,9 @@ func (i IPv6SupportLevel) IsSupported() bool {
|
||||
return i == IPv6Supported || i == IPv6Internet
|
||||
}
|
||||
|
||||
func (n *NetLink) FindIPv6SupportLevel() (level IPv6SupportLevel, err error) {
|
||||
func (n *NetLink) FindIPv6SupportLevel(ctx context.Context,
|
||||
checkAddress netip.AddrPort, firewall Firewall,
|
||||
) (level IPv6SupportLevel, err error) {
|
||||
routes, err := n.RouteList(FamilyV6)
|
||||
if err != nil {
|
||||
return IPv6Unsupported, fmt.Errorf("listing IPv6 routes: %w", err)
|
||||
@@ -44,7 +50,14 @@ func (n *NetLink) FindIPv6SupportLevel() (level IPv6SupportLevel, err error) {
|
||||
case sourceIsIPv4 && destinationIsIPv4,
|
||||
destinationIsIPv6 && route.Dst.Addr().IsLoopback():
|
||||
case route.Dst.Addr().IsUnspecified(): // default ipv6 route
|
||||
n.debugLogger.Debugf("IPv6 internet access is enabled on link %s", link.Name)
|
||||
n.debugLogger.Debugf("IPv6 default route found on link %s", link.Name)
|
||||
err = dialAddrThroughFirewall(ctx, link.Name, checkAddress, firewall)
|
||||
if err != nil {
|
||||
n.debugLogger.Debugf("IPv6 query failed on %s: %w", link.Name, err)
|
||||
level = IPv6Supported
|
||||
continue
|
||||
}
|
||||
n.debugLogger.Debugf("IPv6 internet is accessible through link %s", link.Name)
|
||||
return IPv6Internet, nil
|
||||
default: // non-default ipv6 route found
|
||||
n.debugLogger.Debugf("IPv6 is supported by link %s", link.Name)
|
||||
@@ -57,3 +70,37 @@ func (n *NetLink) FindIPv6SupportLevel() (level IPv6SupportLevel, err error) {
|
||||
}
|
||||
return level, nil
|
||||
}
|
||||
|
||||
func dialAddrThroughFirewall(ctx context.Context, intf string,
|
||||
checkAddress netip.AddrPort, firewall Firewall,
|
||||
) (err error) {
|
||||
const protocol = "tcp"
|
||||
remove := false
|
||||
err = firewall.AcceptOutput(ctx, protocol, intf,
|
||||
checkAddress.Addr(), checkAddress.Port(), remove)
|
||||
if err != nil {
|
||||
return fmt.Errorf("accepting output traffic: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
remove = true
|
||||
firewallErr := firewall.AcceptOutput(ctx, protocol, intf,
|
||||
checkAddress.Addr(), checkAddress.Port(), remove)
|
||||
if err == nil && firewallErr != nil {
|
||||
err = fmt.Errorf("removing output traffic rule: %w", firewallErr)
|
||||
}
|
||||
}()
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: time.Second,
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, protocol, checkAddress.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing: %w", err)
|
||||
}
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
166
internal/netlink/ipv6_test.go
Normal file
166
internal/netlink/ipv6_test.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package netlink
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func isIPv6LocallySupported() bool {
|
||||
dialer := net.Dialer{Timeout: time.Millisecond}
|
||||
_, err := dialer.Dial("tcp6", "[::1]:9999")
|
||||
return !strings.HasSuffix(err.Error(), "connect: cannot assign requested address")
|
||||
}
|
||||
|
||||
// Susceptible to TOCTOU but it should be fine for the use case.
|
||||
func findAvailableTCPPort(t *testing.T) (port uint16) {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := listener.Addr().String()
|
||||
err = listener.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
addrPort, err := netip.ParseAddrPort(addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
return addrPort.Port()
|
||||
}
|
||||
|
||||
func Test_dialAddrThroughFirewall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errTest := errors.New("test error")
|
||||
|
||||
const ipv6InternetWorks = false
|
||||
|
||||
testCases := map[string]struct {
|
||||
getIPv6CheckAddr func(t *testing.T) netip.AddrPort
|
||||
firewallAddErr error
|
||||
firewallRemoveErr error
|
||||
errMessageRegex func() string
|
||||
}{
|
||||
"cloudflare.com": {
|
||||
getIPv6CheckAddr: func(_ *testing.T) netip.AddrPort {
|
||||
return netip.MustParseAddrPort("[2606:4700::6810:84e5]:443")
|
||||
},
|
||||
errMessageRegex: func() string {
|
||||
if ipv6InternetWorks {
|
||||
return ""
|
||||
}
|
||||
return "dialing: dial tcp \\[2606:4700::6810:84e5\\]:443: " +
|
||||
"connect: (cannot assign requested address|network is unreachable)"
|
||||
},
|
||||
},
|
||||
"local_server": {
|
||||
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
|
||||
t.Helper()
|
||||
|
||||
network := "tcp6"
|
||||
loopback := netip.MustParseAddr("::1")
|
||||
if !isIPv6LocallySupported() {
|
||||
network = "tcp4"
|
||||
loopback = netip.MustParseAddr("127.0.0.1")
|
||||
}
|
||||
|
||||
listener, err := net.ListenTCP(network, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := listener.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
addrPort := netip.MustParseAddrPort(listener.Addr().String())
|
||||
return netip.AddrPortFrom(loopback, addrPort.Port())
|
||||
},
|
||||
},
|
||||
"no_local_server": {
|
||||
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
|
||||
t.Helper()
|
||||
|
||||
loopback := netip.MustParseAddr("::1")
|
||||
if !ipv6InternetWorks {
|
||||
loopback = netip.MustParseAddr("127.0.0.1")
|
||||
}
|
||||
|
||||
availablePort := findAvailableTCPPort(t)
|
||||
return netip.AddrPortFrom(loopback, availablePort)
|
||||
},
|
||||
errMessageRegex: func() string {
|
||||
return "dialing: dial tcp (\\[::1\\]|127\\.0\\.0\\.1):[1-9][0-9]{1,4}: " +
|
||||
"connect: connection refused"
|
||||
},
|
||||
},
|
||||
"firewall_add_error": {
|
||||
firewallAddErr: errTest,
|
||||
errMessageRegex: func() string {
|
||||
return "accepting output traffic: test error"
|
||||
},
|
||||
},
|
||||
"firewall_remove_error": {
|
||||
getIPv6CheckAddr: func(t *testing.T) netip.AddrPort {
|
||||
t.Helper()
|
||||
|
||||
network := "tcp4"
|
||||
loopback := netip.MustParseAddr("127.0.0.1")
|
||||
listener, err := net.ListenTCP(network, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := listener.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
addrPort := netip.MustParseAddrPort(listener.Addr().String())
|
||||
return netip.AddrPortFrom(loopback, addrPort.Port())
|
||||
},
|
||||
firewallRemoveErr: errTest,
|
||||
errMessageRegex: func() string {
|
||||
return "removing output traffic rule: test error"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
var checkAddr netip.AddrPort
|
||||
if testCase.getIPv6CheckAddr != nil {
|
||||
checkAddr = testCase.getIPv6CheckAddr(t)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
const intf = "eth0"
|
||||
firewall := NewMockFirewall(ctrl)
|
||||
call := firewall.EXPECT().AcceptOutput(ctx, "tcp", intf,
|
||||
checkAddr.Addr(), checkAddr.Port(), false).
|
||||
Return(testCase.firewallAddErr)
|
||||
if testCase.firewallAddErr == nil {
|
||||
firewall.EXPECT().AcceptOutput(ctx, "tcp", intf,
|
||||
checkAddr.Addr(), checkAddr.Port(), true).
|
||||
Return(testCase.firewallRemoveErr).After(call)
|
||||
}
|
||||
|
||||
err := dialAddrThroughFirewall(ctx, intf, checkAddr, firewall)
|
||||
var errMessageRegex string
|
||||
if testCase.errMessageRegex != nil {
|
||||
errMessageRegex = testCase.errMessageRegex()
|
||||
}
|
||||
if errMessageRegex == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
assert.Regexp(t, errMessageRegex, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
3
internal/netlink/mocks_generate_test.go
Normal file
3
internal/netlink/mocks_generate_test.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package netlink
|
||||
|
||||
//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Firewall
|
||||
50
internal/netlink/mocks_test.go
Normal file
50
internal/netlink/mocks_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/netlink (interfaces: Firewall)
|
||||
|
||||
// Package netlink is a generated GoMock package.
|
||||
package netlink
|
||||
|
||||
import (
|
||||
context "context"
|
||||
netip "net/netip"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockFirewall is a mock of Firewall interface.
|
||||
type MockFirewall struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockFirewallMockRecorder
|
||||
}
|
||||
|
||||
// MockFirewallMockRecorder is the mock recorder for MockFirewall.
|
||||
type MockFirewallMockRecorder struct {
|
||||
mock *MockFirewall
|
||||
}
|
||||
|
||||
// NewMockFirewall creates a new mock instance.
|
||||
func NewMockFirewall(ctrl *gomock.Controller) *MockFirewall {
|
||||
mock := &MockFirewall{ctrl: ctrl}
|
||||
mock.recorder = &MockFirewallMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockFirewall) EXPECT() *MockFirewallMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AcceptOutput mocks base method.
|
||||
func (m *MockFirewall) AcceptOutput(arg0 context.Context, arg1, arg2 string, arg3 netip.Addr, arg4 uint16, arg5 bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AcceptOutput", arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AcceptOutput indicates an expected call of AcceptOutput.
|
||||
func (mr *MockFirewallMockRecorder) AcceptOutput(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcceptOutput", reflect.TypeOf((*MockFirewall)(nil).AcceptOutput), arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
}
|
||||
Reference in New Issue
Block a user