Feat: support IPv6 routing for Wireguard

This commit is contained in:
Quentin McGaw (desktop)
2021-09-21 15:12:48 +00:00
parent 9f001bbc06
commit 59a3a072e0
5 changed files with 82 additions and 0 deletions

View File

@@ -6,4 +6,5 @@ import "github.com/vishvananda/netlink"
const ( const (
FAMILY_ALL = netlink.FAMILY_ALL FAMILY_ALL = netlink.FAMILY_ALL
FAMILY_V4 = netlink.FAMILY_V4 FAMILY_V4 = netlink.FAMILY_V4
FAMILY_V6 = netlink.FAMILY_V6
) )

View File

@@ -0,0 +1,33 @@
package wireguard
import (
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/netlink"
)
var (
errLinkList = errors.New("cannot list links")
errRouteList = errors.New("cannot list routes")
)
func (w *Wireguard) isIPv6Supported() (supported bool, err error) {
links, err := w.netlink.LinkList()
if err != nil {
return false, fmt.Errorf("%w: %s", errLinkList, err)
}
for _, link := range links {
routes, err := w.netlink.RouteList(link, netlink.FAMILY_V6)
if err != nil {
return false, fmt.Errorf("%w: %s", errRouteList, err)
}
if len(routes) > 0 {
return true, nil
}
}
return false, nil
}

View File

@@ -6,9 +6,11 @@ import "github.com/qdm12/gluetun/internal/netlink"
type NetLinker interface { type NetLinker interface {
AddrAdd(link netlink.Link, addr *netlink.Addr) error AddrAdd(link netlink.Link, addr *netlink.Addr) error
RouteList(link netlink.Link, family int) (routes []netlink.Route, err error)
RouteAdd(route *netlink.Route) error RouteAdd(route *netlink.Route) error
RuleAdd(rule *netlink.Rule) error RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error RuleDel(rule *netlink.Rule) error
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error)
LinkSetUp(link netlink.Link) error LinkSetUp(link netlink.Link) error
LinkSetDown(link netlink.Link) error LinkSetDown(link netlink.Link) error

View File

@@ -77,6 +77,21 @@ func (mr *MockNetLinkerMockRecorder) LinkDel(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkDel", reflect.TypeOf((*MockNetLinker)(nil).LinkDel), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkDel", reflect.TypeOf((*MockNetLinker)(nil).LinkDel), arg0)
} }
// LinkList mocks base method.
func (m *MockNetLinker) LinkList() ([]netlink.Link, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkList")
ret0, _ := ret[0].([]netlink.Link)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkList indicates an expected call of LinkList.
func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkList", reflect.TypeOf((*MockNetLinker)(nil).LinkList))
}
// LinkSetDown mocks base method. // LinkSetDown mocks base method.
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error { func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -119,6 +134,21 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteAdd", reflect.TypeOf((*MockNetLinker)(nil).RouteAdd), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteAdd", reflect.TypeOf((*MockNetLinker)(nil).RouteAdd), arg0)
} }
// RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0, arg1)
ret0, _ := ret[0].([]netlink.Route)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RouteList indicates an expected call of RouteList.
func (mr *MockNetLinkerMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RouteList", reflect.TypeOf((*MockNetLinker)(nil).RouteList), arg0, arg1)
}
// RuleAdd mocks base method. // RuleAdd mocks base method.
func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error { func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -14,6 +14,7 @@ import (
) )
var ( var (
ErrDetectIPv6 = errors.New("cannot detect IPv6 support")
ErrCreateTun = errors.New("cannot create TUN device") ErrCreateTun = errors.New("cannot create TUN device")
ErrFindLink = errors.New("cannot find link") ErrFindLink = errors.New("cannot find link")
ErrFindDevice = errors.New("cannot find Wireguard device") ErrFindDevice = errors.New("cannot find Wireguard device")
@@ -34,6 +35,12 @@ type Runner interface {
// See https://git.zx2c4.com/wireguard-go/tree/main.go // See https://git.zx2c4.com/wireguard-go/tree/main.go
func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) { func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<- struct{}) {
doIPv6, err := w.isIPv6Supported()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrDetectIPv6, err)
return
}
client, err := wgctrl.New() client, err := wgctrl.New()
if err != nil { if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err) waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err)
@@ -131,6 +138,15 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
return return
} }
if doIPv6 {
// requires net.ipv6.conf.all.disable_ipv6=0
err = w.addRoute(link, allIPv6(), w.settings.FirewallMark)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrRouteAdd, err)
return
}
}
ruleCleanup, err := w.addRule( ruleCleanup, err := w.addRule(
w.settings.RulePriority, w.settings.FirewallMark) w.settings.RulePriority, w.settings.FirewallMark)
if err != nil { if err != nil {