chore(netlink): define own types with minimal fields

- Allow to swap `github.com/vishvananda/netlink`
- Allow to add build tags for each platform
- One step closer to development on non-Linux platforms
This commit is contained in:
Quentin McGaw
2023-05-29 06:44:58 +00:00
parent 163ac48ce4
commit 38ddcfa756
34 changed files with 828 additions and 493 deletions

View File

@@ -531,30 +531,30 @@ type netLinker interface {
type Addresser interface { type Addresser interface {
AddrList(link netlink.Link, family int) ( AddrList(link netlink.Link, family int) (
addresses []netlink.Addr, err error) addresses []netlink.Addr, err error)
AddrReplace(link netlink.Link, addr *netlink.Addr) error AddrReplace(link netlink.Link, addr netlink.Addr) error
} }
type Router interface { type Router interface {
RouteList(link netlink.Link, family int) ( RouteList(link *netlink.Link, family int) (
routes []netlink.Route, err error) routes []netlink.Route, err error)
RouteAdd(route *netlink.Route) error RouteAdd(route netlink.Route) error
RouteDel(route *netlink.Route) error RouteDel(route netlink.Route) error
RouteReplace(route *netlink.Route) error RouteReplace(route netlink.Route) error
} }
type Ruler interface { type Ruler interface {
RuleList(family int) (rules []netlink.Rule, err error) RuleList(family int) (rules []netlink.Rule, err error)
RuleAdd(rule *netlink.Rule) error RuleAdd(rule netlink.Rule) error
RuleDel(rule *netlink.Rule) error RuleDel(rule netlink.Rule) error
} }
type Linker interface { type Linker interface {
LinkList() (links []netlink.Link, err error) LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index int) (link netlink.Link, err error) LinkByIndex(index int) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (err error) LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error) LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (err error) LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error) LinkSetDown(link netlink.Link) (err error)
} }

View File

@@ -1,14 +1,40 @@
package netlink package netlink
import "github.com/vishvananda/netlink" import (
"net/netip"
type Addr = netlink.Addr "github.com/vishvananda/netlink"
)
type Addr struct {
Network netip.Prefix
}
func (a Addr) String() string {
return a.Network.String()
}
func (n *NetLink) AddrList(link Link, family int) ( func (n *NetLink) AddrList(link Link, family int) (
addresses []Addr, err error) { addresses []Addr, err error) {
return netlink.AddrList(link, family) netlinkLink := linkToNetlinkLink(&link)
netlinkAddresses, err := netlink.AddrList(netlinkLink, family)
if err != nil {
return nil, err
}
addresses = make([]Addr, len(netlinkAddresses))
for i := range netlinkAddresses {
addresses[i].Network = netIPNetToNetipPrefix(netlinkAddresses[i].IPNet)
}
return addresses, nil
} }
func (n *NetLink) AddrReplace(link Link, addr *Addr) error { func (n *NetLink) AddrReplace(link Link, addr Addr) error {
return netlink.AddrReplace(link, addr) netlinkLink := linkToNetlinkLink(&link)
netlinkAddress := netlink.Addr{
IPNet: netipPrefixToIPNet(addr.Network),
}
return netlink.AddrReplace(netlinkLink, &netlinkAddress)
} }

View File

@@ -0,0 +1,62 @@
package netlink
import (
"fmt"
"net"
"net/netip"
)
func netipPrefixToIPNet(prefix netip.Prefix) (ipNet *net.IPNet) {
if !prefix.IsValid() {
return nil
}
prefixAddr := prefix.Addr().Unmap()
ipMask := net.CIDRMask(prefix.Bits(), prefixAddr.BitLen())
ip := netipAddrToNetIP(prefixAddr)
return &net.IPNet{
IP: ip,
Mask: ipMask,
}
}
func netIPNetToNetipPrefix(ipNet *net.IPNet) (prefix netip.Prefix) {
if ipNet == nil || (len(ipNet.IP) != net.IPv4len && len(ipNet.IP) != net.IPv6len) {
return prefix
}
var ip netip.Addr
if ipv4 := ipNet.IP.To4(); ipv4 != nil {
ip = netip.AddrFrom4([4]byte(ipv4))
} else {
ip = netip.AddrFrom16([16]byte(ipNet.IP))
}
bits, _ := ipNet.Mask.Size()
return netip.PrefixFrom(ip, bits)
}
func netipAddrToNetIP(address netip.Addr) (ip net.IP) {
switch {
case !address.IsValid():
return nil
case address.Is4() || address.Is4In6():
bytes := address.As4()
return net.IP(bytes[:])
default:
bytes := address.As16()
return net.IP(bytes[:])
}
}
func netIPToNetipAddress(ip net.IP) (address netip.Addr) {
if len(ip) != net.IPv4len && len(ip) != net.IPv6len {
return address // invalid
}
address, ok := netip.AddrFromSlice(ip)
if !ok {
panic(fmt.Sprintf("converting %#v to netip.Addr failed", ip))
}
return address.Unmap()
}

View File

@@ -0,0 +1,146 @@
package netlink
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_netipPrefixToIPNet(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
prefix netip.Prefix
ipNet *net.IPNet
}{
"empty_prefix": {},
"IPv4_prefix": {
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
ipNet: &net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.IPv4Mask(255, 255, 255, 0),
},
},
"IPv4-in-IPv6_prefix": {
prefix: netip.PrefixFrom(netip.AddrFrom16(
[16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 1, 2, 3, 4}),
24),
ipNet: &net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.IPv4Mask(255, 255, 255, 0),
},
},
"IPv6_prefix": {
prefix: netip.PrefixFrom(netip.IPv6Loopback(), 8),
ipNet: &net.IPNet{
IP: net.IPv6loopback,
Mask: net.IPMask{0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
ipNet := netipPrefixToIPNet(testCase.prefix)
assert.Equal(t, testCase.ipNet, ipNet)
})
}
}
func Test_netIPNetToNetipPrefix(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
ipNet *net.IPNet
prefix netip.Prefix
}{
"empty ipnet": {},
"custom sized IP in ipnet": {
ipNet: &net.IPNet{
IP: net.IP{1},
},
},
"IPv4 ipnet": {
ipNet: &net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.IPMask{255, 255, 255, 0},
},
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
"IPv4-in-IPv6 ipnet": {
ipNet: &net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
},
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
"IPv6 ipnet": {
ipNet: &net.IPNet{
IP: net.IPv6loopback,
Mask: net.IPMask{0xff},
},
prefix: netip.PrefixFrom(netip.IPv6Loopback(), 8),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
prefix := netIPNetToNetipPrefix(testCase.ipNet)
assert.Equal(t, testCase.prefix, prefix)
})
}
}
func Test_netIPToNetipAddress(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
ip net.IP
address netip.Addr
panicMessage string
}{
"nil_ip": {},
"ip_not_ipv4_or_ipv6": {
ip: net.IP{1},
},
"IPv4": {
ip: net.IPv4(1, 2, 3, 4),
address: netip.AddrFrom4([4]byte{1, 2, 3, 4}),
},
"IPv6": {
ip: net.IPv6zero,
address: netip.AddrFrom16([16]byte{}),
},
"IPv4 prefixed with 0xffff": {
ip: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 1, 2, 3, 4},
address: netip.AddrFrom4([4]byte{1, 2, 3, 4}),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
if testCase.panicMessage != "" {
assert.PanicsWithValue(t, testCase.panicMessage, func() {
netIPToNetipAddress(testCase.ip)
})
return
}
address := netIPToNetipAddress(testCase.ip)
assert.Equal(t, testCase.address, address)
})
}
}

View File

