feat(internal/wireguard): opportunistic kernelspace

- Auto detect if kernelspace implementation is available
- Fallback to Go userspace implementation if kernel is not available
This commit is contained in:
Quentin McGaw
2021-12-14 11:03:36 +00:00
parent b9a9319cb4
commit cfa3bb3b64
14 changed files with 229 additions and 79 deletions

View File

@@ -64,7 +64,7 @@ using Go, OpenVPN or Wireguard, iptables, DNS over TLS, ShadowSocks and an HTTP
- Based on Alpine 3.14 for a small Docker image of 33MB - Based on Alpine 3.14 for a small Docker image of 33MB
- Supports: **Cyberghost**, **ExpressVPN**, **FastestVPN**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad**, **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **Surfshark**, **TorGuard**, **VPNUnlimited**, **Vyprvpn**, **WeVPN**, **Windscribe** servers - Supports: **Cyberghost**, **ExpressVPN**, **FastestVPN**, **HideMyAss**, **IPVanish**, **IVPN**, **Mullvad**, **NordVPN**, **Perfect Privacy**, **Privado**, **Private Internet Access**, **PrivateVPN**, **ProtonVPN**, **PureVPN**, **Surfshark**, **TorGuard**, **VPNUnlimited**, **Vyprvpn**, **WeVPN**, **Windscribe** servers
- Supports OpenVPN for all providers listed - Supports OpenVPN for all providers listed
- Supports Wireguard - Supports Wireguard both kernelspace and userspace
- For **Mullvad**, **Ivpn** and **Windscribe** - For **Mullvad**, **Ivpn** and **Windscribe**
- For **Torguard**, **VPN Unlimited** and **WeVPN** using [the custom provider](https://github.com/qdm12/gluetun/wiki/Custom-provider) - For **Torguard**, **VPN Unlimited** and **WeVPN** using [the custom provider](https://github.com/qdm12/gluetun/wiki/Custom-provider)
- For custom Wireguard configurations using [the custom provider](https://github.com/qdm12/gluetun/wiki/Custom-provider) - For custom Wireguard configurations using [the custom provider](https://github.com/qdm12/gluetun/wiki/Custom-provider)

4
go.mod
View File

@@ -14,7 +14,7 @@ require (
github.com/qdm12/ss-server v0.3.0 github.com/qdm12/ss-server v0.3.0
github.com/qdm12/updated v0.0.0-20210603204757-205acfe6937e github.com/qdm12/updated v0.0.0-20210603204757-205acfe6937e
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.7.0
github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netlink v1.1.1-0.20211129163951-9ada19101fc5
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c
golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19 golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20210803171230-4253848d036c golang.zx2c4.com/wireguard/wgctrl v0.0.0-20210803171230-4253848d036c
@@ -34,7 +34,7 @@ require (
github.com/mr-tron/base58 v1.2.0 // indirect github.com/mr-tron/base58 v1.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect
go4.org/intern v0.0.0-20210108033219-3eb7198706b2 // indirect go4.org/intern v0.0.0-20210108033219-3eb7198706b2 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20201222180813-1025295fd063 // indirect
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect

11
go.sum
View File

@@ -130,10 +130,10 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= github.com/vishvananda/netlink v1.1.1-0.20211129163951-9ada19101fc5 h1:b/k/BVWzWRS5v6AB0gf2ckFSbFsHN5jR0HoNso1pN+w=
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= github.com/vishvananda/netlink v1.1.1-0.20211129163951-9ada19101fc5/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/xanzy/ssh-agent v0.2.1/go.mod h1:mLlQY/MoOhWBj+gOGMQkOeiEvkx+8pJSI+0Bx9h2kr4= github.com/xanzy/ssh-agent v0.2.1/go.mod h1:mLlQY/MoOhWBj+gOGMQkOeiEvkx+8pJSI+0Bx9h2kr4=
github.com/yl2chen/cidranger v1.0.2/go.mod h1:9U1yz7WPYDwf0vpNWFaeRh0bjwz5RVgRy/9UEQfHl0g= github.com/yl2chen/cidranger v1.0.2/go.mod h1:9U1yz7WPYDwf0vpNWFaeRh0bjwz5RVgRy/9UEQfHl0g=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -185,7 +185,6 @@ golang.org/x/sys v0.0.0-20190221075227-b4e8571b14e0/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -193,7 +192,9 @@ golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201118182958-a01c418693c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201118182958-a01c418693c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View File

@@ -1,6 +1,10 @@
package netlink package netlink
import "github.com/vishvananda/netlink" import (
"fmt"
"github.com/vishvananda/netlink"
)
//nolint:revive //nolint:revive
const ( const (
@@ -8,3 +12,20 @@ const (
FAMILY_V4 = netlink.FAMILY_V4 FAMILY_V4 = netlink.FAMILY_V4
FAMILY_V6 = netlink.FAMILY_V6 FAMILY_V6 = netlink.FAMILY_V6
) )
type WireguardChecker interface {
IsWireguardSupported() (ok bool, err error)
}
func (n *NetLink) IsWireguardSupported() (ok bool, err error) {
families, err := netlink.GenlFamilyList()
if err != nil {
return false, fmt.Errorf("cannot list gen 1 families: %w", err)
}
for _, family := range families {
if family.Name == "wireguard" {
return true, nil
}
}
return false, nil
}

View File

@@ -0,0 +1,21 @@
package netlink
import (
"testing"
"github.com/stretchr/testify/require"
)
func Test_NetLink_IsWireguardSupported(t *testing.T) {
t.Skip() // TODO unskip once the data race problem with netlink.GenlFamilyList() is fixed
t.Parallel()
netLink := &NetLink{}
ok, err := netLink.IsWireguardSupported()
require.NoError(t, err)
if ok { // cannot assert since this depends on kernel
t.Log("wireguard is supported")
} else {
t.Log("wireguard is not supported")
}
}

View File

@@ -9,4 +9,5 @@ type NetLinker interface {
Linker Linker
Router Router
Ruler Ruler
WireguardChecker
} }

View File

@@ -5,6 +5,7 @@ import "github.com/vishvananda/netlink"
type ( type (
Link = netlink.Link Link = netlink.Link
Bridge = netlink.Bridge Bridge = netlink.Bridge
Wireguard = netlink.Wireguard
) )
var _ Linker = (*NetLink)(nil) var _ Linker = (*NetLink)(nil)

View File

@@ -63,6 +63,21 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrList", reflect.TypeOf((*MockNetLinker)(nil).AddrList), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrList", reflect.TypeOf((*MockNetLinker)(nil).AddrList), arg0, arg1)
} }
// IsWireguardSupported mocks base method.
func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsWireguardSupported")
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsWireguardSupported", reflect.TypeOf((*MockNetLinker)(nil).IsWireguardSupported))
}
// LinkAdd mocks base method. // LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error { func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -88,7 +88,7 @@ func Test_Routing_addIPRule(t *testing.T) {
ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99), ruleToAdd: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
err: errDummy, err: errDummy,
}, },
err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 table 99"), err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99"),
}, },
"add rule success": { "add rule success": {
src: makeIPNet(t, 1), src: makeIPNet(t, 1),
@@ -193,7 +193,7 @@ func Test_Routing_deleteIPRule(t *testing.T) {
ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99), ruleToDel: makeIPRule(t, makeIPNet(t, 1), makeIPNet(t, 2), 99, 99),
err: errDummy, err: errDummy,
}, },
err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 table 99"), err: errors.New("dummy error: for rule: ip rule 99: from 1.1.1.0/24 to 2.2.2.0/24 table 99"),
}, },
"rule deleted": { "rule deleted": {
src: makeIPNet(t, 1), src: makeIPNet(t, 1),

View File

@@ -10,9 +10,11 @@ type NetLinker interface {
RouteAdd(route *netlink.Route) error RouteAdd(route *netlink.Route) error
RuleAdd(rule *netlink.Rule) error RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error RuleDel(rule *netlink.Rule) error
LinkAdd(link netlink.Link) (err error)
LinkList() (links []netlink.Link, err error) LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error)
LinkSetUp(link netlink.Link) error LinkSetUp(link netlink.Link) error
LinkSetDown(link netlink.Link) error LinkSetDown(link netlink.Link) error
LinkDel(link netlink.Link) error LinkDel(link netlink.Link) error
IsWireguardSupported() (ok bool, err error)
} }

