95 lines
2.3 KiB
Go
95 lines
2.3 KiB
Go
package wireguard
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"testing"
|
|
|
|
"github.com/golang/mock/gomock"
|
|
"github.com/qdm12/gluetun/internal/netlink"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func Test_Wireguard_addAddresses(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ipNetOne := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
|
|
ipNetTwo := &net.IPNet{IP: net.IPv4(4, 5, 6, 7), Mask: net.IPv4Mask(255, 255, 255, 128)}
|
|
|
|
newLink := func() netlink.Link {
|
|
linkAttrs := netlink.NewLinkAttrs()
|
|
linkAttrs.Name = "a_bridge"
|
|
return &netlink.Bridge{
|
|
LinkAttrs: linkAttrs,
|
|
}
|
|
}
|
|
|
|
errDummy := errors.New("dummy")
|
|
|
|
testCases := map[string]struct {
|
|
link netlink.Link
|
|
addrs []*net.IPNet
|
|
expectedAddrs []*netlink.Addr
|
|
addrAddErrs []error
|
|
err error
|
|
}{
|
|
"success": {
|
|
link: newLink(),
|
|
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
|
|
expectedAddrs: []*netlink.Addr{
|
|
{IPNet: ipNetOne}, {IPNet: ipNetTwo},
|
|
},
|
|
addrAddErrs: []error{nil, nil},
|
|
},
|
|
"first add error": {
|
|
link: newLink(),
|
|
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
|
|
expectedAddrs: []*netlink.Addr{
|
|
{IPNet: ipNetOne},
|
|
},
|
|
addrAddErrs: []error{errDummy},
|
|
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
|
|
},
|
|
"second add error": {
|
|
link: newLink(),
|
|
addrs: []*net.IPNet{ipNetOne, ipNetTwo},
|
|
expectedAddrs: []*netlink.Addr{
|
|
{IPNet: ipNetOne}, {IPNet: ipNetTwo},
|
|
},
|
|
addrAddErrs: []error{nil, errDummy},
|
|
err: errors.New("dummy: when adding address 4.5.6.7/25 to link a_bridge"),
|
|
},
|
|
}
|
|
|
|
for name, testCase := range testCases {
|
|
testCase := testCase
|
|
t.Run(name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ctrl := gomock.NewController(t)
|
|
|
|
require.Equal(t, len(testCase.expectedAddrs), len(testCase.addrAddErrs))
|
|
|
|
netLinker := NewMockNetLinker(ctrl)
|
|
wg := Wireguard{
|
|
netlink: netLinker,
|
|
}
|
|
|
|
for i := range testCase.expectedAddrs {
|
|
netLinker.EXPECT().
|
|
AddrAdd(testCase.link, testCase.expectedAddrs[i]).
|
|
Return(testCase.addrAddErrs[i])
|
|
}
|
|
|
|
err := wg.addAddresses(testCase.link, testCase.addrs)
|
|
|
|
if testCase.err != nil {
|
|
require.Error(t, err)
|
|
assert.Equal(t, testCase.err.Error(), err.Error())
|
|
} else {
|
|
require.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|