@@ -6,20 +6,19 @@ import (
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
) )
//nolint:revive
const ( const (
FAMILY_ALL = netlink.FAMILY_ALL FamilyAll = 0
FAMILY_V4 = netlink.FAMILY_V4 FamilyV4 = 2
FAMILY_V6 = netlink.FAMILY_V6 FamilyV6 = 10
) )
func FamilyToString(family int) string { func FamilyToString(family int) string {
switch family { switch family {
case FAMILY_ALL: case FamilyAll:
return "all" return "all" //nolint:goconst
case FAMILY_V4: case FamilyV4:
return "v4" return "v4"
case FAMILY_V6: case FamilyV6:
return "v6" return "v6"
default: default:
return fmt.Sprint(family) return fmt.Sprint(family)

View File

@@ -14,20 +14,21 @@ func (n *NetLink) IsIPv6Supported() (supported bool, err error) {
var totalRoutes uint var totalRoutes uint
for _, link := range links { for _, link := range links {
routes, err := n.RouteList(link, netlink.FAMILY_V6) link := link
routes, err := n.RouteList(&link, netlink.FAMILY_V6)
if err != nil { if err != nil {
return false, fmt.Errorf("listing IPv6 routes for link %s: %w", return false, fmt.Errorf("listing IPv6 routes for link %s: %w",
link.Attrs().Name, err) link.Name, err)
} }
// Check each route for IPv6 due to Podman bug listing IPv4 routes // Check each route for IPv6 due to Podman bug listing IPv4 routes
// as IPv6 routes at container start, see: // as IPv6 routes at container start, see:
// https://github.com/qdm12/gluetun/issues/1241#issuecomment-1333405949 // https://github.com/qdm12/gluetun/issues/1241#issuecomment-1333405949
for _, route := range routes { for _, route := range routes {
sourceIsIPv6 := route.Src != nil && route.Src.To4() == nil sourceIsIPv6 := route.Src.IsValid() && route.Src.Is6()
destinationIsIPv6 := route.Dst != nil && route.Dst.IP.To4() == nil destinationIsIPv6 := route.Dst.IsValid() && route.Dst.Addr().Is6()
if sourceIsIPv6 || destinationIsIPv6 { if sourceIsIPv6 || destinationIsIPv6 {
n.debugLogger.Debugf("IPv6 is supported by link %s", link.Attrs().Name) n.debugLogger.Debugf("IPv6 is supported by link %s", link.Name)
return true, nil return true, nil
} }
totalRoutes++ totalRoutes++

View File

@@ -2,36 +2,117 @@ package netlink
import "github.com/vishvananda/netlink" import "github.com/vishvananda/netlink"
type ( type Link struct {
Link = netlink.Link Type string
Bridge = netlink.Bridge Name string
Wireguard = netlink.Wireguard Index int
) EncapType string
MTU uint16
NetNsID int
TxQLen int
}
func (n *NetLink) LinkList() (links []Link, err error) { func (n *NetLink) LinkList() (links []Link, err error) {
return netlink.LinkList() netlinkLinks, err := netlink.LinkList()
if err != nil {
return nil, err
}
links = make([]Link, len(netlinkLinks))
for i := range netlinkLinks {
links[i] = netlinkLinkToLink(netlinkLinks[i])
}
return links, nil
} }
func (n *NetLink) LinkByName(name string) (link Link, err error) { func (n *NetLink) LinkByName(name string) (link Link, err error) {
return netlink.LinkByName(name) netlinkLink, err := netlink.LinkByName(name)
if err != nil {
return Link{}, err
}
return netlinkLinkToLink(netlinkLink), nil
} }
func (n *NetLink) LinkByIndex(index int) (link Link, err error) { func (n *NetLink) LinkByIndex(index int) (link Link, err error) {
return netlink.LinkByIndex(index) netlinkLink, err := netlink.LinkByIndex(index)
if err != nil {
return Link{}, err
}
return netlinkLinkToLink(netlinkLink), nil
} }
func (n *NetLink) LinkAdd(link Link) (err error) { func (n *NetLink) LinkAdd(link Link) (linkIndex int, err error) {
return netlink.LinkAdd(link) netlinkLink := linkToNetlinkLink(&link)
err = netlink.LinkAdd(netlinkLink)
if err != nil {
return 0, err
}
return netlinkLink.Attrs().Index, nil
} }
func (n *NetLink) LinkDel(link Link) (err error) { func (n *NetLink) LinkDel(link Link) (err error) {
return netlink.LinkDel(link) return netlink.LinkDel(linkToNetlinkLink(&link))
} }
func (n *NetLink) LinkSetUp(link Link) (err error) { func (n *NetLink) LinkSetUp(link Link) (linkIndex int, err error) {
return netlink.LinkSetUp(link) netlinkLink := linkToNetlinkLink(&link)
err = netlink.LinkSetUp(netlinkLink)
if err != nil {
return 0, err
}
return netlinkLink.Attrs().Index, nil
} }
func (n *NetLink) LinkSetDown(link Link) (err error) { func (n *NetLink) LinkSetDown(link Link) (err error) {
return netlink.LinkSetDown(link) return netlink.LinkSetDown(linkToNetlinkLink(&link))
}
type netlinkLinkImpl struct {
attrs *netlink.LinkAttrs
linkType string
}
func (n *netlinkLinkImpl) Attrs() *netlink.LinkAttrs {
return n.attrs
}
func (n *netlinkLinkImpl) Type() string {
return n.linkType
}
func netlinkLinkToLink(netlinkLink netlink.Link) Link {
attributes := netlinkLink.Attrs()
return Link{
Type: netlinkLink.Type(),
Name: attributes.Name,
Index: attributes.Index,
EncapType: attributes.EncapType,
MTU: uint16(attributes.MTU),
NetNsID: attributes.NetNsID,
TxQLen: attributes.TxQLen,
}
}
// Warning: we must return `netlink.Link` and not `netlinkLinkImpl`
// so that the vishvananda/netlink package can compare the returned
// value against an untyped nil.
func linkToNetlinkLink(link *Link) netlink.Link {
if link == nil {
return nil
}
return &netlinkLinkImpl{
linkType: link.Type,
attrs: &netlink.LinkAttrs{ // TODO get all original attributes
Name: link.Name,
Index: link.Index,
EncapType: link.EncapType,
MTU: int(link.MTU),
NetNsID: link.NetNsID,
TxQLen: link.TxQLen,
},
}
} }

View File

@@ -1,9 +0,0 @@
package netlink
import "github.com/vishvananda/netlink"
type LinkAttrs = netlink.LinkAttrs
func NewLinkAttrs() LinkAttrs {
return netlink.NewLinkAttrs()
}

View File

@@ -1,22 +1,74 @@
package netlink package netlink
import "github.com/vishvananda/netlink" import (
"net/netip"
type Route = netlink.Route "github.com/vishvananda/netlink"
)
func (n *NetLink) RouteList(link Link, family int) ( type Route struct {
LinkIndex int
Dst netip.Prefix
Src netip.Addr
Gw netip.Addr
Priority int
Family int
Table int
Type int
}
func (n *NetLink) RouteList(link *Link, family int) (
routes []Route, err error) { routes []Route, err error) {
return netlink.RouteList(link, family) netlinkLink := linkToNetlinkLink(link)
netlinkRoutes, err := netlink.RouteList(netlinkLink, family)
if err != nil {
return nil, err
}
routes = make([]Route, len(netlinkRoutes))
for i := range netlinkRoutes {
routes[i] = netlinkRouteToRoute(netlinkRoutes[i])
}
return routes, nil
} }
func (n *NetLink) RouteAdd(route *Route) error { func (n *NetLink) RouteAdd(route Route) error {
return netlink.RouteAdd(route) netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteAdd(&netlinkRoute)
} }
func (n *NetLink) RouteDel(route *Route) error { func (n *NetLink) RouteDel(route Route) error {
return netlink.RouteDel(route) netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteDel(&netlinkRoute)
} }
func (n *NetLink) RouteReplace(route *Route) error { func (n *NetLink) RouteReplace(route Route) error {
return netlink.RouteReplace(route) netlinkRoute := routeToNetlinkRoute(route)
return netlink.RouteReplace(&netlinkRoute)
}
func netlinkRouteToRoute(netlinkRoute netlink.Route) (route Route) {
return Route{
LinkIndex: netlinkRoute.LinkIndex,
Dst: netIPNetToNetipPrefix(netlinkRoute.Dst),
Src: netIPToNetipAddress(netlinkRoute.Src),
Gw: netIPToNetipAddress(netlinkRoute.Gw),
Priority: netlinkRoute.Priority,
Family: netlinkRoute.Family,
Table: netlinkRoute.Table,
Type: netlinkRoute.Type,
}
}
func routeToNetlinkRoute(route Route) (netlinkRoute netlink.Route) {
return netlink.Route{
LinkIndex: route.LinkIndex,
Dst: netipPrefixToIPNet(route.Dst),
Src: netipAddrToNetIP(route.Src),
Gw: netipAddrToNetIP(route.Gw),
Priority: route.Priority,
Family: route.Family,
Table: route.Table,
Type: route.Type,
}
} }

View File

@@ -1,21 +1,90 @@
package netlink package netlink
import "github.com/vishvananda/netlink" import (
"fmt"
"net/netip"
type Rule = netlink.Rule "github.com/vishvananda/netlink"
)
func NewRule() *Rule { type Rule struct {
return netlink.NewRule() Priority int
Family int
Table int
Mark int
Src netip.Prefix
Dst netip.Prefix
Invert bool
}
func (r Rule) String() string {
from := "all"
if r.Src.IsValid() {
from = r.Src.String()
}
to := "all"
if r.Dst.IsValid() {
to = r.Dst.String()
}
return fmt.Sprintf("ip rule %d: from %s to %s table %d",
r.Priority, from, to, r.Table)
}
func NewRule() Rule {
// defaults found from netlink.NewRule() for fields we use,
// the rest of the defaults is set when converting from a `Rule`
// to a `netlink.Rule`
return Rule{
Priority: -1,
Mark: -1,
}
} }
func (n *NetLink) RuleList(family int) (rules []Rule, err error) { func (n *NetLink) RuleList(family int) (rules []Rule, err error) {
return netlink.RuleList(family) netlinkRules, err := netlink.RuleList(family)
if err != nil {
return nil, err
}
rules = make([]Rule, len(netlinkRules))
for i := range netlinkRules {
rules[i] = netlinkRuleToRule(netlinkRules[i])
}
return rules, nil
} }
func (n *NetLink) RuleAdd(rule *Rule) error { func (n *NetLink) RuleAdd(rule Rule) error {
return netlink.RuleAdd(rule) netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleAdd(&netlinkRule)
} }
func (n *NetLink) RuleDel(rule *Rule) error { func (n *NetLink) RuleDel(rule Rule) error {
return netlink.RuleDel(rule) netlinkRule := ruleToNetlinkRule(rule)
return netlink.RuleDel(&netlinkRule)
}
func ruleToNetlinkRule(rule Rule) (netlinkRule netlink.Rule) {
netlinkRule = *netlink.NewRule()
netlinkRule.Priority = rule.Priority
netlinkRule.Family = rule.Family
netlinkRule.Table = rule.Table
netlinkRule.Mark = rule.Mark
netlinkRule.Src = netipPrefixToIPNet(rule.Src)
netlinkRule.Dst = netipPrefixToIPNet(rule.Dst)
netlinkRule.Invert = rule.Invert
return netlinkRule
}
func netlinkRuleToRule(netlinkRule netlink.Rule) (rule Rule) {
return Rule{
Priority: netlinkRule.Priority,
Family: netlinkRule.Family,
Table: netlinkRule.Table,
Mark: netlinkRule.Mark,
Src: netIPNetToNetipPrefix(netlinkRule.Src),
Dst: netIPNetToNetipPrefix(netlinkRule.Dst),
Invert: netlinkRule.Invert,
}
} }

View File

@@ -6,34 +6,6 @@ import (
"net/netip" "net/netip"
) )
func NetipPrefixToIPNet(prefix *netip.Prefix) (ipNet *net.IPNet) {
if prefix == nil {
return nil
}
s := prefix.String()
ip, ipNet, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
ipNet.IP = ip
return ipNet
}
func netIPNetToNetipPrefix(ipNet net.IPNet) (prefix netip.Prefix) {
if len(ipNet.IP) != net.IPv4len && len(ipNet.IP) != net.IPv6len {
return prefix
}
var ip netip.Addr
if ipv4 := ipNet.IP.To4(); ipv4 != nil {
ip = netip.AddrFrom4([4]byte(ipv4))
} else {
ip = netip.AddrFrom16([16]byte(ipNet.IP))
}
bits, _ := ipNet.Mask.Size()
return netip.PrefixFrom(ip, bits)
}
func netIPToNetipAddress(ip net.IP) (address netip.Addr) { func netIPToNetipAddress(ip net.IP) (address netip.Addr) {
address, ok := netip.AddrFromSlice(ip) address, ok := netip.AddrFromSlice(ip)
if !ok { if !ok {

View File

@@ -8,54 +8,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_netIPNetToNetipPrefix(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
ipNet net.IPNet
prefix netip.Prefix
}{
"empty ipnet": {},
"custom sized IP in ipnet": {
ipNet: net.IPNet{
IP: net.IP{1},
},
},
"IPv4 ipnet": {
ipNet: net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.IPMask{255, 255, 255, 0},
},
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
"IPv4-in-IPv6 ipnet": {
ipNet: net.IPNet{
IP: net.IPv4(1, 2, 3, 4),
Mask: net.IPMask{255, 255, 255, 0},
},
prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 24),
},
"IPv6 ipnet": {
ipNet: net.IPNet{
IP: net.IPv6loopback,
Mask: net.IPMask{0xff},
},
prefix: netip.PrefixFrom(netip.IPv6Loopback(), 8),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
prefix := netIPNetToNetipPrefix(testCase.ipNet)
assert.Equal(t, testCase.prefix, prefix)
})
}
}
func Test_netIPToNetipAddress(t *testing.T) { func Test_netIPToNetipAddress(t *testing.T) {
t.Parallel() t.Parallel()

View File

@@ -25,17 +25,17 @@ func (d DefaultRoute) String() string {
} }
func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) { func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
if err != nil { if err != nil {
return nil, fmt.Errorf("listing routes: %w", err) return nil, fmt.Errorf("listing routes: %w", err)
} }
for _, route := range routes { for _, route := range routes {
if route.Dst != nil { if route.Dst.IsValid() {
continue continue
} }
defaultRoute := DefaultRoute{ defaultRoute := DefaultRoute{
Gateway: netIPToNetipAddress(route.Gw), Gateway: route.Gw,
Family: route.Family, Family: route.Family,
} }
linkIndex := route.LinkIndex linkIndex := route.LinkIndex
@@ -43,11 +43,10 @@ func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("obtaining link by index: for default route at index %d: %w", linkIndex, err) return nil, fmt.Errorf("obtaining link by index: for default route at index %d: %w", linkIndex, err)
} }
attributes := link.Attrs() defaultRoute.NetInterface = link.Name
defaultRoute.NetInterface = attributes.Name family := netlink.FamilyV6
family := netlink.FAMILY_V6 if route.Gw.Is4() {
if route.Gw.To4() != nil { family = netlink.FamilyV4
family = netlink.FAMILY_V4
} }
defaultRoute.AssignedIP, err = r.assignedIP(defaultRoute.NetInterface, family) defaultRoute.AssignedIP, err = r.assignedIP(defaultRoute.NetInterface, family)
if err != nil { if err != nil {

View File

@@ -23,7 +23,7 @@ func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err err
for _, defaultRoute := range defaultRoutes { for _, defaultRoute := range defaultRoutes {
defaultDestination := defaultDestinationIPv4 defaultDestination := defaultDestinationIPv4
if defaultRoute.Family == netlink.FAMILY_V6 { if defaultRoute.Family == netlink.FamilyV6 {
defaultDestination = defaultDestinationIPv6 defaultDestination = defaultDestinationIPv6
} }
@@ -43,7 +43,7 @@ func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err e
for _, defaultRoute := range defaultRoutes { for _, defaultRoute := range defaultRoutes {
defaultDestination := defaultDestinationIPv4 defaultDestination := defaultDestinationIPv4
if defaultRoute.Family == netlink.FAMILY_V6 { if defaultRoute.Family == netlink.FamilyV6 {
defaultDestination = defaultDestinationIPv6 defaultDestination = defaultDestinationIPv6
} }
@@ -68,8 +68,8 @@ func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
bits = 128 bits = 128
} }
defaultIPMasked := netip.PrefixFrom(assignedIP, bits) defaultIPMasked := netip.PrefixFrom(assignedIP, bits)
ruleDstNet := (*netip.Prefix)(nil) ruleDstNet := netip.Prefix{}
err = r.addIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority) err = r.addIPRule(defaultIPMasked, ruleDstNet, table, inboundPriority)
if err != nil { if err != nil {
return fmt.Errorf("adding rule for default route %s: %w", defaultRoute, err) return fmt.Errorf("adding rule for default route %s: %w", defaultRoute, err)
} }
@@ -86,8 +86,8 @@ func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRo
bits = 128 bits = 128
} }
defaultIPMasked := netip.PrefixFrom(assignedIP, bits) defaultIPMasked := netip.PrefixFrom(assignedIP, bits)
ruleDstNet := (*netip.Prefix)(nil) ruleDstNet := netip.Prefix{}
err = r.deleteIPRule(&defaultIPMasked, ruleDstNet, table, inboundPriority) err = r.deleteIPRule(defaultIPMasked, ruleDstNet, table, inboundPriority)
if err != nil { if err != nil {
return fmt.Errorf("deleting rule for default route %s: %w", defaultRoute, err) return fmt.Errorf("deleting rule for default route %s: %w", defaultRoute, err)
} }

View File

@@ -19,8 +19,8 @@ var (
) )
func ipMatchesFamily(ip netip.Addr, family int) bool { func ipMatchesFamily(ip netip.Addr, family int) bool {
return (family == netlink.FAMILY_V4 && ip.Is4()) || return (family == netlink.FamilyV4 && ip.Is4()) ||
(family == netlink.FAMILY_V6 && ip.Is6()) (family == netlink.FamilyV6 && ip.Is6())
} }
func (r *Routing) assignedIP(interfaceName string, family int) (ip netip.Addr, err error) { func (r *Routing) assignedIP(interfaceName string, family int) (ip netip.Addr, err error) {

View File

@@ -29,25 +29,25 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
localLinks := make(map[int]struct{}) localLinks := make(map[int]struct{})
for _, link := range links { for _, link := range links {
if link.Attrs().EncapType != "ether" { if link.EncapType != "ether" {
continue continue
} }
localLinks[link.Attrs().Index] = struct{}{} localLinks[link.Index] = struct{}{}
r.logger.Info("local ethernet link found: " + link.Attrs().Name) r.logger.Info("local ethernet link found: " + link.Name)
} }
if len(localLinks) == 0 { if len(localLinks) == 0 {
return localNetworks, fmt.Errorf("%w: in %d links", ErrLinkLocalNotFound, len(links)) return localNetworks, fmt.Errorf("%w: in %d links", ErrLinkLocalNotFound, len(links))
} }
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
if err != nil { if err != nil {
return localNetworks, fmt.Errorf("listing routes: %w", err) return localNetworks, fmt.Errorf("listing routes: %w", err)
} }
for _, route := range routes { for _, route := range routes {
if route.Gw != nil || route.Dst == nil { if route.Gw.IsValid() || !route.Dst.IsValid() {
continue continue
} else if _, ok := localLinks[route.LinkIndex]; !ok { } else if _, ok := localLinks[route.LinkIndex]; !ok {
continue continue
@@ -55,7 +55,7 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
var localNet LocalNetwork var localNet LocalNetwork
localNet.IPNet = netIPNetToNetipPrefix(*route.Dst) localNet.IPNet = route.Dst
r.logger.Info("local ipnet found: " + localNet.IPNet.String()) r.logger.Info("local ipnet found: " + localNet.IPNet.String())
link, err := r.netLinker.LinkByIndex(route.LinkIndex) link, err := r.netLinker.LinkByIndex(route.LinkIndex)
@@ -63,11 +63,11 @@ func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
return localNetworks, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err) return localNetworks, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
} }
localNet.InterfaceName = link.Attrs().Name localNet.InterfaceName = link.Name
family := netlink.FAMILY_V6 family := netlink.FamilyV6
if localNet.IPNet.Addr().Is4() { if localNet.IPNet.Addr().Is4() {
family = netlink.FAMILY_V4 family = netlink.FamilyV4
} }
ip, err := r.assignedIP(localNet.InterfaceName, family) ip, err := r.assignedIP(localNet.InterfaceName, family)
if err != nil { if err != nil {
@@ -96,7 +96,8 @@ func (r *Routing) AddLocalRules(subnets []LocalNetwork) (err error) {
const localPriority = 98 const localPriority = 98
// Main table was setup correctly by Docker, just need to add rules to use it // Main table was setup correctly by Docker, just need to add rules to use it
err = r.addIPRule(nil, &subnet.IPNet, mainTable, localPriority) src := netip.Prefix{}
err = r.addIPRule(src, subnet.IPNet, mainTable, localPriority)
if err != nil { if err != nil {
return fmt.Errorf("adding rule: %v: %w", subnet.IPNet, err) return fmt.Errorf("adding rule: %v: %w", subnet.IPNet, err)
} }

View File

@@ -8,7 +8,7 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
netlink "github.com/vishvananda/netlink" netlink "github.com/qdm12/gluetun/internal/netlink"
) )
// MockNetLinker is a mock of NetLinker interface. // MockNetLinker is a mock of NetLinker interface.
@@ -50,7 +50,7 @@ func (mr *MockNetLinkerMockRecorder) AddrList(arg0, arg1 interface{}) *gomock.Ca
} }
// AddrReplace mocks base method. // AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 *netlink.Addr) error { func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1) ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -79,11 +79,12 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
} }
// LinkAdd mocks base method. // LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error { func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0) ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(int)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// LinkAdd indicates an expected call of LinkAdd. // LinkAdd indicates an expected call of LinkAdd.
@@ -166,11 +167,12 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
} }
// LinkSetUp mocks base method. // LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error { func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0) ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(int)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// LinkSetUp indicates an expected call of LinkSetUp. // LinkSetUp indicates an expected call of LinkSetUp.
@@ -180,7 +182,7 @@ func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call {
} }
// RouteAdd mocks base method. // RouteAdd mocks base method.
func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error { func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteAdd", arg0) ret := m.ctrl.Call(m, "RouteAdd", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -194,7 +196,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
} }
// RouteDel mocks base method. // RouteDel mocks base method.
func (m *MockNetLinker) RouteDel(arg0 *netlink.Route) error { func (m *MockNetLinker) RouteDel(arg0 netlink.Route) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteDel", arg0) ret := m.ctrl.Call(m, "RouteDel", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -208,7 +210,7 @@ func (mr *MockNetLinkerMockRecorder) RouteDel(arg0 interface{}) *gomock.Call {
} }
// RouteList mocks base method. // RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) { func (m *MockNetLinker) RouteList(arg0 *netlink.Link, arg1 int) ([]netlink.Route, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0, arg1) ret := m.ctrl.Call(m, "RouteList", arg0, arg1)
ret0, _ := ret[0].([]netlink.Route) ret0, _ := ret[0].([]netlink.Route)
@@ -223,7 +225,7 @@ func (mr *MockNetLinkerMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.C
} }
// RouteReplace mocks base method. // RouteReplace mocks base method.
func (m *MockNetLinker) RouteReplace(arg0 *netlink.Route) error { func (m *MockNetLinker) RouteReplace(arg0 netlink.Route) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteReplace", arg0) ret := m.ctrl.Call(m, "RouteReplace", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -237,7 +239,7 @@ func (mr *MockNetLinkerMockRecorder) RouteReplace(arg0 interface{}) *gomock.Call
} }
// RuleAdd mocks base method. // RuleAdd mocks base method.
func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error { func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleAdd", arg0) ret := m.ctrl.Call(m, "RuleAdd", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -251,7 +253,7 @@ func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
} }
// RuleDel mocks base method. // RuleDel mocks base method.
func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error { func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleDel", arg0) ret := m.ctrl.Call(m, "RuleDel", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)

View File

@@ -56,8 +56,8 @@ func (r *Routing) removeOutboundSubnets(subnets []netip.Prefix,
} }
} }
ruleSrcNet := (*netip.Prefix)(nil) ruleSrcNet := netip.Prefix{}
ruleDstNet := &subnets[i] ruleDstNet := subnets[i]
err := r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority) err := r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
if err != nil { if err != nil {
warnings = append(warnings, warnings = append(warnings,
@@ -81,8 +81,8 @@ func (r *Routing) addOutboundSubnets(subnets []netip.Prefix,
} }
} }
ruleSrcNet := (*netip.Prefix)(nil) ruleSrcNet := netip.Prefix{}
ruleDstNet := &subnets[i] ruleDstNet := subnets[i]
err = r.addIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority) err = r.addIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
if err != nil { if err != nil {
return fmt.Errorf("adding rule: for subnet %s: %w", subnet, err) return fmt.Errorf("adding rule: for subnet %s: %w", subnet, err)

View File

@@ -23,12 +23,12 @@ func (r *Routing) addRouteVia(destination netip.Prefix, gateway netip.Addr,
} }
route := netlink.Route{ route := netlink.Route{
Dst: NetipPrefixToIPNet(&destination), Dst: destination,
Gw: gateway.AsSlice(), Gw: gateway,
LinkIndex: link.Attrs().Index, LinkIndex: link.Index,
Table: table, Table: table,
} }
if err := r.netLinker.RouteReplace(&route); err != nil { if err := r.netLinker.RouteReplace(route); err != nil {
return fmt.Errorf("replacing route for subnet %s at interface %s: %w", return fmt.Errorf("replacing route for subnet %s at interface %s: %w",
destinationStr, iface, err) destinationStr, iface, err)
} }
@@ -51,12 +51,12 @@ func (r *Routing) deleteRouteVia(destination netip.Prefix, gateway netip.Addr,
} }
route := netlink.Route{ route := netlink.Route{
Dst: NetipPrefixToIPNet(&destination), Dst: destination,
Gw: gateway.AsSlice(), Gw: gateway,
LinkIndex: link.Attrs().Index, LinkIndex: link.Index,
Table: table, Table: table,
} }
if err := r.netLinker.RouteDel(&route); err != nil { if err := r.netLinker.RouteDel(route); err != nil {
return fmt.Errorf("deleting route: for subnet %s at interface %s: %w", return fmt.Errorf("deleting route: for subnet %s at interface %s: %w",
destinationStr, iface, err) destinationStr, iface, err)
} }