View File

@@ -48,6 +48,35 @@ func (mr *MockNetLinkerMockRecorder) AddrAdd(arg0, arg1 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrAdd", reflect.TypeOf((*MockNetLinker)(nil).AddrAdd), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrAdd", reflect.TypeOf((*MockNetLinker)(nil).AddrAdd), arg0, arg1)
} }
// IsWireguardSupported mocks base method.
func (m *MockNetLinker) IsWireguardSupported() (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsWireguardSupported")
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsWireguardSupported indicates an expected call of IsWireguardSupported.
func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsWireguardSupported", reflect.TypeOf((*MockNetLinker)(nil).IsWireguardSupported))
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// LinkAdd indicates an expected call of LinkAdd.
func (mr *MockNetLinkerMockRecorder) LinkAdd(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkAdd", reflect.TypeOf((*MockNetLinker)(nil).LinkAdd), arg0)
}
// LinkByName mocks base method. // LinkByName mocks base method.
func (m *MockNetLinker) LinkByName(arg0 string) (netlink.Link, error) { func (m *MockNetLinker) LinkByName(arg0 string) (netlink.Link, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -53,7 +53,7 @@ func Test_Wireguard_addRoute(t *testing.T) {
Table: firewallMark, Table: firewallMark,
}, },
routeAddErr: errDummy, routeAddErr: errDummy,
err: errors.New("dummy: when adding route: {Ifindex: 88 Dst: 1.2.3.4/32 Src: <nil> Gw: <nil> Flags: [] Table: 51820}"), //nolint:lll err: errors.New("dummy: when adding route: {Ifindex: 88 Dst: 1.2.3.4/32 Src: <nil> Gw: <nil> Flags: [] Table: 51820 Realm: 0}"), //nolint:lll
}, },
} }

