chore(netlink): define own types with minimal fields

- Allow to swap `github.com/vishvananda/netlink`
- Allow to add build tags for each platform
- One step closer to development on non-Linux platforms
This commit is contained in:
Quentin McGaw
2023-05-29 06:44:58 +00:00
parent 163ac48ce4
commit 38ddcfa756
34 changed files with 828 additions and 493 deletions

View File

@@ -531,30 +531,30 @@ type netLinker interface {
type Addresser interface {
AddrList(link netlink.Link, family int) (
addresses []netlink.Addr, err error)
AddrReplace(link netlink.Link, addr *netlink.Addr) error
AddrReplace(link netlink.Link, addr netlink.Addr) error
}
type Router interface {
RouteList(link netlink.Link, family int) (
RouteList(link *netlink.Link, family int) (
routes []netlink.Route, err error)
RouteAdd(route *netlink.Route) error
RouteDel(route *netlink.Route) error
RouteReplace(route *netlink.Route) error
RouteAdd(route netlink.Route) error
RouteDel(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {
RuleList(family int) (rules []netlink.Rule, err error)
RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
type Linker interface {
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index int) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
}

View File

@@ -1,14 +1,40 @@
package netlink
import "github.com/vishvananda/netlink"
import (
"net/netip"
type Addr = netlink.Addr
"github.com/vishvananda/netlink"
)
type Addr struct {
Network netip.Prefix
}
func (a Addr) String() string {
return a.Network.String()
}
func (n *NetLink) AddrList(link Link, family int) (
addresses []Addr, err error) {
return netlink.AddrList(link, family)
netlinkLink := linkToNetlinkLink(&link)
netlinkAddresses, err := netlink.AddrList(netlinkLink, family)
if err != nil {
return nil, err
}
func (n *NetLink) AddrReplace(link Link, addr *Addr) error {
return netlink.AddrReplace(link, addr)
addresses = make([]Addr, len(netlinkAddresses))
for i := range netlinkAddresses {
addresses[i].Network = netIPNetToNetipPrefix(netlinkAddresses[i].IPNet)
}
return addresses, nil
}
func (n *NetLink) AddrReplace(link Link, addr Addr) error {
netlinkLink := linkToNetlinkLink(&link)
netlinkAddress := netlink.Addr{
IPNet: netipPrefixToIPNet(addr.Network),
}
return netlink.AddrReplace(netlinkLink, &netlinkAddress)
}

View File

@@ -0,0 +1,62 @@
package netlink
import (
"fmt"
"net"
"net/netip"
)
func netipPrefixToIPNet(prefix netip.Prefix) (ipNet *net.IPNet) {
if !prefix.IsValid() {
return nil
}
prefixAddr := prefix.Addr().Unmap()
ipMask := net.CIDRMask(prefix.Bits(), prefixAddr.BitLen())
ip := netipAddrToNetIP(prefixAddr)
return &net.IPNet{
IP: ip,
Mask: ipMask,
}
}
func netIPNetToNetipPrefix(ipNet *net.IPNet) (prefix netip.Prefix) {
if ipNet == nil || (len(ipNet.IP) != net.IPv4len && len(ipNet.IP) != net.IPv6len) {
return prefix
}
var ip netip.Addr
if ipv4 := ipNet.IP.To4(); ipv4 != nil {
ip = netip.AddrFrom4([4]byte(ipv4))
} else {
ip = netip.AddrFrom16([16]byte(ipNet.IP))
}
bits, _ := ipNet.Mask.Size()
return netip.PrefixFrom(ip, bits)
}
func netipAddrToNetIP(address netip.Addr) (ip net.IP) {
switch {
case !address.IsValid():
return nil
case address.Is4() || address.Is4In6():
bytes := address.As4()
return net.IP(bytes[:])
default:
bytes := address.As16()
return net.IP(bytes[:])
}
}
func netIPToNetipAddress(ip net.IP) (address netip.Addr) {
if len(ip) != net.IPv4len && len(ip) != net.IPv6len {
return address // invalid
}
address, ok := netip.AddrFromSlice(ip)
if !ok {
panic(fmt.Sprintf("converting %#v to netip.Addr failed", ip))
}
return address.Unmap()
}

View File

@@ -0,0 +1,146 @@
package netlink
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_netipPrefixToIPNet(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
prefix netip.Prefix
ipNet *net.IPNet
}{
"empty_prefix": {},
"IPv4_prefix": {
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
ipNet: &net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.IPv4Mask(255, 255, 255, 0),
},
},
"IPv4-in-IPv6_prefix": {
prefix: netip.PrefixFrom(netip.AddrFrom16(
[16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 1, 2, 3, 4}),
24),
ipNet: &net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.IPv4Mask(255, 255, 255, 0),
},
},
"IPv6_prefix": {
prefix: netip.PrefixFrom(netip.IPv6Loopback(), 8),
ipNet: &net.IPNet{
IP: net.IPv6loopback,
Mask: net.IPMask{0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ipNet := netipPrefixToIPNet(testCase.prefix)
assert.Equal(t, testCase.ipNet, ipNet)
})
}
}
func Test_netIPNetToNetipPrefix(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
ipNet *net.IPNet
prefix netip.Prefix
}{
"empty ipnet": {},
"custom sized IP in ipnet": {
ipNet: &net.IPNet{
IP: net.IP{1},
},
},
"IPv4 ipnet": {
ipNet: &net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.IPMask{255, 255, 255, 0},
},
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
"IPv4-in-IPv6 ipnet": {
ipNet: &net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
},
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
"IPv6 ipnet": {
ipNet: &net.IPNet{
IP: net.IPv6loopback,
Mask: net.IPMask{0xff},
},
prefix: netip.PrefixFrom(netip.IPv6Loopback(), 8),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
prefix := netIPNetToNetipPrefix(testCase.ipNet)
assert.Equal(t, testCase.prefix, prefix)
})
}
}
func Test_netIPToNetipAddress(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
ip net.IP
address netip.Addr
panicMessage string
}{
"nil_ip": {},
"ip_not_ipv4_or_ipv6": {
ip: net.IP{1},
},
"IPv4": {
ip: net.IPv4(1, 2, 3, 4),
address: netip.AddrFrom4([4]byte{1, 2, 3, 4}),
},
"IPv6": {
ip: net.IPv6zero,
address: netip.AddrFrom16([16]byte{}),
},
"IPv4 prefixed with 0xffff": {
ip: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 1, 2, 3, 4},
address: netip.AddrFrom4([4]byte{1, 2, 3, 4}),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
if testCase.panicMessage != "" {
assert.PanicsWithValue(t, testCase.panicMessage, func() {
netIPToNetipAddress(testCase.ip)
})
return
}
address := netIPToNetipAddress(testCase.ip)
assert.Equal(t, testCase.address, address)
})
}
}