View File

@@ -18,30 +18,30 @@ type NetLinker interface {
type Addresser interface { type Addresser interface {
AddrList(link netlink.Link, family int) ( AddrList(link netlink.Link, family int) (
addresses []netlink.Addr, err error) addresses []netlink.Addr, err error)
AddrReplace(link netlink.Link, addr *netlink.Addr) error AddrReplace(link netlink.Link, addr netlink.Addr) error
} }
type Router interface { type Router interface {
RouteList(link netlink.Link, family int) ( RouteList(link *netlink.Link, family int) (
routes []netlink.Route, err error) routes []netlink.Route, err error)
RouteAdd(route *netlink.Route) error RouteAdd(route netlink.Route) error
RouteDel(route *netlink.Route) error RouteDel(route netlink.Route) error
RouteReplace(route *netlink.Route) error RouteReplace(route netlink.Route) error
} }
type Ruler interface { type Ruler interface {
RuleList(family int) (rules []netlink.Rule, err error) RuleList(family int) (rules []netlink.Rule, err error)
RuleAdd(rule *netlink.Rule) error RuleAdd(rule netlink.Rule) error
RuleDel(rule *netlink.Rule) error RuleDel(rule netlink.Rule) error
} }
type Linker interface { type Linker interface {
LinkList() (links []netlink.Link, err error) LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error)
LinkByIndex(index int) (link netlink.Link, err error) LinkByIndex(index int) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (err error) LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error) LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (err error) LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error) LinkSetDown(link netlink.Link) (err error)
} }

