Fix: wireguard cleanup preventing restarts

This commit is contained in:
Quentin McGaw (desktop)
2021-09-04 22:29:04 +00:00
parent 61afdce788
commit 82ac568ee3
6 changed files with 65 additions and 6 deletions

View File

@@ -16,6 +16,7 @@ type Linker interface {
LinkAdd(link netlink.Link) (err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (err error)
LinkSetDown(link netlink.Link) (err error)
}
func (n *NetLink) LinkList() (links []netlink.Link, err error) {
@@ -41,3 +42,7 @@ func (n *NetLink) LinkDel(link netlink.Link) (err error) {
func (n *NetLink) LinkSetUp(link netlink.Link) (err error) {
return netlink.LinkSetUp(link)
}
func (n *NetLink) LinkSetDown(link netlink.Link) (err error) {
return netlink.LinkSetDown(link)
}

View File

@@ -136,6 +136,20 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkList", reflect.TypeOf((*MockNetLinker)(nil).LinkList))
}
// LinkSetDown mocks base method.
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// LinkSetDown indicates an expected call of LinkSetDown.
func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetDown", reflect.TypeOf((*MockNetLinker)(nil).LinkSetDown), arg0)
}
// LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error {
m.ctrl.T.Helper()

View File

@@ -52,8 +52,12 @@ const (
stepTwo
// stepThree closes the UAPI file.
stepThree
// stepFour closes the Wireguard device.
// stepFour shuts down the Wireguard link.
stepFour
// stepFive closes the bind connection and the TUN device file.
// stepFive removes the Wireguard link.
stepFive
// stepSix closes the Wireguard device.
stepSix
// stepSeven closes the bind connection and the TUN device file.
stepSeven
)

View File

@@ -11,4 +11,6 @@ type NetLinker interface {
RuleDel(rule *netlink.Rule) error
LinkByName(name string) (link netlink.Link, err error)
LinkSetUp(link netlink.Link) error
LinkSetDown(link netlink.Link) error
LinkDel(link netlink.Link) error
}

View File

@@ -8,7 +8,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
netlink "github.com/qdm12/gluetun/internal/netlink"
netlink "github.com/vishvananda/netlink"
)
// MockNetLinker is a mock of NetLinker interface.
@@ -63,6 +63,34 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByName", reflect.TypeOf((*MockNetLinker)(nil).LinkByName), arg0)
}
// LinkDel mocks base method.
func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkDel", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// LinkDel indicates an expected call of LinkDel.
func (mr *MockNetLinkerMockRecorder) LinkDel(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkDel", reflect.TypeOf((*MockNetLinker)(nil).LinkDel), arg0)
}
// LinkSetDown mocks base method.
func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetDown", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// LinkSetDown indicates an expected call of LinkSetDown.
func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetDown", reflect.TypeOf((*MockNetLinker)(nil).LinkSetDown), arg0)
}
// LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error {
m.ctrl.T.Helper()

View File

@@ -51,7 +51,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
return
}
closers.add("closing TUN device", stepFive, tun.Close)
closers.add("closing TUN device", stepSeven, tun.Close)
tunName, err := tun.Name()
if err != nil {
@@ -71,12 +71,12 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
bind := conn.NewDefaultBind()
closers.add("closing bind", stepFive, bind.Close)
closers.add("closing bind", stepSeven, bind.Close)
deviceLogger := makeDeviceLogger(w.logger)
device := device.NewDevice(tun, bind, deviceLogger)
closers.add("closing Wireguard device", stepFour, func() error {
closers.add("closing Wireguard device", stepSix, func() error {
device.Close()
return nil
})
@@ -117,6 +117,12 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
return
}
closers.add("shutting down link", stepFour, func() error {
return w.netlink.LinkSetDown(link)
})
closers.add("deleting link", stepFive, func() error {
return w.netlink.LinkDel(link)
})
err = w.addRoute(link, allIPv4(), w.settings.FirewallMark)
if err != nil {