chore(netlink): define own types with minimal fields
- Allow to swap `github.com/vishvananda/netlink` - Allow to add build tags for each platform - One step closer to development on non-Linux platforms
This commit is contained in:
@@ -531,30 +531,30 @@ type netLinker interface {
|
|||||||
type Addresser interface {
|
type Addresser interface {
|
||||||
AddrList(link netlink.Link, family int) (
|
AddrList(link netlink.Link, family int) (
|
||||||
addresses []netlink.Addr, err error)
|
addresses []netlink.Addr, err error)
|
||||||
AddrReplace(link netlink.Link, addr *netlink.Addr) error
|
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Router interface {
|
type Router interface {
|
||||||
RouteList(link netlink.Link, family int) (
|
RouteList(link *netlink.Link, family int) (
|
||||||
routes []netlink.Route, err error)
|
routes []netlink.Route, err error)
|
||||||
RouteAdd(route *netlink.Route) error
|
RouteAdd(route netlink.Route) error
|
||||||
RouteDel(route *netlink.Route) error
|
RouteDel(route netlink.Route) error
|
||||||
RouteReplace(route *netlink.Route) error
|
RouteReplace(route netlink.Route) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Ruler interface {
|
type Ruler interface {
|
||||||
RuleList(family int) (rules []netlink.Rule, err error)
|
RuleList(family int) (rules []netlink.Rule, err error)
|
||||||
RuleAdd(rule *netlink.Rule) error
|
RuleAdd(rule netlink.Rule) error
|
||||||
RuleDel(rule *netlink.Rule) error
|
RuleDel(rule netlink.Rule) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Linker interface {
|
type Linker interface {
|
||||||
LinkList() (links []netlink.Link, err error)
|
LinkList() (links []netlink.Link, err error)
|
||||||
LinkByName(name string) (link netlink.Link, err error)
|
LinkByName(name string) (link netlink.Link, err error)
|
||||||
LinkByIndex(index int) (link netlink.Link, err error)
|
LinkByIndex(index int) (link netlink.Link, err error)
|
||||||
LinkAdd(link netlink.Link) (err error)
|
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
||||||
LinkDel(link netlink.Link) (err error)
|
LinkDel(link netlink.Link) (err error)
|
||||||
LinkSetUp(link netlink.Link) (err error)
|
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||||
LinkSetDown(link netlink.Link) (err error)
|
LinkSetDown(link netlink.Link) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,40 @@
|
|||||||
package netlink
|
package netlink
|
||||||
|
|
||||||
import "github.com/vishvananda/netlink"
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
type Addr = netlink.Addr
|
"github.com/vishvananda/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Addr struct {
|
||||||
|
Network netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Addr) String() string {
|
||||||
|
return a.Network.String()
|
||||||
|
}
|
||||||
|
|
||||||
func (n *NetLink) AddrList(link Link, family int) (
|
func (n *NetLink) AddrList(link Link, family int) (
|
||||||
addresses []Addr, err error) {
|
addresses []Addr, err error) {
|
||||||
return netlink.AddrList(link, family)
|
netlinkLink := linkToNetlinkLink(&link)
|
||||||
|
netlinkAddresses, err := netlink.AddrList(netlinkLink, family)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
addresses = make([]Addr, len(netlinkAddresses))
|
||||||
|
for i := range netlinkAddresses {
|
||||||
|
addresses[i].Network = netIPNetToNetipPrefix(netlinkAddresses[i].IPNet)
|
||||||
|
}
|
||||||
|
|
||||||
|
return addresses, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) AddrReplace(link Link, addr *Addr) error {
|
func (n *NetLink) AddrReplace(link Link, addr Addr) error {
|
||||||
return netlink.AddrReplace(link, addr)
|
netlinkLink := linkToNetlinkLink(&link)
|
||||||
|
netlinkAddress := netlink.Addr{
|
||||||
|
IPNet: netipPrefixToIPNet(addr.Network),
|
||||||
|
}
|
||||||
|
|
||||||
|
return netlink.AddrReplace(netlinkLink, &netlinkAddress)
|
||||||
}
|
}
|
||||||
|
|||||||
62
internal/netlink/conversion.go
Normal file
62
internal/netlink/conversion.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package netlink
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
func netipPrefixToIPNet(prefix netip.Prefix) (ipNet *net.IPNet) {
|
||||||
|
if !prefix.IsValid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
prefixAddr := prefix.Addr().Unmap()
|
||||||
|
ipMask := net.CIDRMask(prefix.Bits(), prefixAddr.BitLen())
|
||||||
|
ip := netipAddrToNetIP(prefixAddr)
|
||||||
|
|
||||||
|
return &net.IPNet{
|
||||||
|
IP: ip,
|
||||||
|
Mask: ipMask,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func netIPNetToNetipPrefix(ipNet *net.IPNet) (prefix netip.Prefix) {
|
||||||
|
if ipNet == nil || (len(ipNet.IP) != net.IPv4len && len(ipNet.IP) != net.IPv6len) {
|
||||||
|
return prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
var ip netip.Addr
|
||||||
|
if ipv4 := ipNet.IP.To4(); ipv4 != nil {
|
||||||
|
ip = netip.AddrFrom4([4]byte(ipv4))
|
||||||
|
} else {
|
||||||
|
ip = netip.AddrFrom16([16]byte(ipNet.IP))
|
||||||
|
}
|
||||||
|
bits, _ := ipNet.Mask.Size()
|
||||||
|
return netip.PrefixFrom(ip, bits)
|
||||||
|
}
|
||||||
|
|
||||||
|
func netipAddrToNetIP(address netip.Addr) (ip net.IP) {
|
||||||
|
switch {
|
||||||
|
case !address.IsValid():
|
||||||
|
return nil
|
||||||
|
case address.Is4() || address.Is4In6():
|
||||||
|
bytes := address.As4()
|
||||||
|
return net.IP(bytes[:])
|
||||||
|
default:
|
||||||
|
bytes := address.As16()
|
||||||
|
return net.IP(bytes[:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func netIPToNetipAddress(ip net.IP) (address netip.Addr) {
|
||||||
|
if len(ip) != net.IPv4len && len(ip) != net.IPv6len {
|
||||||
|
return address // invalid
|
||||||
|
}
|
||||||
|
|
||||||
|
address, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Sprintf("converting %#v to netip.Addr failed", ip))
|
||||||
|
}
|
||||||
|
return address.Unmap()
|
||||||
|
}
|
||||||
146
internal/netlink/conversion_test.go
Normal file
146
internal/netlink/conversion_test.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
package netlink
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_netipPrefixToIPNet(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
prefix netip.Prefix
|
||||||
|
ipNet *net.IPNet
|
||||||
|
}{
|
||||||
|
"empty_prefix": {},
|
||||||
|
"IPv4_prefix": {
|
||||||
|
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
|
||||||
|
ipNet: &net.IPNet{
|
||||||
|
IP: net.IP{1, 2, 3, 4},
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"IPv4-in-IPv6_prefix": {
|
||||||
|
prefix: netip.PrefixFrom(netip.AddrFrom16(
|
||||||
|
[16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 1, 2, 3, 4}),
|
||||||
|
24),
|
||||||
|
ipNet: &net.IPNet{
|
||||||
|
IP: net.IP{1, 2, 3, 4},
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"IPv6_prefix": {
|
||||||
|
prefix: netip.PrefixFrom(netip.IPv6Loopback(), 8),
|
||||||
|
ipNet: &net.IPNet{
|
||||||
|
IP: net.IPv6loopback,
|
||||||
|
Mask: net.IPMask{0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ipNet := netipPrefixToIPNet(testCase.prefix)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.ipNet, ipNet)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_netIPNetToNetipPrefix(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
ipNet *net.IPNet
|
||||||
|
prefix netip.Prefix
|
||||||
|
}{
|
||||||
|
"empty ipnet": {},
|
||||||
|
"custom sized IP in ipnet": {
|
||||||
|
ipNet: &net.IPNet{
|
||||||
|
IP: net.IP{1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"IPv4 ipnet": {
|
||||||
|
ipNet: &net.IPNet{
|
||||||
|
IP: net.IP{1, 2, 3, 4},
|
||||||
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
|
},
|
||||||
|
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
|
||||||
|
},
|
||||||
|
"IPv4-in-IPv6 ipnet": {
|
||||||
|
ipNet: &net.IPNet{
|
||||||
|
IP: net.IPv4(1, 2, 3, 4),
|
||||||
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
|
},
|
||||||
|
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
|
||||||
|
},
|
||||||
|
"IPv6 ipnet": {
|
||||||
|
ipNet: &net.IPNet{
|
||||||
|
IP: net.IPv6loopback,
|
||||||
|
Mask: net.IPMask{0xff},
|
||||||
|
},
|
||||||
|
prefix: netip.PrefixFrom(netip.IPv6Loopback(), 8),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
prefix := netIPNetToNetipPrefix(testCase.ipNet)
|
||||||
|
|
||||||
|
assert.Equal(t, testCase.prefix, prefix)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_netIPToNetipAddress(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
ip net.IP
|
||||||
|
address netip.Addr
|
||||||
|
panicMessage string
|
||||||
|
}{
|
||||||
|
"nil_ip": {},
|
||||||
|
"ip_not_ipv4_or_ipv6": {
|
||||||
|
ip: net.IP{1},
|
||||||
|
},
|
||||||
|
"IPv4": {
|
||||||
|
ip: net.IPv4(1, 2, 3, 4),
|
||||||
|
address: netip.AddrFrom4([4]byte{1, 2, 3, 4}),
|
||||||
|
},
|
||||||
|
"IPv6": {
|
||||||
|
ip: net.IPv6zero,
|
||||||
|
address: netip.AddrFrom16([16]byte{}),
|
||||||
|
},
|
||||||
|
"IPv4 prefixed with 0xffff": {
|
||||||
|
ip: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 1, 2, 3, 4},
|
||||||
|
address: netip.AddrFrom4([4]byte{1, 2, 3, 4}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if testCase.panicMessage != "" {
|
||||||
|
assert.PanicsWithValue(t, testCase.panicMessage, func() {
|
||||||
|
netIPToNetipAddress(testCase.ip)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
address := netIPToNetipAddress(testCase.ip)
|
||||||
|
assert.Equal(t, testCase.address, address)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,20 +6,19 @@ import (
|
|||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
//nolint:revive
|
|
||||||
const (
|
const (
|
||||||
FAMILY_ALL = netlink.FAMILY_ALL
|
FamilyAll = 0
|
||||||
FAMILY_V4 = netlink.FAMILY_V4
|
FamilyV4 = 2
|
||||||
FAMILY_V6 = netlink.FAMILY_V6
|
FamilyV6 = 10
|
||||||
)
|
)
|
||||||
|
|
||||||
func FamilyToString(family int) string {
|
func FamilyToString(family int) string {
|
||||||
switch family {
|
switch family {
|
||||||
case FAMILY_ALL:
|
case FamilyAll:
|
||||||
return "all"
|
return "all" //nolint:goconst
|
||||||
case FAMILY_V4:
|
case FamilyV4:
|
||||||
return "v4"
|
return "v4"
|
||||||
case FAMILY_V6:
|
case FamilyV6:
|
||||||
return "v6"
|
return "v6"
|
||||||
default:
|
default:
|
||||||
return fmt.Sprint(family)
|
return fmt.Sprint(family)
|
||||||
|
|||||||
@@ -14,20 +14,21 @@ func (n *NetLink) IsIPv6Supported() (supported bool, err error) {
|
|||||||
|
|
||||||
var totalRoutes uint
|
var totalRoutes uint
|
||||||
for _, link := range links {
|
for _, link := range links {
|
||||||
routes, err := n.RouteList(link, netlink.FAMILY_V6)
|
link := link
|
||||||
|
routes, err := n.RouteList(&link, netlink.FAMILY_V6)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("listing IPv6 routes for link %s: %w",
|
return false, fmt.Errorf("listing IPv6 routes for link %s: %w",
|
||||||
link.Attrs().Name, err)
|
link.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check each route for IPv6 due to Podman bug listing IPv4 routes
|
// Check each route for IPv6 due to Podman bug listing IPv4 routes
|
||||||
// as IPv6 routes at container start, see:
|
// as IPv6 routes at container start, see:
|
||||||
// https://github.com/qdm12/gluetun/issues/1241#issuecomment-1333405949
|
// https://github.com/qdm12/gluetun/issues/1241#issuecomment-1333405949
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
sourceIsIPv6 := route.Src != nil && route.Src.To4() == nil
|
sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6()
|
||||||
destinationIsIPv6 := route.Dst != nil && route.Dst.IP.To4() == nil
|
destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
|
||||||
if sourceIsIPv6 || destinationIsIPv6 {
|
if sourceIsIPv6 || destinationIsIPv6 {
|
||||||
n.debugLogger.Debugf("IPv6 is supported by link %s", link.Attrs().Name)
|
n.debugLogger.Debugf("IPv6 is supported by link %s", link.Name)
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
totalRoutes++
|
totalRoutes++
|
||||||
|
|||||||
@@ -2,36 +2,117 @@ package netlink
|
|||||||
|
|
||||||
import "github.com/vishvananda/netlink"
|
import "github.com/vishvananda/netlink"
|
||||||
|
|
||||||
type (
|
type Link struct {
|
||||||
Link = netlink.Link
|
Type string
|
||||||
Bridge = netlink.Bridge
|
Name string
|
||||||
Wireguard = netlink.Wireguard
|
Index int
|
||||||
)
|
EncapType string
|
||||||
|
MTU uint16
|
||||||
|
|
||||||
|
NetNsID int
|
||||||
|
TxQLen int
|
||||||
|
}
|
||||||
|
|
||||||
func (n *NetLink) LinkList() (links []Link, err error) {
|
func (n *NetLink) LinkList() (links []Link, err error) {
|
||||||
return netlink.LinkList()
|
netlinkLinks, err := netlink.LinkList()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
links = make([]Link, len(netlinkLinks))
|
||||||
|
for i := range netlinkLinks {
|
||||||
|
links[i] = netlinkLinkToLink(netlinkLinks[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return links, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
func (n *NetLink) LinkByName(name string) (link Link, err error) {
|
||||||
return netlink.LinkByName(name)
|
netlinkLink, err := netlink.LinkByName(name)
|
||||||
|
if err != nil {
|
||||||
|
return Link{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return netlinkLinkToLink(netlinkLink), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
|
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
|
||||||
return netlink.LinkByIndex(index)
|
netlinkLink, err := netlink.LinkByIndex(index)
|
||||||
|
if err != nil {
|
||||||
|
return Link{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return netlinkLinkToLink(netlinkLink), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) LinkAdd(link Link) (err error) {
|
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
|
||||||
return netlink.LinkAdd(link)
|
netlinkLink := linkToNetlinkLink(&link)
|
||||||
|
err = netlink.LinkAdd(netlinkLink)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return netlinkLink.Attrs().Index, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) LinkDel(link Link) (err error) {
|
func (n *NetLink) LinkDel(link Link) (err error) {
|
||||||
return netlink.LinkDel(link)
|
return netlink.LinkDel(linkToNetlinkLink(&link))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) LinkSetUp(link Link) (err error) {
|
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
|
||||||
return netlink.LinkSetUp(link)
|
netlinkLink := linkToNetlinkLink(&link)
|
||||||
|
err = netlink.LinkSetUp(netlinkLink)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return netlinkLink.Attrs().Index, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) LinkSetDown(link Link) (err error) {
|
func (n *NetLink) LinkSetDown(link Link) (err error) {
|
||||||
return netlink.LinkSetDown(link)
|
return netlink.LinkSetDown(linkToNetlinkLink(&link))
|
||||||
|
}
|
||||||
|
|
||||||
|
type netlinkLinkImpl struct {
|
||||||
|
attrs *netlink.LinkAttrs
|
||||||
|
linkType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *netlinkLinkImpl) Attrs() *netlink.LinkAttrs {
|
||||||
|
return n.attrs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *netlinkLinkImpl) Type() string {
|
||||||
|
return n.linkType
|
||||||
|
}
|
||||||
|
|
||||||
|
func netlinkLinkToLink(netlinkLink netlink.Link) Link {
|
||||||
|
attributes := netlinkLink.Attrs()
|
||||||
|
return Link{
|
||||||
|
Type: netlinkLink.Type(),
|
||||||
|
Name: attributes.Name,
|
||||||
|
Index: attributes.Index,
|
||||||
|
EncapType: attributes.EncapType,
|
||||||
|
MTU: uint16(attributes.MTU),
|
||||||
|
NetNsID: attributes.NetNsID,
|
||||||
|
TxQLen: attributes.TxQLen,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warning: we must return `netlink.Link` and not `netlinkLinkImpl`
|
||||||
|
// so that the vishvananda/netlink package can compare the returned
|
||||||
|
// value against an untyped nil.
|
||||||
|
func linkToNetlinkLink(link *Link) netlink.Link {
|
||||||
|
if link == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &netlinkLinkImpl{
|
||||||
|
linkType: link.Type,
|
||||||
|
attrs: &netlink.LinkAttrs{ // TODO get all original attributes
|
||||||
|
Name: link.Name,
|
||||||
|
Index: link.Index,
|
||||||
|
EncapType: link.EncapType,
|
||||||
|
MTU: int(link.MTU),
|
||||||
|
NetNsID: link.NetNsID,
|
||||||
|
TxQLen: link.TxQLen,
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
package netlink
|
|
||||||
|
|
||||||
import "github.com/vishvananda/netlink"
|
|
||||||
|
|
||||||
type LinkAttrs = netlink.LinkAttrs
|
|
||||||
|
|
||||||
func NewLinkAttrs() LinkAttrs {
|
|
||||||
return netlink.NewLinkAttrs()
|
|
||||||
}
|
|
||||||
@@ -1,22 +1,74 @@
|
|||||||
package netlink
|
package netlink
|
||||||
|
|
||||||
import "github.com/vishvananda/netlink"
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
type Route = netlink.Route
|
"github.com/vishvananda/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
func (n *NetLink) RouteList(link Link, family int) (
|
type Route struct {
|
||||||
|
LinkIndex int
|
||||||
|
Dst netip.Prefix
|
||||||
|
Src netip.Addr
|
||||||
|
Gw netip.Addr
|
||||||
|
Priority int
|
||||||
|
Family int
|
||||||
|
Table int
|
||||||
|
Type int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NetLink) RouteList(link *Link, family int) (
|
||||||
routes []Route, err error) {
|
routes []Route, err error) {
|
||||||
return netlink.RouteList(link, family)
|
netlinkLink := linkToNetlinkLink(link)
|
||||||
|
netlinkRoutes, err := netlink.RouteList(netlinkLink, family)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
routes = make([]Route, len(netlinkRoutes))
|
||||||
|
for i := range netlinkRoutes {
|
||||||
|
routes[i] = netlinkRouteToRoute(netlinkRoutes[i])
|
||||||
|
}
|
||||||
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) RouteAdd(route *Route) error {
|
func (n *NetLink) RouteAdd(route Route) error {
|
||||||
return netlink.RouteAdd(route)
|
netlinkRoute := routeToNetlinkRoute(route)
|
||||||
|
return netlink.RouteAdd(&netlinkRoute)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) RouteDel(route *Route) error {
|
func (n *NetLink) RouteDel(route Route) error {
|
||||||
return netlink.RouteDel(route)
|
netlinkRoute := routeToNetlinkRoute(route)
|
||||||
|
return netlink.RouteDel(&netlinkRoute)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) RouteReplace(route *Route) error {
|
func (n *NetLink) RouteReplace(route Route) error {
|
||||||
return netlink.RouteReplace(route)
|
netlinkRoute := routeToNetlinkRoute(route)
|
||||||
|
return netlink.RouteReplace(&netlinkRoute)
|
||||||
|
}
|
||||||
|
|
||||||
|
func netlinkRouteToRoute(netlinkRoute netlink.Route) (route Route) {
|
||||||
|
return Route{
|
||||||
|
LinkIndex: netlinkRoute.LinkIndex,
|
||||||
|
Dst: netIPNetToNetipPrefix(netlinkRoute.Dst),
|
||||||
|
Src: netIPToNetipAddress(netlinkRoute.Src),
|
||||||
|
Gw: netIPToNetipAddress(netlinkRoute.Gw),
|
||||||
|
Priority: netlinkRoute.Priority,
|
||||||
|
Family: netlinkRoute.Family,
|
||||||
|
Table: netlinkRoute.Table,
|
||||||
|
Type: netlinkRoute.Type,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func routeToNetlinkRoute(route Route) (netlinkRoute netlink.Route) {
|
||||||
|
return netlink.Route{
|
||||||
|
LinkIndex: route.LinkIndex,
|
||||||
|
Dst: netipPrefixToIPNet(route.Dst),
|
||||||
|
Src: netipAddrToNetIP(route.Src),
|
||||||
|
Gw: netipAddrToNetIP(route.Gw),
|
||||||
|
Priority: route.Priority,
|
||||||
|
Family: route.Family,
|
||||||
|
Table: route.Table,
|
||||||
|
Type: route.Type,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,90 @@
|
|||||||
package netlink
|
package netlink
|
||||||
|
|
||||||
import "github.com/vishvananda/netlink"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
type Rule = netlink.Rule
|
"github.com/vishvananda/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
func NewRule() *Rule {
|
type Rule struct {
|
||||||
return netlink.NewRule()
|
Priority int
|
||||||
|
Family int
|
||||||
|
Table int
|
||||||
|
Mark int
|
||||||
|
Src netip.Prefix
|
||||||
|
Dst netip.Prefix
|
||||||
|
Invert bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Rule) String() string {
|
||||||
|
from := "all"
|
||||||
|
if r.Src.IsValid() {
|
||||||
|
from = r.Src.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
to := "all"
|
||||||
|
if r.Dst.IsValid() {
|
||||||
|
to = r.Dst.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("ip rule %d: from %s to %s table %d",
|
||||||
|
r.Priority, from, to, r.Table)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRule() Rule {
|
||||||
|
// defaults found from netlink.NewRule() for fields we use,
|
||||||
|
// the rest of the defaults is set when converting from a `Rule`
|
||||||
|
// to a `netlink.Rule`
|
||||||
|
return Rule{
|
||||||
|
Priority: -1,
|
||||||
|
Mark: -1,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
|
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
|
||||||
return netlink.RuleList(family)
|
netlinkRules, err := netlink.RuleList(family)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rules = make([]Rule, len(netlinkRules))
|
||||||
|
for i := range netlinkRules {
|
||||||
|
rules[i] = netlinkRuleToRule(netlinkRules[i])
|
||||||
|
}
|
||||||
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) RuleAdd(rule *Rule) error {
|
func (n *NetLink) RuleAdd(rule Rule) error {
|
||||||
return netlink.RuleAdd(rule)
|
netlinkRule := ruleToNetlinkRule(rule)
|
||||||
|
return netlink.RuleAdd(&netlinkRule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *NetLink) RuleDel(rule *Rule) error {
|
func (n *NetLink) RuleDel(rule Rule) error {
|
||||||
return netlink.RuleDel(rule)
|
netlinkRule := ruleToNetlinkRule(rule)
|
||||||
|
return netlink.RuleDel(&netlinkRule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ruleToNetlinkRule(rule Rule) (netlinkRule netlink.Rule) {
|
||||||
|
netlinkRule = *netlink.NewRule()
|
||||||
|
netlinkRule.Priority = rule.Priority
|
||||||
|
netlinkRule.Family = rule.Family
|
||||||
|
netlinkRule.Table = rule.Table
|
||||||
|
netlinkRule.Mark = rule.Mark
|
||||||
|
netlinkRule.Src = netipPrefixToIPNet(rule.Src)
|
||||||
|
netlinkRule.Dst = netipPrefixToIPNet(rule.Dst)
|
||||||
|
netlinkRule.Invert = rule.Invert
|
||||||
|
return netlinkRule
|
||||||
|
}
|
||||||
|
|
||||||
|
func netlinkRuleToRule(netlinkRule netlink.Rule) (rule Rule) {
|
||||||
|
return Rule{
|
||||||
|
Priority: netlinkRule.Priority,
|
||||||
|
Family: netlinkRule.Family,
|
||||||
|
Table: netlinkRule.Table,
|
||||||
|
Mark: netlinkRule.Mark,
|
||||||
|
Src: netIPNetToNetipPrefix(netlinkRule.Src),
|
||||||
|
Dst: netIPNetToNetipPrefix(netlinkRule.Dst),
|
||||||
|
Invert: netlinkRule.Invert,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,34 +6,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NetipPrefixToIPNet(prefix *netip.Prefix) (ipNet *net.IPNet) {
|
|
||||||
if prefix == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
s := prefix.String()
|
|
||||||
ip, ipNet, err := net.ParseCIDR(s)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
ipNet.IP = ip
|
|
||||||
return ipNet
|
|
||||||
}
|
|
||||||
|
|
||||||
func netIPNetToNetipPrefix(ipNet net.IPNet) (prefix netip.Prefix) {
|
|
||||||
if len(ipNet.IP) != net.IPv4len && len(ipNet.IP) != net.IPv6len {
|
|
||||||
return prefix
|
|
||||||
}
|
|
||||||
var ip netip.Addr
|
|
||||||
if ipv4 := ipNet.IP.To4(); ipv4 != nil {
|
|
||||||
ip = netip.AddrFrom4([4]byte(ipv4))
|
|
||||||
} else {
|
|
||||||
ip = netip.AddrFrom16([16]byte(ipNet.IP))
|
|
||||||
}
|
|
||||||
bits, _ := ipNet.Mask.Size()
|
|
||||||
return netip.PrefixFrom(ip, bits)
|
|
||||||
}
|
|
||||||
|
|
||||||
func netIPToNetipAddress(ip net.IP) (address netip.Addr) {
|
func netIPToNetipAddress(ip net.IP) (address netip.Addr) {
|
||||||
address, ok := netip.AddrFromSlice(ip)
|
address, ok := netip.AddrFromSlice(ip)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
@@ -8,54 +8,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_netIPNetToNetipPrefix(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
testCases := map[string]struct {
|
|
||||||
ipNet net.IPNet
|
|
||||||
prefix netip.Prefix
|
|
||||||
}{
|
|
||||||
"empty ipnet": {},
|
|
||||||
"custom sized IP in ipnet": {
|
|
||||||
ipNet: net.IPNet{
|
|
||||||
IP: net.IP{1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"IPv4 ipnet": {
|
|
||||||
ipNet: net.IPNet{
|
|
||||||
IP: net.IP{1, 2, 3, 4},
|
|
||||||
Mask: net.IPMask{255, 255, 255, 0},
|
|
||||||
},
|
|
||||||
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
|
|
||||||
},
|
|
||||||
"IPv4-in-IPv6 ipnet": {
|
|
||||||
ipNet: net.IPNet{
|
|
||||||
IP: net.IPv4(1, 2, 3, 4),
|
|
||||||
Mask: net.IPMask{255, 255, 255, 0},
|
|
||||||
},
|
|
||||||
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
|
|
||||||
},
|
|
||||||
"IPv6 ipnet": {
|
|
||||||
ipNet: net.IPNet{
|
|
||||||
IP: net.IPv6loopback,
|
|
||||||
Mask: net.IPMask{0xff},
|
|
||||||
},
|
|
||||||
prefix: netip.PrefixFrom(netip.IPv6Loopback(), 8),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, testCase := range testCases {
|
|
||||||
testCase := testCase
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
prefix := netIPNetToNetipPrefix(testCase.ipNet)
|
|
||||||
|
|
||||||
assert.Equal(t, testCase.prefix, prefix)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_netIPToNetipAddress(t *testing.T) {
|
func Test_netIPToNetipAddress(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@@ -25,17 +25,17 @@ func (d DefaultRoute) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
|
func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
|
||||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("listing routes: %w", err)
|
return nil, fmt.Errorf("listing routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.Dst != nil {
|
if route.Dst.IsValid() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defaultRoute := DefaultRoute{
|
defaultRoute := DefaultRoute{
|
||||||
Gateway: netIPToNetipAddress(route.Gw),
|
Gateway: route.Gw,
|
||||||
Family: route.Family,
|
Family: route.Family,
|
||||||
}
|
}
|
||||||
linkIndex := route.LinkIndex
|
linkIndex := route.LinkIndex
|
||||||
@@ -43,11 +43,10 @@ func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("obtaining link by index: for default route at index %d: %w", linkIndex, err)
|
return nil, fmt.Errorf("obtaining link by index: for default route at index %d: %w", linkIndex, err)
|
||||||
}
|
}
|
||||||
attributes := link.Attrs()
|
defaultRoute.NetInterface = link.Name
|
||||||
defaultRoute.NetInterface = attributes.Name
|
family := netlink.FamilyV6
|
||||||
family := netlink.FAMILY_V6
|
if route.Gw.Is4() {
|
||||||
if route.Gw.To4() != nil {
|
family = netlink.FamilyV4
|
||||||
family = netlink.FAMILY_V4
|
|
||||||
}
|
}
|
||||||
defaultRoute.AssignedIP, err = r.assignedIP(defaultRoute.NetInterface, family)
|
defaultRoute.AssignedIP, err = r.assignedIP(defaultRoute.NetInterface, family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err err
|
|||||||
|
|
||||||
for _, defaultRoute := range defaultRoutes {
|
for _, defaultRoute := range defaultRoutes {
|
||||||
defaultDestination := defaultDestinationIPv4
|
defaultDestination := defaultDestinationIPv4
|
||||||
if defaultRoute.Family == netlink.FAMILY_V6 {
|
if defaultRoute.Family == netlink.FamilyV6 {
|
||||||
defaultDestination = defaultDestinationIPv6
|
defaultDestination = defaultDestinationIPv6
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e
|
|||||||
|
|
||||||
for _, defaultRoute := range defaultRoutes {
|
for _, defaultRoute := range defaultRoutes {
|
||||||
defaultDestination := defaultDestinationIPv4
|
defaultDestination := defaultDestinationIPv4
|
||||||
if defaultRoute.Family == netlink.FAMILY_V6 {
|
if defaultRoute.Family == netlink.FamilyV6 {
|
||||||
defaultDestination = defaultDestinationIPv6
|
defaultDestination = defaultDestinationIPv6
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,8 +68,8 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
|
|||||||
bits = 128
|
bits = 128
|
||||||
}
|
}
|
||||||
defaultIPMasked := netip.PrefixFrom(assignedIP, bits)
|
defaultIPMasked := netip.PrefixFrom(assignedIP, bits)
|
||||||
ruleDstNet := (*netip.Prefix)(nil)
|
ruleDstNet := netip.Prefix{}
|
||||||
err = r.addIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority)
|
err = r.addIPRule(defaultIPMasked, ruleDstNet, table, inboundPriority)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("adding rule for default route %s: %w", defaultRoute, err)
|
return fmt.Errorf("adding rule for default route %s: %w", defaultRoute, err)
|
||||||
}
|
}
|
||||||
@@ -86,8 +86,8 @@ func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
|
|||||||
bits = 128
|
bits = 128
|
||||||
}
|
}
|
||||||
defaultIPMasked := netip.PrefixFrom(assignedIP, bits)
|
defaultIPMasked := netip.PrefixFrom(assignedIP, bits)
|
||||||
ruleDstNet := (*netip.Prefix)(nil)
|
ruleDstNet := netip.Prefix{}
|
||||||
err = r.deleteIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority)
|
err = r.deleteIPRule(defaultIPMasked, ruleDstNet, table, inboundPriority)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("deleting rule for default route %s: %w", defaultRoute, err)
|
return fmt.Errorf("deleting rule for default route %s: %w", defaultRoute, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func ipMatchesFamily(ip netip.Addr, family int) bool {
|
func ipMatchesFamily(ip netip.Addr, family int) bool {
|
||||||
return (family == netlink.FAMILY_V4 && ip.Is4()) ||
|
return (family == netlink.FamilyV4 && ip.Is4()) ||
|
||||||
(family == netlink.FAMILY_V6 && ip.Is6())
|
(family == netlink.FamilyV6 && ip.Is6())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) assignedIP(interfaceName string, family int) (ip netip.Addr, err error) {
|
func (r *Routing) assignedIP(interfaceName string, family int) (ip netip.Addr, err error) {
|
||||||
|
|||||||
@@ -29,25 +29,25 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
|
|||||||
localLinks := make(map[int]struct{})
|
localLinks := make(map[int]struct{})
|
||||||
|
|
||||||
for _, link := range links {
|
for _, link := range links {
|
||||||
if link.Attrs().EncapType != "ether" {
|
if link.EncapType != "ether" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
localLinks[link.Attrs().Index] = struct{}{}
|
localLinks[link.Index] = struct{}{}
|
||||||
r.logger.Info("local ethernet link found: " + link.Attrs().Name)
|
r.logger.Info("local ethernet link found: " + link.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(localLinks) == 0 {
|
if len(localLinks) == 0 {
|
||||||
return localNetworks, fmt.Errorf("%w: in %d links", ErrLinkLocalNotFound, len(links))
|
return localNetworks, fmt.Errorf("%w: in %d links", ErrLinkLocalNotFound, len(links))
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return localNetworks, fmt.Errorf("listing routes: %w", err)
|
return localNetworks, fmt.Errorf("listing routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.Gw != nil || route.Dst == nil {
|
if route.Gw.IsValid() || !route.Dst.IsValid() {
|
||||||
continue
|
continue
|
||||||
} else if _, ok := localLinks[route.LinkIndex]; !ok {
|
} else if _, ok := localLinks[route.LinkIndex]; !ok {
|
||||||
continue
|
continue
|
||||||
@@ -55,7 +55,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
|
|||||||
|
|
||||||
var localNet LocalNetwork
|
var localNet LocalNetwork
|
||||||
|
|
||||||
localNet.IPNet = netIPNetToNetipPrefix(*route.Dst)
|
localNet.IPNet = route.Dst
|
||||||
r.logger.Info("local ipnet found: " + localNet.IPNet.String())
|
r.logger.Info("local ipnet found: " + localNet.IPNet.String())
|
||||||
|
|
||||||
link, err := r.netLinker.LinkByIndex(route.LinkIndex)
|
link, err := r.netLinker.LinkByIndex(route.LinkIndex)
|
||||||
@@ -63,11 +63,11 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
|
|||||||
return localNetworks, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
|
return localNetworks, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
localNet.InterfaceName = link.Attrs().Name
|
localNet.InterfaceName = link.Name
|
||||||
|
|
||||||
family := netlink.FAMILY_V6
|
family := netlink.FamilyV6
|
||||||
if localNet.IPNet.Addr().Is4() {
|
if localNet.IPNet.Addr().Is4() {
|
||||||
family = netlink.FAMILY_V4
|
family = netlink.FamilyV4
|
||||||
}
|
}
|
||||||
ip, err := r.assignedIP(localNet.InterfaceName, family)
|
ip, err := r.assignedIP(localNet.InterfaceName, family)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -96,7 +96,8 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) {
|
|||||||
const localPriority = 98
|
const localPriority = 98
|
||||||
|
|
||||||
// Main table was setup correctly by Docker, just need to add rules to use it
|
// Main table was setup correctly by Docker, just need to add rules to use it
|
||||||
err = r.addIPRule(nil, &subnet.IPNet, mainTable, localPriority)
|
src := netip.Prefix{}
|
||||||
|
err = r.addIPRule(src, subnet.IPNet, mainTable, localPriority)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("adding rule: %v: %w", subnet.IPNet, err)
|
return fmt.Errorf("adding rule: %v: %w", subnet.IPNet, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
netlink "github.com/vishvananda/netlink"
|
netlink "github.com/qdm12/gluetun/internal/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockNetLinker is a mock of NetLinker interface.
|
// MockNetLinker is a mock of NetLinker interface.
|
||||||
@@ -50,7 +50,7 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddrReplace mocks base method.
|
// AddrReplace mocks base method.
|
||||||
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 *netlink.Addr) error {
|
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
|
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -79,11 +79,12 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LinkAdd mocks base method.
|
// LinkAdd mocks base method.
|
||||||
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error {
|
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "LinkAdd", arg0)
|
ret := m.ctrl.Call(m, "LinkAdd", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(int)
|
||||||
return ret0
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// LinkAdd indicates an expected call of LinkAdd.
|
// LinkAdd indicates an expected call of LinkAdd.
|
||||||
@@ -166,11 +167,12 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LinkSetUp mocks base method.
|
// LinkSetUp mocks base method.
|
||||||
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error {
|
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
|
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(int)
|
||||||
return ret0
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// LinkSetUp indicates an expected call of LinkSetUp.
|
// LinkSetUp indicates an expected call of LinkSetUp.
|
||||||
@@ -180,7 +182,7 @@ func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RouteAdd mocks base method.
|
// RouteAdd mocks base method.
|
||||||
func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error {
|
func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RouteAdd", arg0)
|
ret := m.ctrl.Call(m, "RouteAdd", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -194,7 +196,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RouteDel mocks base method.
|
// RouteDel mocks base method.
|
||||||
func (m *MockNetLinker) RouteDel(arg0 *netlink.Route) error {
|
func (m *MockNetLinker) RouteDel(arg0 netlink.Route) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RouteDel", arg0)
|
ret := m.ctrl.Call(m, "RouteDel", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -208,7 +210,7 @@ func (mr *MockNetLinkerMockRecorder) RouteDel(arg0 interface{}) *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RouteList mocks base method.
|
// RouteList mocks base method.
|
||||||
func (m *MockNetLinker) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) {
|
func (m *MockNetLinker) RouteList(arg0 *netlink.Link, arg1 int) ([]netlink.Route, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RouteList", arg0, arg1)
|
ret := m.ctrl.Call(m, "RouteList", arg0, arg1)
|
||||||
ret0, _ := ret[0].([]netlink.Route)
|
ret0, _ := ret[0].([]netlink.Route)
|
||||||
@@ -223,7 +225,7 @@ func (mr *MockNetLinkerMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RouteReplace mocks base method.
|
// RouteReplace mocks base method.
|
||||||
func (m *MockNetLinker) RouteReplace(arg0 *netlink.Route) error {
|
func (m *MockNetLinker) RouteReplace(arg0 netlink.Route) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RouteReplace", arg0)
|
ret := m.ctrl.Call(m, "RouteReplace", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -237,7 +239,7 @@ func (mr *MockNetLinkerMockRecorder) RouteReplace(arg0 interface{}) *gomock.Call
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RuleAdd mocks base method.
|
// RuleAdd mocks base method.
|
||||||
func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error {
|
func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RuleAdd", arg0)
|
ret := m.ctrl.Call(m, "RuleAdd", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -251,7 +253,7 @@ func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RuleDel mocks base method.
|
// RuleDel mocks base method.
|
||||||
func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error {
|
func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RuleDel", arg0)
|
ret := m.ctrl.Call(m, "RuleDel", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
|
|||||||
@@ -56,8 +56,8 @@ func (r *Routing) removeOutboundSubnets(subnets []netip.Prefix,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleSrcNet := (*netip.Prefix)(nil)
|
ruleSrcNet := netip.Prefix{}
|
||||||
ruleDstNet := &subnets[i]
|
ruleDstNet := subnets[i]
|
||||||
err := r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
|
err := r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
warnings = append(warnings,
|
warnings = append(warnings,
|
||||||
@@ -81,8 +81,8 @@ func (r *Routing) addOutboundSubnets(subnets []netip.Prefix,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleSrcNet := (*netip.Prefix)(nil)
|
ruleSrcNet := netip.Prefix{}
|
||||||
ruleDstNet := &subnets[i]
|
ruleDstNet := subnets[i]
|
||||||
err = r.addIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
|
err = r.addIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("adding rule: for subnet %s: %w", subnet, err)
|
return fmt.Errorf("adding rule: for subnet %s: %w", subnet, err)
|
||||||
|
|||||||
@@ -23,12 +23,12 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
route := netlink.Route{
|
route := netlink.Route{
|
||||||
Dst: NetipPrefixToIPNet(&destination),
|
Dst: destination,
|
||||||
Gw: gateway.AsSlice(),
|
Gw: gateway,
|
||||||
LinkIndex: link.Attrs().Index,
|
LinkIndex: link.Index,
|
||||||
Table: table,
|
Table: table,
|
||||||
}
|
}
|
||||||
if err := r.netLinker.RouteReplace(&route); err != nil {
|
if err := r.netLinker.RouteReplace(route); err != nil {
|
||||||
return fmt.Errorf("replacing route for subnet %s at interface %s: %w",
|
return fmt.Errorf("replacing route for subnet %s at interface %s: %w",
|
||||||
destinationStr, iface, err)
|
destinationStr, iface, err)
|
||||||
}
|
}
|
||||||
@@ -51,12 +51,12 @@ func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
route := netlink.Route{
|
route := netlink.Route{
|
||||||
Dst: NetipPrefixToIPNet(&destination),
|
Dst: destination,
|
||||||
Gw: gateway.AsSlice(),
|
Gw: gateway,
|
||||||
LinkIndex: link.Attrs().Index,
|
LinkIndex: link.Index,
|
||||||
Table: table,
|
Table: table,
|
||||||
}
|
}
|
||||||
if err := r.netLinker.RouteDel(&route); err != nil {
|
if err := r.netLinker.RouteDel(route); err != nil {
|
||||||
return fmt.Errorf("deleting route: for subnet %s at interface %s: %w",
|
return fmt.Errorf("deleting route: for subnet %s at interface %s: %w",
|
||||||
destinationStr, iface, err)
|
destinationStr, iface, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,30 +18,30 @@ type NetLinker interface {
|
|||||||
type Addresser interface {
|
type Addresser interface {
|
||||||
AddrList(link netlink.Link, family int) (
|
AddrList(link netlink.Link, family int) (
|
||||||
addresses []netlink.Addr, err error)
|
addresses []netlink.Addr, err error)
|
||||||
AddrReplace(link netlink.Link, addr *netlink.Addr) error
|
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Router interface {
|
type Router interface {
|
||||||
RouteList(link netlink.Link, family int) (
|
RouteList(link *netlink.Link, family int) (
|
||||||
routes []netlink.Route, err error)
|
routes []netlink.Route, err error)
|
||||||
RouteAdd(route *netlink.Route) error
|
RouteAdd(route netlink.Route) error
|
||||||
RouteDel(route *netlink.Route) error
|
RouteDel(route netlink.Route) error
|
||||||
RouteReplace(route *netlink.Route) error
|
RouteReplace(route netlink.Route) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Ruler interface {
|
type Ruler interface {
|
||||||
RuleList(family int) (rules []netlink.Rule, err error)
|
RuleList(family int) (rules []netlink.Rule, err error)
|
||||||
RuleAdd(rule *netlink.Rule) error
|
RuleAdd(rule netlink.Rule) error
|
||||||
RuleDel(rule *netlink.Rule) error
|
RuleDel(rule netlink.Rule) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Linker interface {
|
type Linker interface {
|
||||||
LinkList() (links []netlink.Link, err error)
|
LinkList() (links []netlink.Link, err error)
|
||||||
LinkByName(name string) (link netlink.Link, err error)
|
LinkByName(name string) (link netlink.Link, err error)
|
||||||
LinkByIndex(index int) (link netlink.Link, err error)
|
LinkByIndex(index int) (link netlink.Link, err error)
|
||||||
LinkAdd(link netlink.Link) (err error)
|
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
||||||
LinkDel(link netlink.Link) (err error)
|
LinkDel(link netlink.Link) (err error)
|
||||||
LinkSetUp(link netlink.Link) (err error)
|
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||||
LinkSetDown(link netlink.Link) (err error)
|
LinkSetDown(link netlink.Link) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,30 +1,28 @@
|
|||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *Routing) addIPRule(src, dst *netip.Prefix, table, priority int) error {
|
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
|
||||||
const add = true
|
const add = true
|
||||||
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
|
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
|
||||||
|
|
||||||
rule := netlink.NewRule()
|
rule := netlink.NewRule()
|
||||||
rule.Src = NetipPrefixToIPNet(src)
|
rule.Src = src
|
||||||
rule.Dst = NetipPrefixToIPNet(dst)
|
rule.Dst = dst
|
||||||
rule.Priority = priority
|
rule.Priority = priority
|
||||||
rule.Table = table
|
rule.Table = table
|
||||||
|
|
||||||
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
|
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("listing rules: %w", err)
|
return fmt.Errorf("listing rules: %w", err)
|
||||||
}
|
}
|
||||||
for i := range existingRules {
|
for i := range existingRules {
|
||||||
if !rulesAreEqual(&existingRules[i], rule) {
|
if !rulesAreEqual(existingRules[i], rule) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return nil // already exists
|
return nil // already exists
|
||||||
@@ -36,22 +34,22 @@ func (r *Routing) addIPRule(src, dst *netip.Prefix, table, priority int) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) deleteIPRule(src, dst *netip.Prefix, table, priority int) error {
|
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error {
|
||||||
const add = false
|
const add = false
|
||||||
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
|
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
|
||||||
|
|
||||||
rule := netlink.NewRule()
|
rule := netlink.NewRule()
|
||||||
rule.Src = NetipPrefixToIPNet(src)
|
rule.Src = src
|
||||||
rule.Dst = NetipPrefixToIPNet(dst)
|
rule.Dst = dst
|
||||||
rule.Priority = priority
|
rule.Priority = priority
|
||||||
rule.Table = table
|
rule.Table = table
|
||||||
|
|
||||||
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
|
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("listing rules: %w", err)
|
return fmt.Errorf("listing rules: %w", err)
|
||||||
}
|
}
|
||||||
for i := range existingRules {
|
for i := range existingRules {
|
||||||
if !rulesAreEqual(&existingRules[i], rule) {
|
if !rulesAreEqual(existingRules[i], rule) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := r.netLinker.RuleDel(rule); err != nil {
|
if err := r.netLinker.RuleDel(rule); err != nil {
|
||||||
@@ -61,7 +59,7 @@ func (r *Routing) deleteIPRule(src, dst *netip.Prefix, table, priority int) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ruleDbgMsg(add bool, src, dst *netip.Prefix,
|
func ruleDbgMsg(add bool, src, dst netip.Prefix,
|
||||||
table, priority int) (debugMessage string) {
|
table, priority int) (debugMessage string) {
|
||||||
debugMessage = "ip rule"
|
debugMessage = "ip rule"
|
||||||
|
|
||||||
@@ -71,11 +69,11 @@ func ruleDbgMsg(add bool, src, dst *netip.Prefix,
|
|||||||
debugMessage += " del"
|
debugMessage += " del"
|
||||||
}
|
}
|
||||||
|
|
||||||
if src != nil {
|
if src.IsValid() {
|
||||||
debugMessage += " from " + src.String()
|
debugMessage += " from " + src.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
if dst != nil {
|
if dst.IsValid() {
|
||||||
debugMessage += " to " + dst.String()
|
debugMessage += " to " + dst.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,25 +88,20 @@ func ruleDbgMsg(add bool, src, dst *netip.Prefix,
|
|||||||
return debugMessage
|
return debugMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
func rulesAreEqual(a, b *netlink.Rule) bool {
|
func rulesAreEqual(a, b netlink.Rule) bool {
|
||||||
if a == nil && b == nil {
|
return ipPrefixesAreEqual(a.Src, b.Src) &&
|
||||||
return true
|
ipPrefixesAreEqual(a.Dst, b.Dst) &&
|
||||||
}
|
|
||||||
if a == nil || b == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return ipNetsAreEqual(a.Src, b.Src) &&
|
|
||||||
ipNetsAreEqual(a.Dst, b.Dst) &&
|
|
||||||
a.Priority == b.Priority &&
|
a.Priority == b.Priority &&
|
||||||
a.Table == b.Table
|
a.Table == b.Table
|
||||||
}
|
}
|
||||||
|
|
||||||
func ipNetsAreEqual(a, b *net.IPNet) bool {
|
func ipPrefixesAreEqual(a, b netip.Prefix) bool {
|
||||||
if a == nil && b == nil {
|
if !a.IsValid() && !b.IsValid() {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if a == nil || b == nil {
|
if !a.IsValid() || !b.IsValid() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return a.IP.Equal(b.IP) && bytes.Equal(a.Mask, b.Mask)
|
return a.Bits() == b.Bits() &&
|
||||||
|
a.Addr().Compare(b.Addr()) == 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package routing
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -12,17 +11,16 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func makeNetipPrefix(n byte) *netip.Prefix {
|
func makeNetipPrefix(n byte) netip.Prefix {
|
||||||
const bits = 24
|
const bits = 24
|
||||||
prefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
|
return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
|
||||||
return &prefix
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeIPRule(src, dst *netip.Prefix,
|
func makeIPRule(src, dst netip.Prefix,
|
||||||
table, priority int) *netlink.Rule {
|
table, priority int) netlink.Rule {
|
||||||
rule := netlink.NewRule()
|
rule := netlink.NewRule()
|
||||||
rule.Src = NetipPrefixToIPNet(src)
|
rule.Src = src
|
||||||
rule.Dst = NetipPrefixToIPNet(dst)
|
rule.Dst = dst
|
||||||
rule.Table = table
|
rule.Table = table
|
||||||
rule.Priority = priority
|
rule.Priority = priority
|
||||||
return rule
|
return rule
|
||||||
@@ -40,13 +38,13 @@ func Test_Routing_addIPRule(t *testing.T) {
|
|||||||
|
|
||||||
type ruleAddCall struct {
|
type ruleAddCall struct {
|
||||||
expected bool
|
expected bool
|
||||||
ruleToAdd *netlink.Rule
|
ruleToAdd netlink.Rule
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
src *netip.Prefix
|
src netip.Prefix
|
||||||
dst *netip.Prefix
|
dst netip.Prefix
|
||||||
table int
|
table int
|
||||||
priority int
|
priority int
|
||||||
dbgMsg string
|
dbgMsg string
|
||||||
@@ -69,8 +67,8 @@ func Test_Routing_addIPRule(t *testing.T) {
|
|||||||
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
|
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
|
||||||
ruleList: ruleListCall{
|
ruleList: ruleListCall{
|
||||||
rules: []netlink.Rule{
|
rules: []netlink.Rule{
|
||||||
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
|
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
|
||||||
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
|
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -95,8 +93,8 @@ func Test_Routing_addIPRule(t *testing.T) {
|
|||||||
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
|
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
|
||||||
ruleList: ruleListCall{
|
ruleList: ruleListCall{
|
||||||
rules: []netlink.Rule{
|
rules: []netlink.Rule{
|
||||||
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
|
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
|
||||||
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
|
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleAdd: ruleAddCall{
|
ruleAdd: ruleAddCall{
|
||||||
@@ -116,7 +114,7 @@ func Test_Routing_addIPRule(t *testing.T) {
|
|||||||
logger.EXPECT().Debug(testCase.dbgMsg)
|
logger.EXPECT().Debug(testCase.dbgMsg)
|
||||||
|
|
||||||
netLinker := NewMockNetLinker(ctrl)
|
netLinker := NewMockNetLinker(ctrl)
|
||||||
netLinker.EXPECT().RuleList(netlink.FAMILY_ALL).
|
netLinker.EXPECT().RuleList(netlink.FamilyAll).
|
||||||
Return(testCase.ruleList.rules, testCase.ruleList.err)
|
Return(testCase.ruleList.rules, testCase.ruleList.err)
|
||||||
if testCase.ruleAdd.expected {
|
if testCase.ruleAdd.expected {
|
||||||
netLinker.EXPECT().RuleAdd(testCase.ruleAdd.ruleToAdd).
|
netLinker.EXPECT().RuleAdd(testCase.ruleAdd.ruleToAdd).
|
||||||
@@ -153,13 +151,13 @@ func Test_Routing_deleteIPRule(t *testing.T) {
|
|||||||
|
|
||||||
type ruleDelCall struct {
|
type ruleDelCall struct {
|
||||||
expected bool
|
expected bool
|
||||||
ruleToDel *netlink.Rule
|
ruleToDel netlink.Rule
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
src *netip.Prefix
|
src netip.Prefix
|
||||||
dst *netip.Prefix
|
dst netip.Prefix
|
||||||
table int
|
table int
|
||||||
priority int
|
priority int
|
||||||
dbgMsg string
|
dbgMsg string
|
||||||
@@ -182,7 +180,7 @@ func Test_Routing_deleteIPRule(t *testing.T) {
|
|||||||
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
|
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
|
||||||
ruleList: ruleListCall{
|
ruleList: ruleListCall{
|
||||||
rules: []netlink.Rule{
|
rules: []netlink.Rule{
|
||||||
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
|
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleDel: ruleDelCall{
|
ruleDel: ruleDelCall{
|
||||||
@@ -200,8 +198,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
|
|||||||
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
|
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
|
||||||
ruleList: ruleListCall{
|
ruleList: ruleListCall{
|
||||||
rules: []netlink.Rule{
|
rules: []netlink.Rule{
|
||||||
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
|
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
|
||||||
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
|
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ruleDel: ruleDelCall{
|
ruleDel: ruleDelCall{
|
||||||
@@ -217,8 +215,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
|
|||||||
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
|
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
|
||||||
ruleList: ruleListCall{
|
ruleList: ruleListCall{
|
||||||
rules: []netlink.Rule{
|
rules: []netlink.Rule{
|
||||||
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
|
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
|
||||||
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
|
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -234,7 +232,7 @@ func Test_Routing_deleteIPRule(t *testing.T) {
|
|||||||
logger.EXPECT().Debug(testCase.dbgMsg)
|
logger.EXPECT().Debug(testCase.dbgMsg)
|
||||||
|
|
||||||
netLinker := NewMockNetLinker(ctrl)
|
netLinker := NewMockNetLinker(ctrl)
|
||||||
netLinker.EXPECT().RuleList(netlink.FAMILY_ALL).
|
netLinker.EXPECT().RuleList(netlink.FamilyAll).
|
||||||
Return(testCase.ruleList.rules, testCase.ruleList.err)
|
Return(testCase.ruleList.rules, testCase.ruleList.err)
|
||||||
if testCase.ruleDel.expected {
|
if testCase.ruleDel.expected {
|
||||||
netLinker.EXPECT().RuleDel(testCase.ruleDel.ruleToDel).
|
netLinker.EXPECT().RuleDel(testCase.ruleDel.ruleToDel).
|
||||||
@@ -264,8 +262,8 @@ func Test_ruleDbgMsg(t *testing.T) {
|
|||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
add bool
|
add bool
|
||||||
src *netip.Prefix
|
src netip.Prefix
|
||||||
dst *netip.Prefix
|
dst netip.Prefix
|
||||||
table int
|
table int
|
||||||
priority int
|
priority int
|
||||||
dbgMsg string
|
dbgMsg string
|
||||||
@@ -307,38 +305,79 @@ func Test_rulesAreEqual(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
a *netlink.Rule
|
a netlink.Rule
|
||||||
b *netlink.Rule
|
b netlink.Rule
|
||||||
equal bool
|
equal bool
|
||||||
}{
|
}{
|
||||||
"both nil": {
|
"both_empty": {
|
||||||
equal: true,
|
equal: true,
|
||||||
},
|
},
|
||||||
"first nil": {
|
"not_equal_by_src": {
|
||||||
b: &netlink.Rule{},
|
a: netlink.Rule{
|
||||||
},
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24),
|
||||||
"second nil": {
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
a: &netlink.Rule{},
|
|
||||||
},
|
|
||||||
"both not nil": {
|
|
||||||
a: &netlink.Rule{},
|
|
||||||
b: &netlink.Rule{},
|
|
||||||
equal: true,
|
|
||||||
},
|
|
||||||
"both equal": {
|
|
||||||
a: &netlink.Rule{
|
|
||||||
Src: &net.IPNet{
|
|
||||||
IP: net.IPv4(1, 1, 1, 1),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
Priority: 100,
|
Priority: 100,
|
||||||
Table: 101,
|
Table: 101,
|
||||||
},
|
},
|
||||||
b: &netlink.Rule{
|
b: netlink.Rule{
|
||||||
Src: &net.IPNet{
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
IP: net.IPv4(1, 1, 1, 1),
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
Priority: 100,
|
||||||
|
Table: 101,
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
"not_equal_by_dst": {
|
||||||
|
a: netlink.Rule{
|
||||||
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32),
|
||||||
|
Priority: 100,
|
||||||
|
Table: 101,
|
||||||
|
},
|
||||||
|
b: netlink.Rule{
|
||||||
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
|
Priority: 100,
|
||||||
|
Table: 101,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"not_equal_by_priority": {
|
||||||
|
a: netlink.Rule{
|
||||||
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
|
Priority: 999,
|
||||||
|
Table: 101,
|
||||||
|
},
|
||||||
|
b: netlink.Rule{
|
||||||
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
|
Priority: 100,
|
||||||
|
Table: 101,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"not_equal_by_table": {
|
||||||
|
a: netlink.Rule{
|
||||||
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
|
Priority: 100,
|
||||||
|
Table: 999,
|
||||||
|
},
|
||||||
|
b: netlink.Rule{
|
||||||
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
|
Priority: 100,
|
||||||
|
Table: 101,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"equal": {
|
||||||
|
a: netlink.Rule{
|
||||||
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
|
Priority: 100,
|
||||||
|
Table: 101,
|
||||||
|
},
|
||||||
|
b: netlink.Rule{
|
||||||
|
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
|
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
Priority: 100,
|
Priority: 100,
|
||||||
Table: 101,
|
Table: 101,
|
||||||
},
|
},
|
||||||
@@ -358,58 +397,39 @@ func Test_rulesAreEqual(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_ipNetsAreEqual(t *testing.T) {
|
func Test_ipPrefixesAreEqual(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
a *net.IPNet
|
a netip.Prefix
|
||||||
b *net.IPNet
|
b netip.Prefix
|
||||||
equal bool
|
equal bool
|
||||||
}{
|
}{
|
||||||
"both nil": {
|
"both_not_valid": {
|
||||||
equal: true,
|
equal: true,
|
||||||
},
|
},
|
||||||
"first nil": {
|
"first_not_valid": {
|
||||||
b: &net.IPNet{},
|
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
},
|
},
|
||||||
"second nil": {
|
"second_not_valid": {
|
||||||
a: &net.IPNet{},
|
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
},
|
},
|
||||||
"both not nil": {
|
"both_equal": {
|
||||||
a: &net.IPNet{},
|
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
b: &net.IPNet{},
|
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
equal: true,
|
equal: true,
|
||||||
},
|
},
|
||||||
"both equal": {
|
"both_not_equal_by_IP": {
|
||||||
a: &net.IPNet{
|
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
IP: net.IPv4(1, 1, 1, 1),
|
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 24),
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
},
|
||||||
b: &net.IPNet{
|
"both_not_equal_by_bits": {
|
||||||
IP: net.IPv4(1, 1, 1, 1),
|
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
|
||||||
},
|
|
||||||
equal: true,
|
|
||||||
},
|
|
||||||
"both not equal by IP": {
|
|
||||||
a: &net.IPNet{
|
|
||||||
IP: net.IPv4(1, 1, 1, 1),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
b: &net.IPNet{
|
|
||||||
IP: net.IPv4(2, 2, 2, 2),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"both not equal by mask": {
|
|
||||||
a: &net.IPNet{
|
|
||||||
IP: net.IPv4(1, 1, 1, 1),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 255),
|
|
||||||
},
|
|
||||||
b: &net.IPNet{
|
|
||||||
IP: net.IPv4(1, 1, 1, 1),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 0, 0),
|
|
||||||
},
|
},
|
||||||
|
"both_not_equal_by_IP_and_bits": {
|
||||||
|
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
|
||||||
|
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -418,7 +438,7 @@ func Test_ipNetsAreEqual(t *testing.T) {
|
|||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
equal := ipNetsAreEqual(testCase.a, testCase.b)
|
equal := ipPrefixesAreEqual(testCase.a, testCase.b)
|
||||||
|
|
||||||
assert.Equal(t, testCase.equal, equal)
|
assert.Equal(t, testCase.equal, equal)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
@@ -16,14 +14,14 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) {
|
func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) {
|
||||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ip, fmt.Errorf("listing routes: %w", err)
|
return ip, fmt.Errorf("listing routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultLinkIndex := -1
|
defaultLinkIndex := -1
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.Dst == nil {
|
if !route.Dst.IsValid() {
|
||||||
defaultLinkIndex = route.LinkIndex
|
defaultLinkIndex = route.LinkIndex
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -34,17 +32,17 @@ func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) {
|
|||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.LinkIndex == defaultLinkIndex &&
|
if route.LinkIndex == defaultLinkIndex &&
|
||||||
route.Dst != nil &&
|
route.Dst.IsValid() &&
|
||||||
!IPIsPrivate(netIPToNetipAddress(route.Dst.IP)) &&
|
!IPIsPrivate(route.Dst.Addr()) &&
|
||||||
bytes.Equal(route.Dst.Mask, net.IPMask{255, 255, 255, 255}) {
|
route.Dst.IsSingleIP() {
|
||||||
return netIPToNetipAddress(route.Dst.IP), nil
|
return route.Dst.Addr(), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ip, fmt.Errorf("%w: in %d routes", ErrVPNDestinationIPNotFound, len(routes))
|
return ip, fmt.Errorf("%w: in %d routes", ErrVPNDestinationIPNotFound, len(routes))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
|
func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
|
||||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ip, fmt.Errorf("listing routes: %w", err)
|
return ip, fmt.Errorf("listing routes: %w", err)
|
||||||
}
|
}
|
||||||
@@ -53,11 +51,11 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return ip, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
|
return ip, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
|
||||||
}
|
}
|
||||||
interfaceName := link.Attrs().Name
|
interfaceName := link.Name
|
||||||
if interfaceName == vpnIntf &&
|
if interfaceName == vpnIntf &&
|
||||||
route.Dst != nil &&
|
route.Dst.IsValid() &&
|
||||||
route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) {
|
route.Dst.Addr().IsUnspecified() {
|
||||||
return netIPToNetipAddress(route.Gw), nil
|
return route.Gw, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes))
|
return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes))
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ type Storage interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type NetLinker interface {
|
type NetLinker interface {
|
||||||
AddrReplace(link netlink.Link, addr *netlink.Addr) error
|
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
||||||
Router
|
Router
|
||||||
Ruler
|
Ruler
|
||||||
Linker
|
Linker
|
||||||
@@ -50,22 +50,22 @@ type NetLinker interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Router interface {
|
type Router interface {
|
||||||
RouteList(link netlink.Link, family int) (
|
RouteList(link *netlink.Link, family int) (
|
||||||
routes []netlink.Route, err error)
|
routes []netlink.Route, err error)
|
||||||
RouteAdd(route *netlink.Route) error
|
RouteAdd(route netlink.Route) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Ruler interface {
|
type Ruler interface {
|
||||||
RuleAdd(rule *netlink.Rule) error
|
RuleAdd(rule netlink.Rule) error
|
||||||
RuleDel(rule *netlink.Rule) error
|
RuleDel(rule netlink.Rule) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Linker interface {
|
type Linker interface {
|
||||||
LinkList() (links []netlink.Link, err error)
|
LinkList() (links []netlink.Link, err error)
|
||||||
LinkByName(name string) (link netlink.Link, err error)
|
LinkByName(name string) (link netlink.Link, err error)
|
||||||
LinkAdd(link netlink.Link) (err error)
|
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
||||||
LinkDel(link netlink.Link) (err error)
|
LinkDel(link netlink.Link) (err error)
|
||||||
LinkSetUp(link netlink.Link) (err error)
|
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||||
LinkSetDown(link netlink.Link) (err error)
|
LinkSetDown(link netlink.Link) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
"github.com/qdm12/gluetun/internal/routing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (w *Wireguard) addAddresses(link netlink.Link,
|
func (w *Wireguard) addAddresses(link netlink.Link,
|
||||||
@@ -15,15 +14,14 @@ func (w *Wireguard) addAddresses(link netlink.Link,
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ipNet := ipNet
|
address := netlink.Addr{
|
||||||
address := &netlink.Addr{
|
Network: ipNet,
|
||||||
IPNet: routing.NetipPrefixToIPNet(&ipNet),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = w.netlink.AddrReplace(link, address)
|
err = w.netlink.AddrReplace(link, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: when adding address %s to link %s",
|
return fmt.Errorf("%w: when adding address %s to link %s",
|
||||||
err, address, link.Attrs().Name)
|
err, address, link.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
"github.com/qdm12/gluetun/internal/routing"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -18,14 +17,6 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
|
ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
|
||||||
ipNetTwo := netip.PrefixFrom(netip.MustParseAddr("::1234"), 64)
|
ipNetTwo := netip.PrefixFrom(netip.MustParseAddr("::1234"), 64)
|
||||||
|
|
||||||
newLink := func() netlink.Link {
|
|
||||||
linkAttrs := netlink.NewLinkAttrs()
|
|
||||||
linkAttrs.Name = "a_bridge"
|
|
||||||
return &netlink.Bridge{
|
|
||||||
LinkAttrs: linkAttrs,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
errDummy := errors.New("dummy")
|
errDummy := errors.New("dummy")
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
@@ -35,15 +26,15 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
"success": {
|
"success": {
|
||||||
link: newLink(),
|
link: netlink.Link{Type: "wireguard"},
|
||||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||||
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
||||||
netLinker := NewMockNetLinker(ctrl)
|
netLinker := NewMockNetLinker(ctrl)
|
||||||
firstCall := netLinker.EXPECT().
|
firstCall := netLinker.EXPECT().
|
||||||
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}).
|
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
netLinker.EXPECT().
|
netLinker.EXPECT().
|
||||||
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}).
|
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
|
||||||
Return(nil).After(firstCall)
|
Return(nil).After(firstCall)
|
||||||
return &Wireguard{
|
return &Wireguard{
|
||||||
netlink: netLinker,
|
netlink: netLinker,
|
||||||
@@ -54,12 +45,12 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"first add error": {
|
"first add error": {
|
||||||
link: newLink(),
|
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
|
||||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||||
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
||||||
netLinker := NewMockNetLinker(ctrl)
|
netLinker := NewMockNetLinker(ctrl)
|
||||||
netLinker.EXPECT().
|
netLinker.EXPECT().
|
||||||
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}).
|
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
|
||||||
Return(errDummy)
|
Return(errDummy)
|
||||||
return &Wireguard{
|
return &Wireguard{
|
||||||
netlink: netLinker,
|
netlink: netLinker,
|
||||||
@@ -71,15 +62,15 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
|
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
|
||||||
},
|
},
|
||||||
"second add error": {
|
"second add error": {
|
||||||
link: newLink(),
|
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
|
||||||
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
|
||||||
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
||||||
netLinker := NewMockNetLinker(ctrl)
|
netLinker := NewMockNetLinker(ctrl)
|
||||||
firstCall := netLinker.EXPECT().
|
firstCall := netLinker.EXPECT().
|
||||||
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}).
|
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
|
||||||
Return(nil)
|
Return(nil)
|
||||||
netLinker.EXPECT().
|
netLinker.EXPECT().
|
||||||
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}).
|
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
|
||||||
Return(errDummy).After(firstCall)
|
Return(errDummy).After(firstCall)
|
||||||
return &Wireguard{
|
return &Wireguard{
|
||||||
netlink: netLinker,
|
netlink: netLinker,
|
||||||
@@ -91,7 +82,6 @@ func Test_Wireguard_addAddresses(t *testing.T) {
|
|||||||
err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"),
|
err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"),
|
||||||
},
|
},
|
||||||
"ignore IPv6": {
|
"ignore IPv6": {
|
||||||
link: newLink(),
|
|
||||||
addrs: []netip.Prefix{ipNetTwo},
|
addrs: []netip.Prefix{ipNetTwo},
|
||||||
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
|
||||||
return &Wireguard{
|
return &Wireguard{
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package wireguard
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -53,8 +54,14 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
|
|||||||
PublicKey: publicKey,
|
PublicKey: publicKey,
|
||||||
PresharedKey: preSharedKey,
|
PresharedKey: preSharedKey,
|
||||||
AllowedIPs: []net.IPNet{
|
AllowedIPs: []net.IPNet{
|
||||||
*allIPv4(),
|
{
|
||||||
*allIPv6(),
|
IP: net.IPv4(0, 0, 0, 0),
|
||||||
|
Mask: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
IP: net.IPv6zero,
|
||||||
|
Mask: []byte(net.IPv6zero),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
ReplaceAllowedIPs: true,
|
ReplaceAllowedIPs: true,
|
||||||
Endpoint: &net.UDPAddr{
|
Endpoint: &net.UDPAddr{
|
||||||
@@ -68,16 +75,12 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
|
|||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func allIPv4() (ipNet *net.IPNet) {
|
func allIPv4() (prefix netip.Prefix) {
|
||||||
return &net.IPNet{
|
const bits = 0
|
||||||
IP: net.IPv4(0, 0, 0, 0),
|
return netip.PrefixFrom(netip.IPv4Unspecified(), bits)
|
||||||
Mask: []byte{0, 0, 0, 0},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func allIPv6() (ipNet *net.IPNet) {
|
func allIPv6() (prefix netip.Prefix) {
|
||||||
return &net.IPNet{
|
const bits = 0
|
||||||
IP: net.IPv6zero,
|
return netip.PrefixFrom(netip.IPv6Unspecified(), bits)
|
||||||
Mask: []byte(net.IPv6zero),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,10 +24,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
|||||||
|
|
||||||
netlinker := netlink.New(&noopDebugLogger{})
|
netlinker := netlink.New(&noopDebugLogger{})
|
||||||
|
|
||||||
linkAttrs := netlink.NewLinkAttrs()
|
link := netlink.Link{
|
||||||
linkAttrs.Name = "test_8081"
|
Type: "bridge",
|
||||||
link := &netlink.Bridge{
|
Name: "test_8081",
|
||||||
LinkAttrs: linkAttrs,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove any previously created test interface from a crashed/panic
|
// Remove any previously created test interface from a crashed/panic
|
||||||
@@ -37,8 +36,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = netlinker.LinkAdd(link)
|
linkIndex, err := netlinker.LinkAdd(link)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
link.Index = linkIndex
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = netlinker.LinkDel(link)
|
err = netlinker.LinkDel(link)
|
||||||
@@ -63,14 +63,12 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
|
|||||||
err = wg.addAddresses(link, addresses)
|
err = wg.addAddresses(link, addresses)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
netlinkAddresses, err := netlinker.AddrList(link, netlink.FAMILY_ALL)
|
netlinkAddresses, err := netlinker.AddrList(link, netlink.FamilyAll)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, len(addresses), len(netlinkAddresses))
|
require.Equal(t, len(addresses), len(netlinkAddresses))
|
||||||
for i, netlinkAddress := range netlinkAddresses {
|
for i, netlinkAddress := range netlinkAddresses {
|
||||||
require.NotNil(t, netlinkAddress.IPNet)
|
require.NotNil(t, netlinkAddress.Network)
|
||||||
ipNet, err := netip.ParsePrefix(netlinkAddress.IPNet.String())
|
assert.Equal(t, addresses[i], netlinkAddress.Network)
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, addresses[i], ipNet)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -95,7 +93,7 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
rules, err := netlinker.RuleList(netlink.FAMILY_V4)
|
rules, err := netlinker.RuleList(netlink.FamilyV4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
var rule netlink.Rule
|
var rule netlink.Rule
|
||||||
var ruleFound bool
|
var ruleFound bool
|
||||||
@@ -111,11 +109,6 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
|
|||||||
Priority: rulePriority,
|
Priority: rulePriority,
|
||||||
Mark: firewallMark,
|
Mark: firewallMark,
|
||||||
Table: firewallMark,
|
Table: firewallMark,
|
||||||
Mask: 4294967295,
|
|
||||||
Goto: -1,
|
|
||||||
Flow: -1,
|
|
||||||
SuppressIfgroup: -1,
|
|
||||||
SuppressPrefixlen: -1,
|
|
||||||
}
|
}
|
||||||
assert.Equal(t, expectedRule, rule)
|
assert.Equal(t, expectedRule, rule)
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import "github.com/qdm12/gluetun/internal/netlink"
|
|||||||
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
|
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
|
||||||
|
|
||||||
type NetLinker interface {
|
type NetLinker interface {
|
||||||
AddrReplace(link netlink.Link, addr *netlink.Addr) error
|
AddrReplace(link netlink.Link, addr netlink.Addr) error
|
||||||
Router
|
Router
|
||||||
Ruler
|
Ruler
|
||||||
Linker
|
Linker
|
||||||
@@ -13,21 +13,21 @@ type NetLinker interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Router interface {
|
type Router interface {
|
||||||
RouteList(link netlink.Link, family int) (
|
RouteList(link *netlink.Link, family int) (
|
||||||
routes []netlink.Route, err error)
|
routes []netlink.Route, err error)
|
||||||
RouteAdd(route *netlink.Route) error
|
RouteAdd(route netlink.Route) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Ruler interface {
|
type Ruler interface {
|
||||||
RuleAdd(rule *netlink.Rule) error
|
RuleAdd(rule netlink.Rule) error
|
||||||
RuleDel(rule *netlink.Rule) error
|
RuleDel(rule netlink.Rule) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Linker interface {
|
type Linker interface {
|
||||||
LinkAdd(link netlink.Link) (err error)
|
LinkAdd(link netlink.Link) (linkIndex int, err error)
|
||||||
LinkList() (links []netlink.Link, err error)
|
LinkList() (links []netlink.Link, err error)
|
||||||
LinkByName(name string) (link netlink.Link, err error)
|
LinkByName(name string) (link netlink.Link, err error)
|
||||||
LinkSetUp(link netlink.Link) error
|
LinkSetUp(link netlink.Link) (linkIndex int, err error)
|
||||||
LinkSetDown(link netlink.Link) error
|
LinkSetDown(link netlink.Link) error
|
||||||
LinkDel(link netlink.Link) error
|
LinkDel(link netlink.Link) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
netlink "github.com/vishvananda/netlink"
|
netlink "github.com/qdm12/gluetun/internal/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockNetLinker is a mock of NetLinker interface.
|
// MockNetLinker is a mock of NetLinker interface.
|
||||||
@@ -35,7 +35,7 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddrReplace mocks base method.
|
// AddrReplace mocks base method.
|
||||||
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 *netlink.Addr) error {
|
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
|
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -64,11 +64,12 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LinkAdd mocks base method.
|
// LinkAdd mocks base method.
|
||||||
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error {
|
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "LinkAdd", arg0)
|
ret := m.ctrl.Call(m, "LinkAdd", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(int)
|
||||||
return ret0
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// LinkAdd indicates an expected call of LinkAdd.
|
// LinkAdd indicates an expected call of LinkAdd.
|
||||||
@@ -136,11 +137,12 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LinkSetUp mocks base method.
|
// LinkSetUp mocks base method.
|
||||||
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error {
|
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
|
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(int)
|
||||||
return ret0
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// LinkSetUp indicates an expected call of LinkSetUp.
|
// LinkSetUp indicates an expected call of LinkSetUp.
|
||||||
@@ -150,7 +152,7 @@ func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RouteAdd mocks base method.
|
// RouteAdd mocks base method.
|
||||||
func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error {
|
func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RouteAdd", arg0)
|
ret := m.ctrl.Call(m, "RouteAdd", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -164,7 +166,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RouteList mocks base method.
|
// RouteList mocks base method.
|
||||||
func (m *MockNetLinker) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) {
|
func (m *MockNetLinker) RouteList(arg0 *netlink.Link, arg1 int) ([]netlink.Route, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RouteList", arg0, arg1)
|
ret := m.ctrl.Call(m, "RouteList", arg0, arg1)
|
||||||
ret0, _ := ret[0].([]netlink.Route)
|
ret0, _ := ret[0].([]netlink.Route)
|
||||||
@@ -179,7 +181,7 @@ func (mr *MockNetLinkerMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RuleAdd mocks base method.
|
// RuleAdd mocks base method.
|
||||||
func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error {
|
func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RuleAdd", arg0)
|
ret := m.ctrl.Call(m, "RuleAdd", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
@@ -193,7 +195,7 @@ func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RuleDel mocks base method.
|
// RuleDel mocks base method.
|
||||||
func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error {
|
func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RuleDel", arg0)
|
ret := m.ctrl.Call(m, "RuleDel", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
|
|||||||
@@ -2,17 +2,17 @@ package wireguard
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/netlink"
|
"github.com/qdm12/gluetun/internal/netlink"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO add IPv6 route if IPv6 is supported
|
// TODO add IPv6 route if IPv6 is supported
|
||||||
|
|
||||||
func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet,
|
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
|
||||||
firewallMark int) (err error) {
|
firewallMark int) (err error) {
|
||||||
route := &netlink.Route{
|
route := netlink.Route{
|
||||||
LinkIndex: link.Attrs().Index,
|
LinkIndex: link.Index,
|
||||||
Dst: dst,
|
Dst: dst,
|
||||||
Table: firewallMark,
|
Table: firewallMark,
|
||||||
}
|
}
|
||||||
@@ -21,7 +21,7 @@ func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"adding route for link %s, destination %s and table %d: %w",
|
"adding route for link %s, destination %s and table %d: %w",
|
||||||
link.Attrs().Name, dst, firewallMark, err)
|
link.Name, dst, firewallMark, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package wireguard
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
@@ -15,41 +15,40 @@ func Test_Wireguard_addRoute(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
const linkIndex = 88
|
const linkIndex = 88
|
||||||
newLink := func() netlink.Link {
|
|
||||||
linkAttrs := netlink.NewLinkAttrs()
|
ipPrefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
|
||||||
linkAttrs.Name = "a_bridge"
|
|
||||||
linkAttrs.Index = linkIndex
|
|
||||||
return &netlink.Bridge{
|
|
||||||
LinkAttrs: linkAttrs,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ipNet := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
|
|
||||||
const firewallMark = 51820
|
const firewallMark = 51820
|
||||||
|
|
||||||
errDummy := errors.New("dummy")
|
errDummy := errors.New("dummy")
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
link netlink.Link
|
link netlink.Link
|
||||||
dst *net.IPNet
|
dst netip.Prefix
|
||||||
expectedRoute *netlink.Route
|
expectedRoute netlink.Route
|
||||||
routeAddErr error
|
routeAddErr error
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
"success": {
|
"success": {
|
||||||
link: newLink(),
|
link: netlink.Link{
|
||||||
dst: ipNet,
|
Index: linkIndex,
|
||||||
expectedRoute: &netlink.Route{
|
},
|
||||||
|
dst: ipPrefix,
|
||||||
|
expectedRoute: netlink.Route{
|
||||||
LinkIndex: linkIndex,
|
LinkIndex: linkIndex,
|
||||||
Dst: ipNet,
|
Dst: ipPrefix,
|
||||||
Table: firewallMark,
|
Table: firewallMark,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"route add error": {
|
"route add error": {
|
||||||
link: newLink(),
|
link: netlink.Link{
|
||||||
dst: ipNet,
|
Name: "a_bridge",
|
||||||
expectedRoute: &netlink.Route{
|
Index: linkIndex,
|
||||||
|
},
|
||||||
|
dst: ipPrefix,
|
||||||
|
expectedRoute: netlink.Route{
|
||||||
LinkIndex: linkIndex,
|
LinkIndex: linkIndex,
|
||||||
Dst: ipNet,
|
Dst: ipPrefix,
|
||||||
Table: firewallMark,
|
Table: firewallMark,
|
||||||
},
|
},
|
||||||
routeAddErr: errDummy,
|
routeAddErr: errDummy,
|
||||||
|
|||||||
@@ -21,53 +21,38 @@ func Test_Wireguard_addRule(t *testing.T) {
|
|||||||
errDummy := errors.New("dummy")
|
errDummy := errors.New("dummy")
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
expectedRule *netlink.Rule
|
expectedRule netlink.Rule
|
||||||
ruleAddErr error
|
ruleAddErr error
|
||||||
err error
|
err error
|
||||||
ruleDelErr error
|
ruleDelErr error
|
||||||
cleanupErr error
|
cleanupErr error
|
||||||
}{
|
}{
|
||||||
"success": {
|
"success": {
|
||||||
expectedRule: &netlink.Rule{
|
expectedRule: netlink.Rule{
|
||||||
Invert: true,
|
Invert: true,
|
||||||
Priority: rulePriority,
|
Priority: rulePriority,
|
||||||
Mark: firewallMark,
|
Mark: firewallMark,
|
||||||
Table: firewallMark,
|
Table: firewallMark,
|
||||||
Mask: -1,
|
|
||||||
Goto: -1,
|
|
||||||
Flow: -1,
|
|
||||||
SuppressIfgroup: -1,
|
|
||||||
SuppressPrefixlen: -1,
|
|
||||||
Family: family,
|
Family: family,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"rule add error": {
|
"rule add error": {
|
||||||
expectedRule: &netlink.Rule{
|
expectedRule: netlink.Rule{
|
||||||
Invert: true,
|
Invert: true,
|
||||||
Priority: rulePriority,
|
Priority: rulePriority,
|
||||||
Mark: firewallMark,
|
Mark: firewallMark,
|
||||||
Table: firewallMark,
|
Table: firewallMark,
|
||||||
Mask: -1,
|
|
||||||
Goto: -1,
|
|
||||||
Flow: -1,
|
|
||||||
SuppressIfgroup: -1,
|
|
||||||
SuppressPrefixlen: -1,
|
|
||||||
Family: family,
|
Family: family,
|
||||||
},
|
},
|
||||||
ruleAddErr: errDummy,
|
ruleAddErr: errDummy,
|
||||||
err: errors.New("adding rule ip rule 987: from all to all table 456: dummy"),
|
err: errors.New("adding rule ip rule 987: from all to all table 456: dummy"),
|
||||||
},
|
},
|
||||||
"rule delete error": {
|
"rule delete error": {
|
||||||
expectedRule: &netlink.Rule{
|
expectedRule: netlink.Rule{
|
||||||
Invert: true,
|
Invert: true,
|
||||||
Priority: rulePriority,
|
Priority: rulePriority,
|
||||||
Mark: firewallMark,
|
Mark: firewallMark,
|
||||||
Table: firewallMark,
|
Table: firewallMark,
|
||||||
Mask: -1,
|
|
||||||
Goto: -1,
|
|
||||||
Flow: -1,
|
|
||||||
SuppressIfgroup: -1,
|
|
||||||
SuppressPrefixlen: -1,
|
|
||||||
Family: family,
|
Family: family,
|
||||||
},
|
},
|
||||||
ruleDelErr: errDummy,
|
ruleDelErr: errDummy,
|
||||||
|
|||||||
@@ -93,10 +93,12 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := w.netlink.LinkSetUp(link); err != nil {
|
linkIndex, err := w.netlink.LinkSetUp(link)
|
||||||
|
if err != nil {
|
||||||
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
|
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
link.Index = linkIndex
|
||||||
closers.add("shutting down link", stepFour, func() error {
|
closers.add("shutting down link", stepFour, func() error {
|
||||||
return w.netlink.LinkSetDown(link)
|
return w.netlink.LinkSetDown(link)
|
||||||
})
|
})
|
||||||
@@ -161,17 +163,16 @@ func setupKernelSpace(ctx context.Context,
|
|||||||
interfaceName string, netLinker NetLinker, mtu uint16,
|
interfaceName string, netLinker NetLinker, mtu uint16,
|
||||||
closers *closers, logger Logger) (
|
closers *closers, logger Logger) (
|
||||||
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
|
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
|
||||||
linkAttrs := netlink.LinkAttrs{
|
link = netlink.Link{
|
||||||
|
Type: "wireguard",
|
||||||
Name: interfaceName,
|
Name: interfaceName,
|
||||||
MTU: int(mtu),
|
MTU: mtu,
|
||||||
}
|
}
|
||||||
link = &netlink.Wireguard{
|
linkIndex, err := netLinker.LinkAdd(link)
|
||||||
LinkAttrs: linkAttrs,
|
|
||||||
}
|
|
||||||
err = netLinker.LinkAdd(link)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
|
return link, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
|
||||||
}
|
}
|
||||||
|
link.Index = linkIndex
|
||||||
closers.add("deleting link", stepFive, func() error {
|
closers.add("deleting link", stepFive, func() error {
|
||||||
return netLinker.LinkDel(link)
|
return netLinker.LinkDel(link)
|
||||||
})
|
})
|
||||||
@@ -191,22 +192,22 @@ func setupUserSpace(ctx context.Context,
|
|||||||
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
|
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
|
||||||
tun, err := tun.CreateTUN(interfaceName, int(mtu))
|
tun, err := tun.CreateTUN(interfaceName, int(mtu))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
|
return link, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
closers.add("closing TUN device", stepSeven, tun.Close)
|
closers.add("closing TUN device", stepSeven, tun.Close)
|
||||||
|
|
||||||
tunName, err := tun.Name()
|
tunName, err := tun.Name()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
|
return link, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
|
||||||
} else if tunName != interfaceName {
|
} else if tunName != interfaceName {
|
||||||
return nil, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
|
return link, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
|
||||||
ErrCreateTun, interfaceName, tunName)
|
ErrCreateTun, interfaceName, tunName)
|
||||||
}
|
}
|
||||||
|
|
||||||
link, err = netLinker.LinkByName(interfaceName)
|
link, err = netLinker.LinkByName(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
|
return link, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
|
||||||
}
|
}
|
||||||
closers.add("deleting link", stepFive, func() error {
|
closers.add("deleting link", stepFive, func() error {
|
||||||
return netLinker.LinkDel(link)
|
return netLinker.LinkDel(link)
|
||||||
@@ -226,14 +227,14 @@ func setupUserSpace(ctx context.Context,
|
|||||||
|
|
||||||
uapiFile, err := ipc.UAPIOpen(interfaceName)
|
uapiFile, err := ipc.UAPIOpen(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
|
return link, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
closers.add("closing UAPI file", stepThree, uapiFile.Close)
|
closers.add("closing UAPI file", stepThree, uapiFile.Close)
|
||||||
|
|
||||||
uapiListener, err := ipc.UAPIListen(interfaceName, uapiFile)
|
uapiListener, err := ipc.UAPIListen(interfaceName, uapiFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
|
return link, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
|
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)
|
||||||
|
|||||||
Reference in New Issue
Block a user