View File

@@ -1,30 +1,28 @@
package routing package routing
import ( import (
"bytes"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
) )
func (r *Routing) addIPRule(src, dst *netip.Prefix, table, priority int) error { func (r *Routing) addIPRule(src, dst netip.Prefix, table, priority int) error {
const add = true const add = true
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority)) r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
rule := netlink.NewRule() rule := netlink.NewRule()
rule.Src = NetipPrefixToIPNet(src) rule.Src = src
rule.Dst = NetipPrefixToIPNet(dst) rule.Dst = dst
rule.Priority = priority rule.Priority = priority
rule.Table = table rule.Table = table
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil { if err != nil {
return fmt.Errorf("listing rules: %w", err) return fmt.Errorf("listing rules: %w", err)
} }
for i := range existingRules { for i := range existingRules {
if !rulesAreEqual(&existingRules[i], rule) { if !rulesAreEqual(existingRules[i], rule) {
continue continue
} }
return nil // already exists return nil // already exists
@@ -36,22 +34,22 @@ func (r *Routing) addIPRule(src, dst *netip.Prefix, table, priority int) error {
return nil return nil
} }
func (r *Routing) deleteIPRule(src, dst *netip.Prefix, table, priority int) error { func (r *Routing) deleteIPRule(src, dst netip.Prefix, table, priority int) error {
const add = false const add = false
r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority)) r.logger.Debug(ruleDbgMsg(add, src, dst, table, priority))
rule := netlink.NewRule() rule := netlink.NewRule()
rule.Src = NetipPrefixToIPNet(src) rule.Src = src
rule.Dst = NetipPrefixToIPNet(dst) rule.Dst = dst
rule.Priority = priority rule.Priority = priority
rule.Table = table rule.Table = table
existingRules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) existingRules, err := r.netLinker.RuleList(netlink.FamilyAll)
if err != nil { if err != nil {
return fmt.Errorf("listing rules: %w", err) return fmt.Errorf("listing rules: %w", err)
} }
for i := range existingRules { for i := range existingRules {
if !rulesAreEqual(&existingRules[i], rule) { if !rulesAreEqual(existingRules[i], rule) {
continue continue
} }
if err := r.netLinker.RuleDel(rule); err != nil { if err := r.netLinker.RuleDel(rule); err != nil {
@@ -61,7 +59,7 @@ func (r *Routing) deleteIPRule(src, dst *netip.Prefix, table, priority int) erro
return nil return nil
} }
func ruleDbgMsg(add bool, src, dst *netip.Prefix, func ruleDbgMsg(add bool, src, dst netip.Prefix,
table, priority int) (debugMessage string) { table, priority int) (debugMessage string) {
debugMessage = "ip rule" debugMessage = "ip rule"
@@ -71,11 +69,11 @@ func ruleDbgMsg(add bool, src, dst *netip.Prefix,
debugMessage += " del" debugMessage += " del"
} }
if src != nil { if src.IsValid() {
debugMessage += " from " + src.String() debugMessage += " from " + src.String()
} }
if dst != nil { if dst.IsValid() {
debugMessage += " to " + dst.String() debugMessage += " to " + dst.String()
} }
@@ -90,25 +88,20 @@ func ruleDbgMsg(add bool, src, dst *netip.Prefix,
return debugMessage return debugMessage
} }
func rulesAreEqual(a, b *netlink.Rule) bool { func rulesAreEqual(a, b netlink.Rule) bool {
if a == nil && b == nil { return ipPrefixesAreEqual(a.Src, b.Src) &&
return true ipPrefixesAreEqual(a.Dst, b.Dst) &&
}
if a == nil || b == nil {
return false
}
return ipNetsAreEqual(a.Src, b.Src) &&
ipNetsAreEqual(a.Dst, b.Dst) &&
a.Priority == b.Priority && a.Priority == b.Priority &&
a.Table == b.Table a.Table == b.Table
} }
func ipNetsAreEqual(a, b *net.IPNet) bool { func ipPrefixesAreEqual(a, b netip.Prefix) bool {
if a == nil && b == nil { if !a.IsValid() && !b.IsValid() {
return true return true
} }
if a == nil || b == nil { if !a.IsValid() || !b.IsValid() {
return false return false
} }
return a.IP.Equal(b.IP) && bytes.Equal(a.Mask, b.Mask) return a.Bits() == b.Bits() &&
a.Addr().Compare(b.Addr()) == 0
} }

View File

@@ -2,7 +2,6 @@ package routing
import ( import (
"errors" "errors"
"net"
"net/netip" "net/netip"
"testing" "testing"
@@ -12,17 +11,16 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func makeNetipPrefix(n byte) *netip.Prefix { func makeNetipPrefix(n byte) netip.Prefix {
const bits = 24 const bits = 24
prefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits) return netip.PrefixFrom(netip.AddrFrom4([4]byte{n, n, n, 0}), bits)
return &prefix
} }
func makeIPRule(src, dst *netip.Prefix, func makeIPRule(src, dst netip.Prefix,
table, priority int) *netlink.Rule { table, priority int) netlink.Rule {
rule := netlink.NewRule() rule := netlink.NewRule()
rule.Src = NetipPrefixToIPNet(src) rule.Src = src
rule.Dst = NetipPrefixToIPNet(dst) rule.Dst = dst
rule.Table = table rule.Table = table
rule.Priority = priority rule.Priority = priority
return rule return rule
@@ -40,13 +38,13 @@ func Test_Routing_addIPRule(t *testing.T) {
type ruleAddCall struct { type ruleAddCall struct {
expected bool expected bool
ruleToAdd *netlink.Rule ruleToAdd netlink.Rule
err error err error
} }
testCases := map[string]struct { testCases := map[string]struct {
src *netip.Prefix src netip.Prefix
dst *netip.Prefix dst netip.Prefix
table int table int
priority int priority int
dbgMsg string dbgMsg string
@@ -69,8 +67,8 @@ func Test_Routing_addIPRule(t *testing.T) {
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{ ruleList: ruleListCall{
rules: []netlink.Rule{ rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
}, },
}, },
}, },
@@ -95,8 +93,8 @@ func Test_Routing_addIPRule(t *testing.T) {
dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", dbgMsg: "ip rule add from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{ ruleList: ruleListCall{
rules: []netlink.Rule{ rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101), makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
}, },
}, },
ruleAdd: ruleAddCall{ ruleAdd: ruleAddCall{
@@ -116,7 +114,7 @@ func Test_Routing_addIPRule(t *testing.T) {
logger.EXPECT().Debug(testCase.dbgMsg) logger.EXPECT().Debug(testCase.dbgMsg)
netLinker := NewMockNetLinker(ctrl) netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().RuleList(netlink.FAMILY_ALL). netLinker.EXPECT().RuleList(netlink.FamilyAll).
Return(testCase.ruleList.rules, testCase.ruleList.err) Return(testCase.ruleList.rules, testCase.ruleList.err)
if testCase.ruleAdd.expected { if testCase.ruleAdd.expected {
netLinker.EXPECT().RuleAdd(testCase.ruleAdd.ruleToAdd). netLinker.EXPECT().RuleAdd(testCase.ruleAdd.ruleToAdd).
@@ -153,13 +151,13 @@ func Test_Routing_deleteIPRule(t *testing.T) {
type ruleDelCall struct { type ruleDelCall struct {
expected bool expected bool
ruleToDel *netlink.Rule ruleToDel netlink.Rule
err error err error
} }
testCases := map[string]struct { testCases := map[string]struct {
src *netip.Prefix src netip.Prefix
dst *netip.Prefix dst netip.Prefix
table int table int
priority int priority int
dbgMsg string dbgMsg string
@@ -182,7 +180,7 @@ func Test_Routing_deleteIPRule(t *testing.T) {
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{ ruleList: ruleListCall{
rules: []netlink.Rule{ rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
}, },
}, },
ruleDel: ruleDelCall{ ruleDel: ruleDelCall{
@@ -200,8 +198,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{ ruleList: ruleListCall{
rules: []netlink.Rule{ rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99), makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 99, 99),
}, },
}, },
ruleDel: ruleDelCall{ ruleDel: ruleDelCall{
@@ -217,8 +215,8 @@ func Test_Routing_deleteIPRule(t *testing.T) {
dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99", dbgMsg: "ip rule del from 1.1.1.0/24 to 2.2.2.0/24 lookup 99 pref 99",
ruleList: ruleListCall{ ruleList: ruleListCall{
rules: []netlink.Rule{ rules: []netlink.Rule{
*makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99), makeIPRule(makeNetipPrefix(2), makeNetipPrefix(2), 99, 99),
*makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101), makeIPRule(makeNetipPrefix(1), makeNetipPrefix(2), 101, 101),
}, },
}, },
}, },
@@ -234,7 +232,7 @@ func Test_Routing_deleteIPRule(t *testing.T) {
logger.EXPECT().Debug(testCase.dbgMsg) logger.EXPECT().Debug(testCase.dbgMsg)
netLinker := NewMockNetLinker(ctrl) netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT().RuleList(netlink.FAMILY_ALL). netLinker.EXPECT().RuleList(netlink.FamilyAll).
Return(testCase.ruleList.rules, testCase.ruleList.err) Return(testCase.ruleList.rules, testCase.ruleList.err)
if testCase.ruleDel.expected { if testCase.ruleDel.expected {
netLinker.EXPECT().RuleDel(testCase.ruleDel.ruleToDel). netLinker.EXPECT().RuleDel(testCase.ruleDel.ruleToDel).
@@ -264,8 +262,8 @@ func Test_ruleDbgMsg(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
add bool add bool
src *netip.Prefix src netip.Prefix
dst *netip.Prefix dst netip.Prefix
table int table int
priority int priority int
dbgMsg string dbgMsg string
@@ -307,38 +305,79 @@ func Test_rulesAreEqual(t *testing.T) {
t.Parallel() t.Parallel()
testCases := map[string]struct { testCases := map[string]struct {
a *netlink.Rule a netlink.Rule
b *netlink.Rule b netlink.Rule
equal bool equal bool
}{ }{
"both nil": { "both_empty": {
equal: true, equal: true,
}, },
"first nil": { "not_equal_by_src": {
b: &netlink.Rule{}, a: netlink.Rule{
}, Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 24),
"second nil": { Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
a: &netlink.Rule{},
},
"both not nil": {
a: &netlink.Rule{},
b: &netlink.Rule{},
equal: true,
},
"both equal": {
a: &netlink.Rule{
Src: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
Priority: 100, Priority: 100,
Table: 101, Table: 101,
}, },
b: &netlink.Rule{ b: netlink.Rule{
Src: &net.IPNet{ Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
IP: net.IPv4(1, 1, 1, 1), Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Mask: net.IPv4Mask(255, 255, 255, 0), Priority: 100,
Table: 101,
}, },
},
"not_equal_by_dst": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{9, 9, 9, 9}), 32),
Priority: 100,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"not_equal_by_priority": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 999,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"not_equal_by_table": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 999,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
},
"equal": {
a: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100,
Table: 101,
},
b: netlink.Rule{
Src: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Dst: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
Priority: 100, Priority: 100,
Table: 101, Table: 101,
}, },
@@ -358,58 +397,39 @@ func Test_rulesAreEqual(t *testing.T) {
} }
} }
func Test_ipNetsAreEqual(t *testing.T) { func Test_ipPrefixesAreEqual(t *testing.T) {
t.Parallel() t.Parallel()
testCases := map[string]struct { testCases := map[string]struct {
a *net.IPNet a netip.Prefix
b *net.IPNet b netip.Prefix
equal bool equal bool
}{ }{
"both nil": { "both_not_valid": {
equal: true, equal: true,
}, },
"first nil": { "first_not_valid": {
b: &net.IPNet{}, b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
}, },
"second nil": { "second_not_valid": {
a: &net.IPNet{}, a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
}, },
"both not nil": { "both_equal": {
a: &net.IPNet{}, a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: &net.IPNet{}, b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
equal: true, equal: true,
}, },
"both equal": { "both_not_equal_by_IP": {
a: &net.IPNet{ a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
IP: net.IPv4(1, 1, 1, 1), b: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 24),
Mask: net.IPv4Mask(255, 255, 255, 0),
}, },
b: &net.IPNet{ "both_not_equal_by_bits": {
IP: net.IPv4(1, 1, 1, 1), a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
Mask: net.IPv4Mask(255, 255, 255, 0), b: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32),
},
equal: true,
},
"both not equal by IP": {
a: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
b: &net.IPNet{
IP: net.IPv4(2, 2, 2, 2),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
},
"both not equal by mask": {
a: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
b: &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.IPv4Mask(255, 255, 0, 0),
}, },
"both_not_equal_by_IP_and_bits": {
a: netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 24),
b: netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32),
}, },
} }
@@ -418,7 +438,7 @@ func Test_ipNetsAreEqual(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
equal := ipNetsAreEqual(testCase.a, testCase.b) equal := ipPrefixesAreEqual(testCase.a, testCase.b)
assert.Equal(t, testCase.equal, equal) assert.Equal(t, testCase.equal, equal)
}) })

View File

@@ -1,10 +1,8 @@
package routing package routing
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
@@ -16,14 +14,14 @@ var (
) )
func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) { func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
if err != nil { if err != nil {
return ip, fmt.Errorf("listing routes: %w", err) return ip, fmt.Errorf("listing routes: %w", err)
} }
defaultLinkIndex := -1 defaultLinkIndex := -1
for _, route := range routes { for _, route := range routes {
if route.Dst == nil { if !route.Dst.IsValid() {
defaultLinkIndex = route.LinkIndex defaultLinkIndex = route.LinkIndex
break break
} }
@@ -34,17 +32,17 @@ func (r *Routing) VPNDestinationIP() (ip netip.Addr, err error) {
for _, route := range routes { for _, route := range routes {
if route.LinkIndex == defaultLinkIndex && if route.LinkIndex == defaultLinkIndex &&
route.Dst != nil && route.Dst.IsValid() &&
!IPIsPrivate(netIPToNetipAddress(route.Dst.IP)) && !IPIsPrivate(route.Dst.Addr()) &&
bytes.Equal(route.Dst.Mask, net.IPMask{255, 255, 255, 255}) { route.Dst.IsSingleIP() {
return netIPToNetipAddress(route.Dst.IP), nil return route.Dst.Addr(), nil
} }
} }
return ip, fmt.Errorf("%w: in %d routes", ErrVPNDestinationIPNotFound, len(routes)) return ip, fmt.Errorf("%w: in %d routes", ErrVPNDestinationIPNotFound, len(routes))
} }
func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) { func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) routes, err := r.netLinker.RouteList(nil, netlink.FamilyAll)
if err != nil { if err != nil {
return ip, fmt.Errorf("listing routes: %w", err) return ip, fmt.Errorf("listing routes: %w", err)
} }
@@ -53,11 +51,11 @@ func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip netip.Addr, err error) {
if err != nil { if err != nil {
return ip, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err) return ip, fmt.Errorf("finding link at index %d: %w", route.LinkIndex, err)
} }
interfaceName := link.Attrs().Name interfaceName := link.Name
if interfaceName == vpnIntf && if interfaceName == vpnIntf &&
route.Dst != nil && route.Dst.IsValid() &&
route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) { route.Dst.Addr().IsUnspecified() {
return netIPToNetipAddress(route.Gw), nil return route.Gw, nil
} }
} }
return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes)) return ip, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes))