View File

@@ -51,7 +51,7 @@ func Test_Wireguard_addRule(t *testing.T) {
SuppressPrefixlen: -1, SuppressPrefixlen: -1,
}, },
ruleAddErr: errDummy, ruleAddErr: errDummy,
err: errors.New("dummy: when adding rule: ip rule 987: from <nil> table 456"), err: errors.New("dummy: when adding rule: ip rule 987: from all to all table 456"),
}, },
"rule delete error": { "rule delete error": {
expectedRule: &netlink.Rule{ expectedRule: &netlink.Rule{
@@ -66,7 +66,7 @@ func Test_Wireguard_addRule(t *testing.T) {
SuppressPrefixlen: -1, SuppressPrefixlen: -1,
}, },
ruleDelErr: errDummy, ruleDelErr: errDummy,
cleanupErr: errors.New("dummy: when deleting rule: ip rule 987: from <nil> table 456"), cleanupErr: errors.New("dummy: when deleting rule: ip rule 987: from all to all table 456"),
}, },
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"github.com/qdm12/gluetun/internal/netlink"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
@@ -15,7 +16,9 @@ import (
var ( var (
ErrDetectIPv6 = errors.New("cannot detect IPv6 support") ErrDetectIPv6 = errors.New("cannot detect IPv6 support")
ErrDetectKernel = errors.New("cannot detect Kernel support")
ErrCreateTun = errors.New("cannot create TUN device") ErrCreateTun = errors.New("cannot create TUN device")
ErrAddLink = errors.New("cannot add Wireguard link")
ErrFindLink = errors.New("cannot find link") ErrFindLink = errors.New("cannot find link")
ErrFindDevice = errors.New("cannot find Wireguard device") ErrFindDevice = errors.New("cannot find Wireguard device")
ErrUAPISocketOpening = errors.New("cannot open UAPI socket") ErrUAPISocketOpening = errors.New("cannot open UAPI socket")
@@ -23,6 +26,7 @@ var (
ErrUAPIListen = errors.New("cannot listen on UAPI socket") ErrUAPIListen = errors.New("cannot listen on UAPI socket")
ErrAddAddress = errors.New("cannot add address to wireguard interface") ErrAddAddress = errors.New("cannot add address to wireguard interface")
ErrConfigure = errors.New("cannot configure wireguard interface") ErrConfigure = errors.New("cannot configure wireguard interface")
ErrDeviceInfo = errors.New("cannot get wireguard device information")
ErrIfaceUp = errors.New("cannot set the interface to UP") ErrIfaceUp = errors.New("cannot set the interface to UP")
ErrRouteAdd = errors.New("cannot add route for interface") ErrRouteAdd = errors.New("cannot add route for interface")
ErrRuleAdd = errors.New("cannot add rule for interface") ErrRuleAdd = errors.New("cannot add rule for interface")
@@ -41,6 +45,12 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
return return
} }
doKernel, err := w.netlink.IsWireguardSupported()
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrDetectKernel, err)
return
}
client, err := wgctrl.New() client, err := wgctrl.New()
if err != nil { if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err) waitError <- fmt.Errorf("%w: %s", ErrWgctrlOpen, err)
@@ -52,62 +62,21 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
defer closers.cleanup(w.logger) defer closers.cleanup(w.logger)
tun, err := tun.CreateTUN(w.settings.InterfaceName, device.DefaultMTU) setupFunction := setupUserSpace
if err != nil { if doKernel {
waitError <- fmt.Errorf("%w: %s", ErrCreateTun, err) w.logger.Info("Using available kernelspace implementation")
return setupFunction = setupKernelSpace
} else {
w.logger.Info("Using userspace implementation since Kernel support does not exist")
} }
closers.add("closing TUN device", stepSeven, tun.Close) link, waitAndCleanup, err := setupFunction(ctx,
w.settings.InterfaceName, w.netlink, &closers, w.logger)
tunName, err := tun.Name()
if err != nil { if err != nil {
waitError <- fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err) waitError <- err
return
} else if tunName != w.settings.InterfaceName {
waitError <- fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, w.settings.InterfaceName, tunName)
return return
} }
link, err := w.netlink.LinkByName(w.settings.InterfaceName)
if err != nil {
waitError <- fmt.Errorf("%w: %s: %s", ErrFindLink, w.settings.InterfaceName, err)
return
}
bind := conn.NewDefaultBind()
closers.add("closing bind", stepSeven, bind.Close)
deviceLogger := makeDeviceLogger(w.logger)
device := device.NewDevice(tun, bind, deviceLogger)
closers.add("closing Wireguard device", stepSix, func() error {
device.Close()
return nil
})
uapiFile, err := ipc.UAPIOpen(w.settings.InterfaceName)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
return
}
closers.add("closing UAPI file", stepThree, uapiFile.Close)
uapiListener, err := ipc.UAPIListen(w.settings.InterfaceName, uapiFile)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrUAPIListen, err)
return
}
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
// acceptAndHandle exits when uapiListener is closed
uapiAcceptErrorCh := make(chan error)
go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh)
err = w.addAddresses(link, w.settings.Addresses) err = w.addAddresses(link, w.settings.Addresses)
if err != nil { if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err) waitError <- fmt.Errorf("%w: %s", ErrAddAddress, err)
@@ -128,9 +97,6 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
closers.add("shutting down link", stepFour, func() error { closers.add("shutting down link", stepFour, func() error {
return w.netlink.LinkSetDown(link) return w.netlink.LinkSetDown(link)
}) })
closers.add("deleting link", stepFive, func() error {
return w.netlink.LinkDel(link)
})
err = w.addRoute(link, allIPv4(), w.settings.FirewallMark) err = w.addRoute(link, allIPv4(), w.settings.FirewallMark)
if err != nil { if err != nil {
@@ -158,6 +124,96 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
w.logger.Info("Wireguard is up") w.logger.Info("Wireguard is up")
ready <- struct{}{} ready <- struct{}{}
waitError <- waitAndCleanup()
}
type waitAndCleanupFunc func() error
func setupKernelSpace(ctx context.Context,
interfaceName string, netLinker NetLinker,
closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
linkAttrs := netlink.LinkAttrs{
Name: interfaceName,
MTU: device.DefaultMTU, // TODO
}
link = &netlink.Wireguard{
LinkAttrs: linkAttrs,
}
err = netLinker.LinkAdd(link)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
}
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link)
})
waitAndCleanup = func() error {
<-ctx.Done()
closers.cleanup(logger)
return ctx.Err()
}
return link, waitAndCleanup, nil
}
func setupUserSpace(ctx context.Context,
interfaceName string, netLinker NetLinker,
closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
tun, err := tun.CreateTUN(interfaceName, device.DefaultMTU)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
}
closers.add("closing TUN device", stepSeven, tun.Close)
tunName, err := tun.Name()
if err != nil {
return nil, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
} else if tunName != interfaceName {
return nil, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, interfaceName, tunName)
}
link, err = netLinker.LinkByName(interfaceName)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
}
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link)
})
bind := conn.NewDefaultBind()
closers.add("closing bind", stepSeven, bind.Close)
deviceLogger := makeDeviceLogger(logger)
device := device.NewDevice(tun, bind, deviceLogger)
closers.add("closing Wireguard device", stepSix, func() error {
device.Close()
return nil
})
uapiFile, err := ipc.UAPIOpen(interfaceName)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
}
closers.add("closing UAPI file", stepThree, uapiFile.Close)
uapiListener, err := ipc.UAPIListen(interfaceName, uapiFile)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
}
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
// acceptAndHandle exits when uapiListener is closed
uapiAcceptErrorCh := make(chan error)
go acceptAndHandle(uapiListener, device, uapiAcceptErrorCh)
waitAndCleanup = func() error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
err = ctx.Err() err = ctx.Err()
@@ -167,11 +223,14 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
err = ErrDeviceWaited err = ErrDeviceWaited
} }
closers.cleanup(w.logger) closers.cleanup(logger)
<-uapiAcceptErrorCh // wait for acceptAndHandle to exit <-uapiAcceptErrorCh // wait for acceptAndHandle to exit
waitError <- err return err
}
return link, waitAndCleanup, nil
} }
func acceptAndHandle(uapi net.Listener, device *device.Device, func acceptAndHandle(uapi net.Listener, device *device.Device,