View File

@@ -6,20 +6,19 @@ import (
"github.com/vishvananda/netlink"
)
//nolint:revive
const (
FAMILY_ALL = netlink.FAMILY_ALL
FAMILY_V4 = netlink.FAMILY_V4
FAMILY_V6 = netlink.FAMILY_V6
FamilyAll = 0
FamilyV4 = 2
FamilyV6 = 10
)
func FamilyToString(family int) string {
switch family {
case FAMILY_ALL:
return "all"
case FAMILY_V4:
case FamilyAll:
return "all" //nolint:goconst
case FamilyV4:
return "v4"
case FAMILY_V6:
case FamilyV6:
return "v6"
default:
return fmt.Sprint(family)

View File

@@ -14,20 +14,21 @@ func (n *NetLink) IsIPv6Supported() (supported bool, err error) {
var totalRoutes uint
for _, link := range links {
routes, err := n.RouteList(link, netlink.FAMILY_V6)
link := link
routes, err := n.RouteList(&link, netlink.FAMILY_V6)
if err != nil {
return false, fmt.Errorf("listing IPv6 routes for link %s: %w",
link.Attrs().Name, err)
link.Name, err)
}
// Check each route for IPv6 due to Podman bug listing IPv4 routes
// as IPv6 routes at container start, see:
// https://github.com/qdm12/gluetun/issues/1241#issuecomment-1333405949
for _, route := range routes {
sourceIsIPv6 := route.Src != nil && route.Src.To4() == nil
destinationIsIPv6 := route.Dst != nil && route.Dst.IP.To4() == nil
sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6()
destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
if sourceIsIPv6 || destinationIsIPv6 {
n.debugLogger.Debugf("IPv6 is supported by link %s", link.Attrs().Name)
n.debugLogger.Debugf("IPv6 is supported by link %s", link.Name)
return true, nil
}
totalRoutes++

View File

@@ -2,36 +2,117 @@ package netlink
import "github.com/vishvananda/netlink"
type (
Link = netlink.Link
Bridge = netlink.Bridge
Wireguard = netlink.Wireguard
)
type Link struct {
Type string
Name string
Index int
EncapType string
MTU uint16
NetNsID int
TxQLen int
}
func (n *NetLink) LinkList() (links []Link, err error) {
return netlink.LinkList()
netlinkLinks, err := netlink.LinkList()
if err != nil {
return nil, err
}
links = make([]Link, len(netlinkLinks))
for i := range netlinkLinks {
links[i] = netlinkLinkToLink(netlinkLinks[i])
}
return links, nil
}
func (n *NetLink) LinkByName(name string) (link Link, err error) {
return netlink.LinkByName(name)
netlinkLink, err := netlink.LinkByName(name)
if err != nil {
return Link{}, err
}
return netlinkLinkToLink(netlinkLink), nil
}
func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
return netlink.LinkByIndex(index)
netlinkLink, err := netlink.LinkByIndex(index)
if err != nil {
return Link{}, err
}
func (n *NetLink) LinkAdd(link Link) (err error) {
return netlink.LinkAdd(link)
return netlinkLinkToLink(netlinkLink), nil
}
func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
netlinkLink := linkToNetlinkLink(&link)
err = netlink.LinkAdd(netlinkLink)
if err != nil {
return 0, err
}
return netlinkLink.Attrs().Index, nil
}
func (n *NetLink) LinkDel(link Link) (err error) {
return netlink.LinkDel(link)
return netlink.LinkDel(linkToNetlinkLink(&link))
}
func (n *NetLink) LinkSetUp(link Link) (err error) {
return netlink.LinkSetUp(link)
func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
netlinkLink := linkToNetlinkLink(&link)
err = netlink.LinkSetUp(netlinkLink)
if err != nil {
return 0, err
}
return netlinkLink.Attrs().Index, nil
}
func (n *NetLink) LinkSetDown(link Link) (err error) {
return netlink.LinkSetDown(link)
return netlink.LinkSetDown(linkToNetlinkLink(&link))
}
type netlinkLinkImpl struct {
attrs *netlink.LinkAttrs
linkType string
}
func (n *netlinkLinkImpl) Attrs() *netlink.LinkAttrs {
return n.attrs
}
func (n *netlinkLinkImpl) Type() string {
return n.linkType
}
func netlinkLinkToLink(netlinkLink netlink.Link) Link {
attributes := netlinkLink.Attrs()
return Link{
Type: netlinkLink.Type(),
Name: attributes.Name,
Index: attributes.Index,
EncapType: attributes.EncapType,
MTU: uint16(attributes.MTU),
NetNsID: attributes.NetNsID,
TxQLen: attributes.TxQLen,
}
}
// Warning: we must return `netlink.Link` and not `netlinkLinkImpl`
// so that the vishvananda/netlink package can compare the returned
// value against an untyped nil.
func linkToNetlinkLink(link *Link) netlink.Link {
if link == nil {
return nil
}
return &netlinkLinkImpl{
linkType: link.Type,
attrs: &netlink.LinkAttrs{ // TODO get all original attributes
Name: link.Name,
Index: link.Index,
EncapType: link.EncapType,
MTU: int(link.MTU),
NetNsID: link.NetNsID,
TxQLen: link.TxQLen,
},
}
}

View File

@@ -1,9 +0,0 @@
package netlink
import "github.com/vishvananda/netlink"
type LinkAttrs = netlink.LinkAttrs
func NewLinkAttrs() LinkAttrs {
return netlink.NewLinkAttrs()
}

View File

@@ -1,22 +1,74 @@
package netlink
import "github.com/vishvananda/netlink"
import (
"net/netip"
type Route = netlink.Route
"github.com/vishvananda/netlink"
)
func (n *NetLink) RouteList(link Link, family int) (
type Route struct {
LinkIndex int
Dst netip.Prefix
Src netip.Addr
Gw netip.Addr
Priority int
Family int
Table int
Type int
}
func (n *NetLink) RouteList(link *Link, family int) (
routes []Route, err error) {
return netlink.RouteList(link, family)
netlinkLink := linkToNetlinkLink(link)
netlinkRoutes, err := netlink.RouteList(netlinkLink, family)
if err != nil {
return nil, err
}
func (n *NetLink) RouteAdd(route *Route) error {
return netlink.RouteAdd(route)
routes = make([]Route, len(netlinkRoutes))
for i := range netlinkRoutes {
routes[i] = netlinkRouteToRoute(netlinkRoutes[i])
}
return routes, nil
}
func (n *NetLink) RouteDel(route *Route) error {
return netlink.RouteDel(route)
func (n *NetLink) RouteAdd(route Route) error {
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteAdd(&netlinkRoute)
}
func (n *NetLink) RouteReplace(route *Route) error {
return netlink.RouteReplace(route)
func (n *NetLink) RouteDel(route Route) error {
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteDel(&netlinkRoute)
}
func (n *NetLink) RouteReplace(route Route) error {
netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteReplace(&netlinkRoute)
}
func netlinkRouteToRoute(netlinkRoute netlink.Route) (route Route) {
return Route{
LinkIndex: netlinkRoute.LinkIndex,
Dst: netIPNetToNetipPrefix(netlinkRoute.Dst),
Src: netIPToNetipAddress(netlinkRoute.Src),
Gw: netIPToNetipAddress(netlinkRoute.Gw),
Priority: netlinkRoute.Priority,
Family: netlinkRoute.Family,
Table: netlinkRoute.Table,
Type: netlinkRoute.Type,
}
}
func routeToNetlinkRoute(route Route) (netlinkRoute netlink.Route) {
return netlink.Route{
LinkIndex: route.LinkIndex,
Dst: netipPrefixToIPNet(route.Dst),
Src: netipAddrToNetIP(route.Src),
Gw: netipAddrToNetIP(route.Gw),
Priority: route.Priority,
Family: route.Family,
Table: route.Table,
Type: route.Type,
}
}

View File

@@ -1,21 +1,90 @@
package netlink
import "github.com/vishvananda/netlink"
import (
"fmt"
"net/netip"
type Rule = netlink.Rule
"github.com/vishvananda/netlink"
)
func NewRule() *Rule {
return netlink.NewRule()
type Rule struct {
Priority int
Family int
Table int
Mark int
Src netip.Prefix
Dst netip.Prefix
Invert bool
}
func (r Rule) String() string {
from := "all"
if r.Src.IsValid() {
from = r.Src.String()
}
to := "all"
if r.Dst.IsValid() {
to = r.Dst.String()
}
return fmt.Sprintf("ip rule %d: from %s to %s table %d",
r.Priority, from, to, r.Table)
}
func NewRule() Rule {
// defaults found from netlink.NewRule() for fields we use,
// the rest of the defaults is set when converting from a `Rule`
// to a `netlink.Rule`
return Rule{
Priority: -1,
Mark: -1,
}
}
func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
return netlink.RuleList(family)
netlinkRules, err := netlink.RuleList(family)
if err != nil {
return nil, err
}
func (n *NetLink) RuleAdd(rule *Rule) error {
return netlink.RuleAdd(rule)
rules = make([]Rule, len(netlinkRules))
for i := range netlinkRules {
rules[i] = netlinkRuleToRule(netlinkRules[i])
}
return rules, nil
}
func (n *NetLink) RuleDel(rule *Rule) error {
return netlink.RuleDel(rule)
func (n *NetLink) RuleAdd(rule Rule) error {
netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleAdd(&netlinkRule)
}
func (n *NetLink) RuleDel(rule Rule) error {
netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleDel(&netlinkRule)
}
func ruleToNetlinkRule(rule Rule) (netlinkRule netlink.Rule) {
netlinkRule = *netlink.NewRule()
netlinkRule.Priority = rule.Priority
netlinkRule.Family = rule.Family
netlinkRule.Table = rule.Table
netlinkRule.Mark = rule.Mark
netlinkRule.Src = netipPrefixToIPNet(rule.Src)
netlinkRule.Dst = netipPrefixToIPNet(rule.Dst)
netlinkRule.Invert = rule.Invert
return netlinkRule
}
func netlinkRuleToRule(netlinkRule netlink.Rule) (rule Rule) {
return Rule{
Priority: netlinkRule.Priority,
Family: netlinkRule.Family,
Table: netlinkRule.Table,
Mark: netlinkRule.Mark,
Src: netIPNetToNetipPrefix(netlinkRule.Src),
Dst: netIPNetToNetipPrefix(netlinkRule.Dst),
Invert: netlinkRule.Invert,
}
}

View File

@@ -6,34 +6,6 @@ import (
"net/netip"
)
func NetipPrefixToIPNet(prefix *netip.Prefix) (ipNet *net.IPNet) {
if prefix == nil {
return nil
}
s := prefix.String()
ip, ipNet, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
ipNet.IP = ip
return ipNet
}
func netIPNetToNetipPrefix(ipNet net.IPNet) (prefix netip.Prefix) {
if len(ipNet.IP) != net.IPv4len && len(ipNet.IP) != net.IPv6len {
return prefix
}
var ip netip.Addr
if ipv4 := ipNet.IP.To4(); ipv4 != nil {
ip = netip.AddrFrom4([4]byte(ipv4))
} else {
ip = netip.AddrFrom16([16]byte(ipNet.IP))
}
bits, _ := ipNet.Mask.Size()
return netip.PrefixFrom(ip, bits)
}
func netIPToNetipAddress(ip net.IP) (address netip.Addr) {
address, ok := netip.AddrFromSlice(ip)
if !ok {

View File

@@ -8,54 +8,6 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_netIPNetToNetipPrefix(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
ipNet net.IPNet
prefix netip.Prefix
}{
"empty ipnet": {},
"custom sized IP in ipnet": {
ipNet: net.IPNet{
IP: net.IP{1},
},
},
"IPv4 ipnet": {
ipNet: net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.IPMask{255, 255, 255, 0},
},
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
"IPv4-in-IPv6 ipnet": {
ipNet: net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
},
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
"IPv6 ipnet": {
ipNet: net.IPNet{
IP: net.IPv6loopback,
Mask: net.IPMask{0xff},
},
prefix: netip.PrefixFrom(netip.IPv6Loopback(), 8),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
prefix := netIPNetToNetipPrefix(testCase.ipNet)
assert.Equal(t, testCase.prefix, prefix)
})
}
}
func Test_netIPToNetipAddress(t *testing.T) {
t.Parallel()

View File

@@ -25,17 +25,17 @@ func (d DefaultRoute) String() string {
}
func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
if err != nil {
return nil, fmt.Errorf("listing routes: %w", err)
}
for _, route := range routes {
if route.Dst != nil {
if route.Dst.IsValid() {
continue
}
defaultRoute := DefaultRoute{
Gateway: netIPToNetipAddress(route.Gw),
Gateway: route.Gw,
Family: route.Family,
}
linkIndex := route.LinkIndex
@@ -43,11 +43,10 @@ func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
if err != nil {
return nil, fmt.Errorf("obtaining link by index: for default route at index %d: %w", linkIndex, err)
}
attributes := link.Attrs()
defaultRoute.NetInterface = attributes.Name
family := netlink.FAMILY_V6
if route.Gw.To4() != nil {
family = netlink.FAMILY_V4
defaultRoute.NetInterface = link.Name
family := netlink.FamilyV6
if route.Gw.Is4() {
family = netlink.FamilyV4
}
defaultRoute.AssignedIP, err = r.assignedIP(defaultRoute.NetInterface, family)
if err != nil {

View File

@@ -23,7 +23,7 @@ func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err err
for _, defaultRoute := range defaultRoutes {
defaultDestination := defaultDestinationIPv4
if defaultRoute.Family == netlink.FAMILY_V6 {
if defaultRoute.Family == netlink.FamilyV6 {
defaultDestination = defaultDestinationIPv6
}
@@ -43,7 +43,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e
for _, defaultRoute := range defaultRoutes {
defaultDestination := defaultDestinationIPv4
if defaultRoute.Family == netlink.FAMILY_V6 {
if defaultRoute.Family == netlink.FamilyV6 {
defaultDestination = defaultDestinationIPv6
}
@@ -68,8 +68,8 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
bits = 128
}
defaultIPMasked := netip.PrefixFrom(assignedIP, bits)
ruleDstNet := (*netip.Prefix)(nil)
err = r.addIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority)
ruleDstNet := netip.Prefix{}
err = r.addIPRule(defaultIPMasked, ruleDstNet, table, inboundPriority)
if err != nil {
return fmt.Errorf("adding rule for default route %s: %w", defaultRoute, err)
}
@@ -86,8 +86,8 @@ func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
bits = 128
}
defaultIPMasked := netip.PrefixFrom(assignedIP, bits)
ruleDstNet := (*netip.Prefix)(nil)
err = r.deleteIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority)
ruleDstNet := netip.Prefix{}
err = r.deleteIPRule(defaultIPMasked, ruleDstNet, table, inboundPriority)
if err != nil {
return fmt.Errorf("deleting rule for default route %s: %w", defaultRoute, err)
}

View File

@@ -19,8 +19,8 @@ var (
)
func ipMatchesFamily(ip netip.Addr, family int) bool {
return (family == netlink.FAMILY_V4 && ip.Is4()) ||
(family == netlink.FAMILY_V6 && ip.Is6())
return (family == netlink.FamilyV4 && ip.Is4()) ||
(family == netlink.FamilyV6 && ip.Is6())
}
func (r *Routing) assignedIP(interfaceName string, family int) (ip netip.Addr, err error) {

View File

@@ -29,25 +29,25 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
localLinks := make(map[int]struct{})
for _, link := range links {
if link.Attrs().EncapType != "ether" {
if link.EncapType != "ether" {
continue
}
localLinks[link.Attrs().Index] = struct{}{}
r.logger.Info("local ethernet link found: " + link.Attrs().Name)
localLinks[link.Index] = struct{}{}
r.logger.Info("local ethernet link found: " + link.Name)
}
if len(localLinks) == 0 {
return localNetworks, fmt.Errorf("%w: in %d links", ErrLinkLocalNotFound, len(links))
}
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
if err != nil {
return localNetworks, fmt.Errorf("listing routes: %w", err)
}
for _, route := range routes {
if route.Gw != nil || route.Dst == nil {
if route.Gw.IsValid() || !route.Dst.IsValid() {
continue
} else if _, ok := localLinks[route.LinkIndex]; !ok {
continue
@@ -55,7 +55,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
var localNet LocalNetwork
localNet.IPNet = netIPNetToNetipPrefix(*route.Dst)
localNet.IPNet = route.Dst
r.logger.Info("local ipnet found: " + localNet.IPNet.String())
link, err := r.netLinker.LinkByIndex(route.LinkIndex)
@@ -63,11 +63,11 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
return localNetworks, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
}
localNet.InterfaceName = link.Attrs().Name
localNet.InterfaceName = link.Name
family := netlink.FAMILY_V6
family := netlink.FamilyV6
if localNet.IPNet.Addr().Is4() {
family = netlink.FAMILY_V4
family = netlink.FamilyV4
}
ip, err := r.assignedIP(localNet.InterfaceName, family)
if err != nil {
@@ -96,7 +96,8 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) {
const localPriority = 98
// Main table was setup correctly by Docker, just need to add rules to use it
err = r.addIPRule(nil, &subnet.IPNet, mainTable, localPriority)
src := netip.Prefix{}
err = r.addIPRule(src, subnet.IPNet, mainTable, localPriority)
if err != nil {
return fmt.Errorf("adding rule: %v: %w", subnet.IPNet, err)
}

View File

@@ -8,7 +8,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
netlink "github.com/vishvananda/netlink"
netlink "github.com/qdm12/gluetun/internal/netlink"
)
// MockNetLinker is a mock of NetLinker interface.
@@ -50,7 +50,7 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca
}
// AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 *netlink.Addr) error {
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error)
@@ -79,11 +79,12 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkAdd indicates an expected call of LinkAdd.
@@ -166,11 +167,12 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
}
// LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkSetUp indicates an expected call of LinkSetUp.
@@ -180,7 +182,7 @@ func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call {
}
// RouteAdd mocks base method.
func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error {
func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteAdd", arg0)
ret0, _ := ret[0].(error)
@@ -194,7 +196,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
}
// RouteDel mocks base method.
func (m *MockNetLinker) RouteDel(arg0 *netlink.Route) error {
func (m *MockNetLinker) RouteDel(arg0 netlink.Route) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteDel", arg0)
ret0, _ := ret[0].(error)
@@ -208,7 +210,7 @@ func (mr *MockNetLinkerMockRecorder) RouteDel(arg0 interface{}) *gomock.Call {
}
// RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) {
func (m *MockNetLinker) RouteList(arg0 *netlink.Link, arg1 int) ([]netlink.Route, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0, arg1)
ret0, _ := ret[0].([]netlink.Route)
@@ -223,7 +225,7 @@ func (mr *MockNetLinkerMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.C
}
// RouteReplace mocks base method.
func (m *MockNetLinker) RouteReplace(arg0 *netlink.Route) error {
func (m *MockNetLinker) RouteReplace(arg0 netlink.Route) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteReplace", arg0)
ret0, _ := ret[0].(error)
@@ -237,7 +239,7 @@ func (mr *MockNetLinkerMockRecorder) RouteReplace(arg0 interface{}) *gomock.Call
}
// RuleAdd mocks base method.
func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error {
func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleAdd", arg0)
ret0, _ := ret[0].(error)
@@ -251,7 +253,7 @@ func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
}
// RuleDel mocks base method.
func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error {
func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleDel", arg0)
ret0, _ := ret[0].(error)

View File

@@ -56,8 +56,8 @@ func (r *Routing) removeOutboundSubnets(subnets []netip.Prefix,
}
}
ruleSrcNet := (*netip.Prefix)(nil)
ruleDstNet := &subnets[i]
ruleSrcNet := netip.Prefix{}
ruleDstNet := subnets[i]
err := r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
if err != nil {
warnings = append(warnings,
@@ -81,8 +81,8 @@ func (r *Routing) addOutboundSubnets(subnets []netip.Prefix,
}
}
ruleSrcNet := (*netip.Prefix)(nil)
ruleDstNet := &subnets[i]
ruleSrcNet := netip.Prefix{}
ruleDstNet := subnets[i]
err = r.addIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
if err != nil {
return fmt.Errorf("adding rule: for subnet %s: %w", subnet, err)

View File

@@ -23,12 +23,12 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
}
route := netlink.Route{
Dst: NetipPrefixToIPNet(&destination),
Gw: gateway.AsSlice(),
LinkIndex: link.Attrs().Index,
Dst: destination,
Gw: gateway,
LinkIndex: link.Index,
Table: table,
}
if err := r.netLinker.RouteReplace(&route); err != nil {
if err := r.netLinker.RouteReplace(route); err != nil {
return fmt.Errorf("replacing route for subnet %s at interface %s: %w",
destinationStr, iface, err)
}
@@ -51,12 +51,12 @@ func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr,
}
route := netlink.Route{
Dst: NetipPrefixToIPNet(&destination),
Gw: gateway.AsSlice(),
LinkIndex: link.Attrs().Index,
Dst: destination,
Gw: gateway,
LinkIndex: link.Index,
Table: table,
}
if err := r.netLinker.RouteDel(&route); err != nil {
if err := r.netLinker.RouteDel(route); err != nil {
return fmt.Errorf("deleting route: for subnet %s at interface %s: %w",
destinationStr, iface, err)
}

View File

@@ -18,30 +18,30 @@ type NetLinker interface {
type Addresser interface {
AddrList(link netlink.Link, family int) (
addresses []netlink.Addr, err error)
AddrReplace(link netlink.Link, addr *netlink.Addr) error
AddrReplace(link netlink.Link, addr netlink.Addr) error
}
type Router interface {
RouteList(link netlink.Link, family int) (
RouteList(link *netlink.Link, family int) (
routes []netlink.Route, err error)
RouteAdd(route *netlink.Route) error
RouteDel(route *netlink.Route) error
RouteReplace(route *netlink.Route) error
RouteAdd(route netlink.Route) error
RouteDel(route netlink.Route) error
RouteReplace(route netlink.Route) error
}
type Ruler interface {
RuleList(family int) (rules []netlink.Rule, err error)
RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
type Linker interface {
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index int) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
}

View File

@@ -1,30 +1,28 @@
package routing
import (
"bytes"
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
func (r *Routing) addIPRule(src, dst *netip.Prefix, table, priority int) error {
func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
const add = true
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
rule := netlink.NewRule()
rule.Src = NetipPrefixToIPNet(src)
rule.Dst = NetipPrefixToIPNet(dst)
rule.Src = src
rule.Dst = dst
rule.Priority = priority
rule.Table = table
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil {
return fmt.Errorf("listing rules: %w", err)
}
for i := range existingRules {
if !rulesAreEqual(&existingRules[i], rule) {
if !rulesAreEqual(existingRules[i], rule) {
continue
}
return nil // already exists
@@ -36,22 +34,22 @@ func (r *Routing) addIPRule(src, dst *netip.Prefix, table, priority int) error {
return nil
}
func (r *Routing) deleteIPRule(src, dst *netip.Prefix, table, priority int) error {
func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error {
const add = false
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
rule := netlink.NewRule()
rule.Src = NetipPrefixToIPNet(src)
rule.Dst = NetipPrefixToIPNet(dst)
rule.Src = src
rule.Dst = dst
rule.Priority = priority
rule.Table = table
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL)
existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil {
return fmt.Errorf("listing rules: %w", err)
}
for i := range existingRules {
if !rulesAreEqual(&existingRules[i], rule) {
if !rulesAreEqual(existingRules[i], rule) {
continue
}
if err := r.netLinker.RuleDel(rule); err != nil {
@@ -61,7 +59,7 @@ func (r *Routing) deleteIPRule(src, dst *netip.Prefix, table, priority int) erro
return nil
}
func ruleDbgMsg(add bool, src, dst *netip.Prefix,
func ruleDbgMsg(add bool, src, dst netip.Prefix,
table, priority int) (debugMessage string) {
debugMessage = "ip rule"
@@ -71,11 +69,11 @@ func ruleDbgMsg(add bool, src, dst *netip.Prefix,
debugMessage += " del"
}
if src != nil {
if src.IsValid() {
debugMessage += " from " + src.String()
}
if dst != nil {
if dst.IsValid() {
debugMessage += " to " + dst.String()
}
@@ -90,25 +88,20 @@ func ruleDbgMsg(add bool, src, dst *netip.Prefix,
return debugMessage
}
func rulesAreEqual(a, b *netlink.Rule) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return ipNetsAreEqual(a.Src, b.Src) &&
ipNetsAreEqual(a.Dst, b.Dst) &&
func rulesAreEqual(a, b netlink.Rule) bool {
return ipPrefixesAreEqual(a.Src, b.Src) &&
ipPrefixesAreEqual(a.Dst, b.Dst) &&
a.Priority == b.Priority &&
a.Table == b.Table
}
func ipNetsAreEqual(a, b *net.IPNet) bool {
if a == nil && b == nil {
func ipPrefixesAreEqual(a, b netip.Prefix) bool {
if !a.IsValid() && !b.IsValid() {
return true
}
if a == nil || b == nil {
if !a.IsValid() || !b.IsValid() {
return false
}
return a.IP.Equal(b.IP) && bytes.Equal(a.Mask, b.Mask)
return a.Bits() == b.Bits() &&
a.Addr().Compare(b.Addr()) == 0
}

View File

@@ -2,7 +2,6 @@ package routing
import (
"errors"
"net"
"net/netip"
"testing"
@@ -12,17 +11,16 @@ import (
"github.com/stretchr/testify/require"
)
func makeNetipPrefix(n byte) *netip.Prefix {
func makeNetipPrefix(n byte) netip.Prefix {
const bits = 24
prefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
return &prefix
return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
}
func makeIPRule(src, dst *netip.Prefix,
table, priority int) *netlink.Rule {
func makeIPRule(src, dst netip.Prefix,
table, priority int) netlink.Rule {
rule := netlink.NewRule()
rule.Src = NetipPrefixToIPNet(src)
rule.Dst = NetipPrefixToIPNet(dst)
rule.Src = src
rule.Dst = dst
rule.Table = table
rule.Priority = priority
return rule
@@ -40,13 +38,13 @@ func Test_Routing_addIPRule(t *testing.T) {
type ruleAddCall struct {
expected bool
ruleToAdd *netlink.Rule
ruleToAdd netlink.Rule
err error
}
testCases := map[string]struct {
src *netip.Prefix
dst *netip.Prefix
src netip.Prefix
dst netip.Prefix
table int
priority int
dbgMsg string
@@ -69,8 +67,8 @@ func Test_Routing_addIPRule(t *testing.T) {
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
},
@@ -95,8 +93,8 @@ func Test_Routing_addIPRule(t *testing.T) {
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
},
},
ruleAdd: ruleAddCall{
@@ -116,7 +114,7 @@ func Test_Routing_addIPRule(t *testing.T) {
logger.EXPECT().Debug(testCase.dbgMsg)
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().RuleList(netlink.FAMILY_ALL).
netLinker.EXPECT().RuleList(netlink.FamilyAll).
Return(testCase.ruleList.rules, testCase.ruleList.err)
if testCase.ruleAdd.expected {
netLinker.EXPECT().RuleAdd(testCase.ruleAdd.ruleToAdd).
@@ -153,13 +151,13 @@ func Test_Routing_deleteIPRule(t *testing.T) {
type ruleDelCall struct {
expected bool
ruleToDel *netlink.Rule
ruleToDel netlink.Rule
err error
}
testCases := map[string]struct {
src *netip.Prefix
dst *netip.Prefix
src netip.Prefix
dst netip.Prefix
table int
priority int
dbgMsg string
@@ -182,7 +180,7 @@ func Test_Routing_deleteIPRule(t *testing.T) {
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
ruleDel: ruleDelCall{
@@ -200,8 +198,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
},
},
ruleDel: ruleDelCall{
@@ -217,8 +215,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{
rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
},
},
},
@@ -234,7 +232,7 @@ func Test_Routing_deleteIPRule(t *testing.T) {
logger.EXPECT().Debug(testCase.dbgMsg)
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().RuleList(netlink.FAMILY_ALL).
netLinker.EXPECT().RuleList(netlink.FamilyAll).
Return(testCase.ruleList.rules, testCase.ruleList.err)
if testCase.ruleDel.expected {
netLinker.EXPECT().RuleDel(testCase.ruleDel.ruleToDel).
@@ -264,8 +262,8 @@ func Test_ruleDbgMsg(t *testing.T) {
testCases := map[string]struct {
add bool
src *netip.Prefix
dst *netip.Prefix
src netip.Prefix
dst netip.Prefix
table int
priority int
dbgMsg string
@@ -307,38 +305,79 @@ func Test_rulesAreEqual(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
a *netlink.Rule
b *netlink.Rule
a netlink.Rule
b netlink.Rule
equal bool
}{
"both nil": {
"both_empty": {
equal: true,
},
"first nil": {
b: &netlink.Rule{},
},
"second nil": {
a: &netlink.Rule{},
},
"both not nil": {
a: &netlink.Rule{},
b: &netlink.Rule{},
equal: true,
},
"both equal": {
a: &netlink.Rule{
Src: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
"not_equal_by_src": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
b: &netlink.Rule{
Src: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 0),
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"not_equal_by_dst": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32),
Priority: 100,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"not_equal_by_priority": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 999,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"not_equal_by_table": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 999,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"equal": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
@@ -358,58 +397,39 @@ func Test_rulesAreEqual(t *testing.T) {
}
}
func Test_ipNetsAreEqual(t *testing.T) {
func Test_ipPrefixesAreEqual(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
a *net.IPNet
b *net.IPNet
a netip.Prefix
b netip.Prefix
equal bool
}{
"both nil": {
"both_not_valid": {
equal: true,
},
"first nil": {
b: &net.IPNet{},
"first_not_valid": {
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
},
"second nil": {
a: &net.IPNet{},
"second_not_valid": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
},
"both not nil": {
a: &net.IPNet{},
b: &net.IPNet{},
"both_equal": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
equal: true,
},
"both equal": {
a: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 0),
"both_not_equal_by_IP": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 24),
},
b: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
equal: true,
},
"both not equal by IP": {
a: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
b: &net.IPNet{
IP: net.IPv4(2, 2, 2, 2),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
},
"both not equal by mask": {
a: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
b: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 0, 0),
"both_not_equal_by_bits": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
},
"both_not_equal_by_IP_and_bits": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
},
}
@@ -418,7 +438,7 @@ func Test_ipNetsAreEqual(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
equal := ipNetsAreEqual(testCase.a, testCase.b)
equal := ipPrefixesAreEqual(testCase.a, testCase.b)
assert.Equal(t, testCase.equal, equal)
})

View File

@@ -1,10 +1,8 @@
package routing
import (
"bytes"
"errors"
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
@@ -16,14 +14,14 @@ var (
)
func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
if err != nil {
return ip, fmt.Errorf("listing routes: %w", err)
}
defaultLinkIndex := -1
for _, route := range routes {
if route.Dst == nil {
if !route.Dst.IsValid() {
defaultLinkIndex = route.LinkIndex
break
}
@@ -34,17 +32,17 @@ func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) {
for _, route := range routes {
if route.LinkIndex == defaultLinkIndex &&
route.Dst != nil &&
!IPIsPrivate(netIPToNetipAddress(route.Dst.IP)) &&
bytes.Equal(route.Dst.Mask, net.IPMask{255, 255, 255, 255}) {
return netIPToNetipAddress(route.Dst.IP), nil
route.Dst.IsValid() &&
!IPIsPrivate(route.Dst.Addr()) &&
route.Dst.IsSingleIP() {
return route.Dst.Addr(), nil
}
}
return ip, fmt.Errorf("%w: in %d routes", ErrVPNDestinationIPNotFound, len(routes))
}
func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
if err != nil {
return ip, fmt.Errorf("listing routes: %w", err)
}
@@ -53,11 +51,11 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
if err != nil {
return ip, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
}
interfaceName := link.Attrs().Name
interfaceName := link.Name
if interfaceName == vpnIntf &&
route.Dst != nil &&
route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) {
return netIPToNetipAddress(route.Gw), nil
route.Dst.IsValid() &&
route.Dst.Addr().IsUnspecified() {
return route.Gw, nil
}
}
return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes))

View File

@@ -42,7 +42,7 @@ type Storage interface {
}
type NetLinker interface {
AddrReplace(link netlink.Link, addr *netlink.Addr) error
AddrReplace(link netlink.Link, addr netlink.Addr) error
Router
Ruler
Linker
@@ -50,22 +50,22 @@ type NetLinker interface {
}
type Router interface {
RouteList(link netlink.Link, family int) (
RouteList(link *netlink.Link, family int) (
routes []netlink.Route, err error)
RouteAdd(route *netlink.Route) error
RouteAdd(route netlink.Route) error
}
type Ruler interface {
RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
type Linker interface {
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error)
}

View File

@@ -5,7 +5,6 @@ import (
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/routing"
)
func (w *Wireguard) addAddresses(link netlink.Link,
@@ -15,15 +14,14 @@ func (w *Wireguard) addAddresses(link netlink.Link,
continue
}
ipNet := ipNet
address := &netlink.Addr{
IPNet: routing.NetipPrefixToIPNet(&ipNet),
address := netlink.Addr{
Network: ipNet,
}
err = w.netlink.AddrReplace(link, address)
if err != nil {
return fmt.Errorf("%w: when adding address %s to link %s",
err, address, link.Attrs().Name)
err, address, link.Name)
}
}

View File

@@ -7,7 +7,6 @@ import (
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/routing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -18,14 +17,6 @@ func Test_Wireguard_addAddresses(t *testing.T) {
ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
ipNetTwo := netip.PrefixFrom(netip.MustParseAddr("::1234"), 64)
newLink := func() netlink.Link {
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = "a_bridge"
return &netlink.Bridge{
LinkAttrs: linkAttrs,
}
}
errDummy := errors.New("dummy")
testCases := map[string]struct {
@@ -35,15 +26,15 @@ func Test_Wireguard_addAddresses(t *testing.T) {
err error
}{
"success": {
link: newLink(),
link: netlink.Link{Type: "wireguard"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}).
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
Return(nil)
netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}).
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
Return(nil).After(firstCall)
return &Wireguard{
netlink: netLinker,
@@ -54,12 +45,12 @@ func Test_Wireguard_addAddresses(t *testing.T) {
},
},
"first add error": {
link: newLink(),
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}).
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
Return(errDummy)
return &Wireguard{
netlink: netLinker,
@@ -71,15 +62,15 @@ func Test_Wireguard_addAddresses(t *testing.T) {
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
},
"second add error": {
link: newLink(),
link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}).
AddrReplace(link, netlink.Addr{Network: ipNetOne}).
Return(nil)
netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}).
AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
Return(errDummy).After(firstCall)
return &Wireguard{
netlink: netLinker,
@@ -91,7 +82,6 @@ func Test_Wireguard_addAddresses(t *testing.T) {
err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"),
},
"ignore IPv6": {
link: newLink(),
addrs: []netip.Prefix{ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
return &Wireguard{

View File

@@ -3,6 +3,7 @@ package wireguard
import (
"fmt"
"net"
"net/netip"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -53,8 +54,14 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
PublicKey: publicKey,
PresharedKey: preSharedKey,
AllowedIPs: []net.IPNet{
*allIPv4(),
*allIPv6(),
{
IP: net.IPv4(0, 0, 0, 0),
Mask: []byte{0, 0, 0, 0},
},
{
IP: net.IPv6zero,
Mask: []byte(net.IPv6zero),
},
},
ReplaceAllowedIPs: true,
Endpoint: &net.UDPAddr{
@@ -68,16 +75,12 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
return config, nil
}
func allIPv4() (ipNet *net.IPNet) {
return &net.IPNet{
IP: net.IPv4(0, 0, 0, 0),
Mask: []byte{0, 0, 0, 0},
}
func allIPv4() (prefix netip.Prefix) {
const bits = 0
return netip.PrefixFrom(netip.IPv4Unspecified(), bits)
}
func allIPv6() (ipNet *net.IPNet) {
return &net.IPNet{
IP: net.IPv6zero,
Mask: []byte(net.IPv6zero),
}
func allIPv6() (prefix netip.Prefix) {
const bits = 0
return netip.PrefixFrom(netip.IPv6Unspecified(), bits)
}

View File

@@ -24,10 +24,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
netlinker := netlink.New(&noopDebugLogger{})
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = "test_8081"
link := &netlink.Bridge{
LinkAttrs: linkAttrs,
link := netlink.Link{
Type: "bridge",
Name: "test_8081",
}
// Remove any previously created test interface from a crashed/panic
@@ -37,8 +36,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
require.NoError(t, err)
}
err = netlinker.LinkAdd(link)
linkIndex, err := netlinker.LinkAdd(link)
require.NoError(t, err)
link.Index = linkIndex
defer func() {
err = netlinker.LinkDel(link)
@@ -63,14 +63,12 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
err = wg.addAddresses(link, addresses)
require.NoError(t, err)
netlinkAddresses, err := netlinker.AddrList(link, netlink.FAMILY_ALL)
netlinkAddresses, err := netlinker.AddrList(link, netlink.FamilyAll)
require.NoError(t, err)
require.Equal(t, len(addresses), len(netlinkAddresses))
for i, netlinkAddress := range netlinkAddresses {
require.NotNil(t, netlinkAddress.IPNet)
ipNet, err := netip.ParsePrefix(netlinkAddress.IPNet.String())
require.NoError(t, err)
assert.Equal(t, addresses[i], ipNet)
require.NotNil(t, netlinkAddress.Network)
assert.Equal(t, addresses[i], netlinkAddress.Network)
}
}
}
@@ -95,7 +93,7 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
assert.NoError(t, err)
}()
rules, err := netlinker.RuleList(netlink.FAMILY_V4)
rules, err := netlinker.RuleList(netlink.FamilyV4)
require.NoError(t, err)
var rule netlink.Rule
var ruleFound bool
@@ -111,11 +109,6 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: 4294967295,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
}
assert.Equal(t, expectedRule, rule)

View File

@@ -5,7 +5,7 @@ import "github.com/qdm12/gluetun/internal/netlink"
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
type NetLinker interface {
AddrReplace(link netlink.Link, addr *netlink.Addr) error
AddrReplace(link netlink.Link, addr netlink.Addr) error
Router
Ruler
Linker
@@ -13,21 +13,21 @@ type NetLinker interface {
}
type Router interface {
RouteList(link netlink.Link, family int) (
RouteList(link *netlink.Link, family int) (
routes []netlink.Route, err error)
RouteAdd(route *netlink.Route) error
RouteAdd(route netlink.Route) error
}
type Ruler interface {
RuleAdd(rule *netlink.Rule) error
RuleDel(rule *netlink.Rule) error
RuleAdd(rule netlink.Rule) error
RuleDel(rule netlink.Rule) error
}
type Linker interface {
LinkAdd(link netlink.Link) (err error)
LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error)
LinkSetUp(link netlink.Link) error
LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) error
LinkDel(link netlink.Link) error
}

View File

@@ -8,7 +8,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
netlink "github.com/vishvananda/netlink"
netlink "github.com/qdm12/gluetun/internal/netlink"
)
// MockNetLinker is a mock of NetLinker interface.
@@ -35,7 +35,7 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
}
// AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 *netlink.Addr) error {
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error)
@@ -64,11 +64,12 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
}
// LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkAdd indicates an expected call of LinkAdd.
@@ -136,11 +137,12 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
}
// LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error {
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LinkSetUp indicates an expected call of LinkSetUp.
@@ -150,7 +152,7 @@ func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call {
}
// RouteAdd mocks base method.
func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error {
func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteAdd", arg0)
ret0, _ := ret[0].(error)
@@ -164,7 +166,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
}
// RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) {
func (m *MockNetLinker) RouteList(arg0 *netlink.Link, arg1 int) ([]netlink.Route, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0, arg1)
ret0, _ := ret[0].([]netlink.Route)
@@ -179,7 +181,7 @@ func (mr *MockNetLinkerMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.C
}
// RuleAdd mocks base method.
func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error {
func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleAdd", arg0)
ret0, _ := ret[0].(error)
@@ -193,7 +195,7 @@ func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
}
// RuleDel mocks base method.
func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error {
func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleDel", arg0)
ret0, _ := ret[0].(error)

View File

@@ -2,17 +2,17 @@ package wireguard
import (
"fmt"
"net"
"net/netip"
"github.com/qdm12/gluetun/internal/netlink"
)
// TODO add IPv6 route if IPv6 is supported
func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet,
func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
firewallMark int) (err error) {
route := &netlink.Route{
LinkIndex: link.Attrs().Index,
route := netlink.Route{
LinkIndex: link.Index,
Dst: dst,
Table: firewallMark,
}
@@ -21,7 +21,7 @@ func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet,
if err != nil {
return fmt.Errorf(
"adding route for link %s, destination %s and table %d: %w",
link.Attrs().Name, dst, firewallMark, err)
link.Name, dst, firewallMark, err)
}
return err

View File

@@ -2,7 +2,7 @@ package wireguard
import (
"errors"
"net"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
@@ -15,41 +15,40 @@ func Test_Wireguard_addRoute(t *testing.T) {
t.Parallel()
const linkIndex = 88
newLink := func() netlink.Link {
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = "a_bridge"
linkAttrs.Index = linkIndex
return &netlink.Bridge{
LinkAttrs: linkAttrs,
}
}
ipNet := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
ipPrefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
const firewallMark = 51820
errDummy := errors.New("dummy")
testCases := map[string]struct {
link netlink.Link
dst *net.IPNet
expectedRoute *netlink.Route
dst netip.Prefix
expectedRoute netlink.Route
routeAddErr error
err error
}{
"success": {
link: newLink(),
dst: ipNet,
expectedRoute: &netlink.Route{
link: netlink.Link{
Index: linkIndex,
},
dst: ipPrefix,
expectedRoute: netlink.Route{
LinkIndex: linkIndex,
Dst: ipNet,
Dst: ipPrefix,
Table: firewallMark,
},
},
"route add error": {
link: newLink(),
dst: ipNet,
expectedRoute: &netlink.Route{
link: netlink.Link{
Name: "a_bridge",
Index: linkIndex,
},
dst: ipPrefix,
expectedRoute: netlink.Route{
LinkIndex: linkIndex,
Dst: ipNet,
Dst: ipPrefix,
Table: firewallMark,
},
routeAddErr: errDummy,

View File

@@ -21,53 +21,38 @@ func Test_Wireguard_addRule(t *testing.T) {
errDummy := errors.New("dummy")
testCases := map[string]struct {
expectedRule *netlink.Rule
expectedRule netlink.Rule
ruleAddErr error
err error
ruleDelErr error
cleanupErr error
}{
"success": {
expectedRule: &netlink.Rule{
expectedRule: netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Family: family,
},
},
"rule add error": {
expectedRule: &netlink.Rule{
expectedRule: netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Family: family,
},
ruleAddErr: errDummy,
err: errors.New("adding rule ip rule 987: from all to all table 456: dummy"),
},
"rule delete error": {
expectedRule: &netlink.Rule{
expectedRule: netlink.Rule{
Invert: true,
Priority: rulePriority,
Mark: firewallMark,
Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Family: family,
},
ruleDelErr: errDummy,

View File

@@ -93,10 +93,12 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
return
}
if err := w.netlink.LinkSetUp(link); err != nil {
linkIndex, err := w.netlink.LinkSetUp(link)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
return
}
link.Index = linkIndex
closers.add("shutting down link", stepFour, func() error {
return w.netlink.LinkSetDown(link)
})
@@ -161,17 +163,16 @@ func setupKernelSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint16,
closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
linkAttrs := netlink.LinkAttrs{
link = netlink.Link{
Type: "wireguard",
Name: interfaceName,
MTU: int(mtu),
MTU: mtu,
}
link = &netlink.Wireguard{
LinkAttrs: linkAttrs,
}
err = netLinker.LinkAdd(link)
linkIndex, err := netLinker.LinkAdd(link)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
return link, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
}
link.Index = linkIndex
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link)
})
@@ -191,22 +192,22 @@ func setupUserSpace(ctx context.Context,
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
tun, err := tun.CreateTUN(interfaceName, int(mtu))
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
return link, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
}
closers.add("closing TUN device", stepSeven, tun.Close)
tunName, err := tun.Name()
if err != nil {
return nil, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
return link, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
} else if tunName != interfaceName {
return nil, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
return link, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, interfaceName, tunName)
}
link, err = netLinker.LinkByName(interfaceName)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
return link, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
}
closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link)
@@ -226,14 +227,14 @@ func setupUserSpace(ctx context.Context,
uapiFile, err := ipc.UAPIOpen(interfaceName)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
return link, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
}
closers.add("closing UAPI file", stepThree, uapiFile.Close)
uapiListener, err := ipc.UAPIListen(interfaceName, uapiFile)
if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
return link, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
}
closers.add("closing UAPI listener", stepTwo, uapiListener.Close)