View File

@@ -42,7 +42,7 @@ type Storage interface {
} }
type NetLinker interface { type NetLinker interface {
AddrReplace(link netlink.Link, addr *netlink.Addr) error AddrReplace(link netlink.Link, addr netlink.Addr) error
Router Router
Ruler Ruler
Linker Linker
@@ -50,22 +50,22 @@ type NetLinker interface {
} }
type Router interface { type Router interface {
RouteList(link netlink.Link, family int) ( RouteList(link *netlink.Link, family int) (
routes []netlink.Route, err error) routes []netlink.Route, err error)
RouteAdd(route *netlink.Route) error RouteAdd(route netlink.Route) error
} }
type Ruler interface { type Ruler interface {
RuleAdd(rule *netlink.Rule) error RuleAdd(rule netlink.Rule) error
RuleDel(rule *netlink.Rule) error RuleDel(rule netlink.Rule) error
} }
type Linker interface { type Linker interface {
LinkList() (links []netlink.Link, err error) LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error)
LinkAdd(link netlink.Link) (err error) LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkDel(link netlink.Link) (err error) LinkDel(link netlink.Link) (err error)
LinkSetUp(link netlink.Link) (err error) LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) (err error) LinkSetDown(link netlink.Link) (err error)
} }

View File

@@ -5,7 +5,6 @@ import (
"net/netip" "net/netip"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/routing"
) )
func (w *Wireguard) addAddresses(link netlink.Link, func (w *Wireguard) addAddresses(link netlink.Link,
@@ -15,15 +14,14 @@ func (w *Wireguard) addAddresses(link netlink.Link,
continue continue
} }
ipNet := ipNet address := netlink.Addr{
address := &netlink.Addr{ Network: ipNet,
IPNet: routing.NetipPrefixToIPNet(&ipNet),
} }
err = w.netlink.AddrReplace(link, address) err = w.netlink.AddrReplace(link, address)
if err != nil { if err != nil {
return fmt.Errorf("%w: when adding address %s to link %s", return fmt.Errorf("%w: when adding address %s to link %s",
err, address, link.Attrs().Name) err, address, link.Name)
} }
} }

View File

@@ -7,7 +7,6 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
"github.com/qdm12/gluetun/internal/routing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -18,14 +17,6 @@ func Test_Wireguard_addAddresses(t *testing.T) {
ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32) ipNetOne := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
ipNetTwo := netip.PrefixFrom(netip.MustParseAddr("::1234"), 64) ipNetTwo := netip.PrefixFrom(netip.MustParseAddr("::1234"), 64)
newLink := func() netlink.Link {
linkAttrs := netlink.NewLinkAttrs()
linkAttrs.Name = "a_bridge"
return &netlink.Bridge{
LinkAttrs: linkAttrs,
}
}
errDummy := errors.New("dummy") errDummy := errors.New("dummy")
testCases := map[string]struct { testCases := map[string]struct {
@@ -35,15 +26,15 @@ func Test_Wireguard_addAddresses(t *testing.T) {
err error err error
}{ }{
"success": { "success": {
link: newLink(), link: netlink.Link{Type: "wireguard"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo}, addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl) netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT(). firstCall := netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}). AddrReplace(link, netlink.Addr{Network: ipNetOne}).
Return(nil) Return(nil)
netLinker.EXPECT(). netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}). AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
Return(nil).After(firstCall) Return(nil).After(firstCall)
return &Wireguard{ return &Wireguard{
netlink: netLinker, netlink: netLinker,
@@ -54,12 +45,12 @@ func Test_Wireguard_addAddresses(t *testing.T) {
}, },
}, },
"first add error": { "first add error": {
link: newLink(), link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo}, addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl) netLinker := NewMockNetLinker(ctrl)
netLinker.EXPECT(). netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}). AddrReplace(link, netlink.Addr{Network: ipNetOne}).
Return(errDummy) Return(errDummy)
return &Wireguard{ return &Wireguard{
netlink: netLinker, netlink: netLinker,
@@ -71,15 +62,15 @@ func Test_Wireguard_addAddresses(t *testing.T) {
err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"), err: errors.New("dummy: when adding address 1.2.3.4/32 to link a_bridge"),
}, },
"second add error": { "second add error": {
link: newLink(), link: netlink.Link{Type: "wireguard", Name: "a_bridge"},
addrs: []netip.Prefix{ipNetOne, ipNetTwo}, addrs: []netip.Prefix{ipNetOne, ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
netLinker := NewMockNetLinker(ctrl) netLinker := NewMockNetLinker(ctrl)
firstCall := netLinker.EXPECT(). firstCall := netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetOne)}). AddrReplace(link, netlink.Addr{Network: ipNetOne}).
Return(nil) Return(nil)
netLinker.EXPECT(). netLinker.EXPECT().
AddrReplace(link, &netlink.Addr{IPNet: routing.NetipPrefixToIPNet(&ipNetTwo)}). AddrReplace(link, netlink.Addr{Network: ipNetTwo}).
Return(errDummy).After(firstCall) Return(errDummy).After(firstCall)
return &Wireguard{ return &Wireguard{
netlink: netLinker, netlink: netLinker,
@@ -91,7 +82,6 @@ func Test_Wireguard_addAddresses(t *testing.T) {
err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"), err: errors.New("dummy: when adding address ::1234/64 to link a_bridge"),
}, },
"ignore IPv6": { "ignore IPv6": {
link: newLink(),
addrs: []netip.Prefix{ipNetTwo}, addrs: []netip.Prefix{ipNetTwo},
wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard { wgBuilder: func(ctrl *gomock.Controller, link netlink.Link) *Wireguard {
return &Wireguard{ return &Wireguard{

View File

@@ -3,6 +3,7 @@ package wireguard
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -53,8 +54,14 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
PublicKey: publicKey, PublicKey: publicKey,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,
AllowedIPs: []net.IPNet{ AllowedIPs: []net.IPNet{
*allIPv4(), {
*allIPv6(), IP: net.IPv4(0, 0, 0, 0),
Mask: []byte{0, 0, 0, 0},
},
{
IP: net.IPv6zero,
Mask: []byte(net.IPv6zero),
},
}, },
ReplaceAllowedIPs: true, ReplaceAllowedIPs: true,
Endpoint: &net.UDPAddr{ Endpoint: &net.UDPAddr{
@@ -68,16 +75,12 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) {
return config, nil return config, nil
} }
func allIPv4() (ipNet *net.IPNet) { func allIPv4() (prefix netip.Prefix) {
return &net.IPNet{ const bits = 0
IP: net.IPv4(0, 0, 0, 0), return netip.PrefixFrom(netip.IPv4Unspecified(), bits)
Mask: []byte{0, 0, 0, 0},
}
} }
func allIPv6() (ipNet *net.IPNet) { func allIPv6() (prefix netip.Prefix) {
return &net.IPNet{ const bits = 0
IP: net.IPv6zero, return netip.PrefixFrom(netip.IPv6Unspecified(), bits)
Mask: []byte(net.IPv6zero),
}
} }

View File

@@ -24,10 +24,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
netlinker := netlink.New(&noopDebugLogger{}) netlinker := netlink.New(&noopDebugLogger{})
linkAttrs := netlink.NewLinkAttrs() link := netlink.Link{
linkAttrs.Name = "test_8081" Type: "bridge",
link := &netlink.Bridge{ Name: "test_8081",
LinkAttrs: linkAttrs,
} }
// Remove any previously created test interface from a crashed/panic // Remove any previously created test interface from a crashed/panic
@@ -37,8 +36,9 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
err = netlinker.LinkAdd(link) linkIndex, err := netlinker.LinkAdd(link)
require.NoError(t, err) require.NoError(t, err)
link.Index = linkIndex
defer func() { defer func() {
err = netlinker.LinkDel(link) err = netlinker.LinkDel(link)
@@ -63,14 +63,12 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) {
err = wg.addAddresses(link, addresses) err = wg.addAddresses(link, addresses)
require.NoError(t, err) require.NoError(t, err)
netlinkAddresses, err := netlinker.AddrList(link, netlink.FAMILY_ALL) netlinkAddresses, err := netlinker.AddrList(link, netlink.FamilyAll)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(addresses), len(netlinkAddresses)) require.Equal(t, len(addresses), len(netlinkAddresses))
for i, netlinkAddress := range netlinkAddresses { for i, netlinkAddress := range netlinkAddresses {
require.NotNil(t, netlinkAddress.IPNet) require.NotNil(t, netlinkAddress.Network)
ipNet, err := netip.ParsePrefix(netlinkAddress.IPNet.String()) assert.Equal(t, addresses[i], netlinkAddress.Network)
require.NoError(t, err)
assert.Equal(t, addresses[i], ipNet)
} }
} }
} }
@@ -95,7 +93,7 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
}() }()
rules, err := netlinker.RuleList(netlink.FAMILY_V4) rules, err := netlinker.RuleList(netlink.FamilyV4)
require.NoError(t, err) require.NoError(t, err)
var rule netlink.Rule var rule netlink.Rule
var ruleFound bool var ruleFound bool
@@ -111,11 +109,6 @@ func Test_netlink_Wireguard_addRule(t *testing.T) {
Priority: rulePriority, Priority: rulePriority,
Mark: firewallMark, Mark: firewallMark,
Table: firewallMark, Table: firewallMark,
Mask: 4294967295,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
} }
assert.Equal(t, expectedRule, rule) assert.Equal(t, expectedRule, rule)

View File

@@ -5,7 +5,7 @@ import "github.com/qdm12/gluetun/internal/netlink"
//go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker //go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker
type NetLinker interface { type NetLinker interface {
AddrReplace(link netlink.Link, addr *netlink.Addr) error AddrReplace(link netlink.Link, addr netlink.Addr) error
Router Router
Ruler Ruler
Linker Linker
@@ -13,21 +13,21 @@ type NetLinker interface {
} }
type Router interface { type Router interface {
RouteList(link netlink.Link, family int) ( RouteList(link *netlink.Link, family int) (
routes []netlink.Route, err error) routes []netlink.Route, err error)
RouteAdd(route *netlink.Route) error RouteAdd(route netlink.Route) error
} }
type Ruler interface { type Ruler interface {
RuleAdd(rule *netlink.Rule) error RuleAdd(rule netlink.Rule) error
RuleDel(rule *netlink.Rule) error RuleDel(rule netlink.Rule) error
} }
type Linker interface { type Linker interface {
LinkAdd(link netlink.Link) (err error) LinkAdd(link netlink.Link) (linkIndex int, err error)
LinkList() (links []netlink.Link, err error) LinkList() (links []netlink.Link, err error)
LinkByName(name string) (link netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error)
LinkSetUp(link netlink.Link) error LinkSetUp(link netlink.Link) (linkIndex int, err error)
LinkSetDown(link netlink.Link) error LinkSetDown(link netlink.Link) error
LinkDel(link netlink.Link) error LinkDel(link netlink.Link) error
} }

View File

@@ -8,7 +8,7 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
netlink "github.com/vishvananda/netlink" netlink "github.com/qdm12/gluetun/internal/netlink"
) )
// MockNetLinker is a mock of NetLinker interface. // MockNetLinker is a mock of NetLinker interface.
@@ -35,7 +35,7 @@ func (m *MockNetLinker) EXPECT() *MockNetLinkerMockRecorder {
} }
// AddrReplace mocks base method. // AddrReplace mocks base method.
func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 *netlink.Addr) error { func (m *MockNetLinker) AddrReplace(arg0 netlink.Link, arg1 netlink.Addr) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1) ret := m.ctrl.Call(m, "AddrReplace", arg0, arg1)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -64,11 +64,12 @@ func (mr *MockNetLinkerMockRecorder) IsWireguardSupported() *gomock.Call {
} }
// LinkAdd mocks base method. // LinkAdd mocks base method.
func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) error { func (m *MockNetLinker) LinkAdd(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkAdd", arg0) ret := m.ctrl.Call(m, "LinkAdd", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(int)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// LinkAdd indicates an expected call of LinkAdd. // LinkAdd indicates an expected call of LinkAdd.
@@ -136,11 +137,12 @@ func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call
} }
// LinkSetUp mocks base method. // LinkSetUp mocks base method.
func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error { func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) (int, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LinkSetUp", arg0) ret := m.ctrl.Call(m, "LinkSetUp", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(int)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// LinkSetUp indicates an expected call of LinkSetUp. // LinkSetUp indicates an expected call of LinkSetUp.
@@ -150,7 +152,7 @@ func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call {
} }
// RouteAdd mocks base method. // RouteAdd mocks base method.
func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error { func (m *MockNetLinker) RouteAdd(arg0 netlink.Route) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteAdd", arg0) ret := m.ctrl.Call(m, "RouteAdd", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -164,7 +166,7 @@ func (mr *MockNetLinkerMockRecorder) RouteAdd(arg0 interface{}) *gomock.Call {
} }
// RouteList mocks base method. // RouteList mocks base method.
func (m *MockNetLinker) RouteList(arg0 netlink.Link, arg1 int) ([]netlink.Route, error) { func (m *MockNetLinker) RouteList(arg0 *netlink.Link, arg1 int) ([]netlink.Route, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RouteList", arg0, arg1) ret := m.ctrl.Call(m, "RouteList", arg0, arg1)
ret0, _ := ret[0].([]netlink.Route) ret0, _ := ret[0].([]netlink.Route)
@@ -179,7 +181,7 @@ func (mr *MockNetLinkerMockRecorder) RouteList(arg0, arg1 interface{}) *gomock.C
} }
// RuleAdd mocks base method. // RuleAdd mocks base method.
func (m *MockNetLinker) RuleAdd(arg0 *netlink.Rule) error { func (m *MockNetLinker) RuleAdd(arg0 netlink.Rule) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleAdd", arg0) ret := m.ctrl.Call(m, "RuleAdd", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -193,7 +195,7 @@ func (mr *MockNetLinkerMockRecorder) RuleAdd(arg0 interface{}) *gomock.Call {
} }
// RuleDel mocks base method. // RuleDel mocks base method.
func (m *MockNetLinker) RuleDel(arg0 *netlink.Rule) error { func (m *MockNetLinker) RuleDel(arg0 netlink.Rule) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RuleDel", arg0) ret := m.ctrl.Call(m, "RuleDel", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)

View File

@@ -2,17 +2,17 @@ package wireguard
import ( import (
"fmt" "fmt"
"net" "net/netip"
"github.com/qdm12/gluetun/internal/netlink" "github.com/qdm12/gluetun/internal/netlink"
) )
// TODO add IPv6 route if IPv6 is supported // TODO add IPv6 route if IPv6 is supported
func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet, func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix,
firewallMark int) (err error) { firewallMark int) (err error) {
route := &netlink.Route{ route := netlink.Route{
LinkIndex: link.Attrs().Index, LinkIndex: link.Index,
Dst: dst, Dst: dst,
Table: firewallMark, Table: firewallMark,
} }
@@ -21,7 +21,7 @@ func (w *Wireguard) addRoute(link netlink.Link, dst *net.IPNet,
if err != nil { if err != nil {
return fmt.Errorf( return fmt.Errorf(
"adding route for link %s, destination %s and table %d: %w", "adding route for link %s, destination %s and table %d: %w",
link.Attrs().Name, dst, firewallMark, err) link.Name, dst, firewallMark, err)
} }
return err return err

View File

@@ -2,7 +2,7 @@ package wireguard
import ( import (
"errors" "errors"
"net" "net/netip"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@@ -15,41 +15,40 @@ func Test_Wireguard_addRoute(t *testing.T) {
t.Parallel() t.Parallel()
const linkIndex = 88 const linkIndex = 88
newLink := func() netlink.Link {
linkAttrs := netlink.NewLinkAttrs() ipPrefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 2, 3, 4}), 32)
linkAttrs.Name = "a_bridge"
linkAttrs.Index = linkIndex
return &netlink.Bridge{
LinkAttrs: linkAttrs,
}
}
ipNet := &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.IPv4Mask(255, 255, 255, 255)}
const firewallMark = 51820 const firewallMark = 51820
errDummy := errors.New("dummy") errDummy := errors.New("dummy")
testCases := map[string]struct { testCases := map[string]struct {
link netlink.Link link netlink.Link
dst *net.IPNet dst netip.Prefix
expectedRoute *netlink.Route expectedRoute netlink.Route
routeAddErr error routeAddErr error
err error err error
}{ }{
"success": { "success": {
link: newLink(), link: netlink.Link{
dst: ipNet, Index: linkIndex,
expectedRoute: &netlink.Route{ },
dst: ipPrefix,
expectedRoute: netlink.Route{
LinkIndex: linkIndex, LinkIndex: linkIndex,
Dst: ipNet, Dst: ipPrefix,
Table: firewallMark, Table: firewallMark,
}, },
}, },
"route add error": { "route add error": {
link: newLink(), link: netlink.Link{
dst: ipNet, Name: "a_bridge",
expectedRoute: &netlink.Route{ Index: linkIndex,
},
dst: ipPrefix,
expectedRoute: netlink.Route{
LinkIndex: linkIndex, LinkIndex: linkIndex,
Dst: ipNet, Dst: ipPrefix,
Table: firewallMark, Table: firewallMark,
}, },
routeAddErr: errDummy, routeAddErr: errDummy,

View File

@@ -21,53 +21,38 @@ func Test_Wireguard_addRule(t *testing.T) {
errDummy := errors.New("dummy") errDummy := errors.New("dummy")
testCases := map[string]struct { testCases := map[string]struct {
expectedRule *netlink.Rule expectedRule netlink.Rule
ruleAddErr error ruleAddErr error
err error err error
ruleDelErr error ruleDelErr error
cleanupErr error cleanupErr error
}{ }{
"success": { "success": {
expectedRule: &netlink.Rule{ expectedRule: netlink.Rule{
Invert: true, Invert: true,
Priority: rulePriority, Priority: rulePriority,
Mark: firewallMark, Mark: firewallMark,
Table: firewallMark, Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Family: family, Family: family,
}, },
}, },
"rule add error": { "rule add error": {
expectedRule: &netlink.Rule{ expectedRule: netlink.Rule{
Invert: true, Invert: true,
Priority: rulePriority, Priority: rulePriority,
Mark: firewallMark, Mark: firewallMark,
Table: firewallMark, Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Family: family, Family: family,
}, },
ruleAddErr: errDummy, ruleAddErr: errDummy,
err: errors.New("adding rule ip rule 987: from all to all table 456: dummy"), err: errors.New("adding rule ip rule 987: from all to all table 456: dummy"),
}, },
"rule delete error": { "rule delete error": {
expectedRule: &netlink.Rule{ expectedRule: netlink.Rule{
Invert: true, Invert: true,
Priority: rulePriority, Priority: rulePriority,
Mark: firewallMark, Mark: firewallMark,
Table: firewallMark, Table: firewallMark,
Mask: -1,
Goto: -1,
Flow: -1,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Family: family, Family: family,
}, },
ruleDelErr: errDummy, ruleDelErr: errDummy,

View File

@@ -93,10 +93,12 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan<
return return
} }
if err := w.netlink.LinkSetUp(link); err != nil { linkIndex, err := w.netlink.LinkSetUp(link)
if err != nil {
waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err) waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err)
return return
} }
link.Index = linkIndex
closers.add("shutting down link", stepFour, func() error { closers.add("shutting down link", stepFour, func() error {
return w.netlink.LinkSetDown(link) return w.netlink.LinkSetDown(link)
}) })
@@ -161,17 +163,16 @@ func setupKernelSpace(ctx context.Context,
interfaceName string, netLinker NetLinker, mtu uint16, interfaceName string, netLinker NetLinker, mtu uint16,
closers *closers, logger Logger) ( closers *closers, logger Logger) (
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) { link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
linkAttrs := netlink.LinkAttrs{ link = netlink.Link{
Type: "wireguard",
Name: interfaceName, Name: interfaceName,
MTU: int(mtu), MTU: mtu,
} }
link = &netlink.Wireguard{ linkIndex, err := netLinker.LinkAdd(link)
LinkAttrs: linkAttrs,
}
err = netLinker.LinkAdd(link)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrAddLink, err) return link, nil, fmt.Errorf("%w: %s", ErrAddLink, err)
} }
link.Index = linkIndex
closers.add("deleting link", stepFive, func() error { closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link) return netLinker.LinkDel(link)
}) })
@@ -191,22 +192,22 @@ func setupUserSpace(ctx context.Context,
link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) { link netlink.Link, waitAndCleanup waitAndCleanupFunc, err error) {
tun, err := tun.CreateTUN(interfaceName, int(mtu)) tun, err := tun.CreateTUN(interfaceName, int(mtu))
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrCreateTun, err) return link, nil, fmt.Errorf("%w: %s", ErrCreateTun, err)
} }
closers.add("closing TUN device", stepSeven, tun.Close) closers.add("closing TUN device", stepSeven, tun.Close)
tunName, err := tun.Name() tunName, err := tun.Name()
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err) return link, nil, fmt.Errorf("%w: cannot get TUN name: %s", ErrCreateTun, err)
} else if tunName != interfaceName { } else if tunName != interfaceName {
return nil, nil, fmt.Errorf("%w: names don't match: expected %q and got %q", return link, nil, fmt.Errorf("%w: names don't match: expected %q and got %q",
ErrCreateTun, interfaceName, tunName) ErrCreateTun, interfaceName, tunName)
} }
link, err = netLinker.LinkByName(interfaceName) link, err = netLinker.LinkByName(interfaceName)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err) return link, nil, fmt.Errorf("%w: %s: %s", ErrFindLink, interfaceName, err)
} }
closers.add("deleting link", stepFive, func() error { closers.add("deleting link", stepFive, func() error {
return netLinker.LinkDel(link) return netLinker.LinkDel(link)
@@ -226,14 +227,14 @@ func setupUserSpace(ctx context.Context,
uapiFile, err := ipc.UAPIOpen(interfaceName) uapiFile, err := ipc.UAPIOpen(interfaceName)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err) return link, nil, fmt.Errorf("%w: %s", ErrUAPISocketOpening, err)
} }
closers.add("closing UAPI file", stepThree, uapiFile.Close) closers.add("closing UAPI file", stepThree, uapiFile.Close)
uapiListener, err := ipc.UAPIListen(interfaceName, uapiFile) uapiListener, err := ipc.UAPIListen(interfaceName, uapiFile)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err) return link, nil, fmt.Errorf("%w: %s", ErrUAPIListen, err)
} }
closers.add("closing UAPI listener", stepTwo, uapiListener.Close) closers.add("closing UAPI listener", stepTwo, uapiListener.Close)