Feature: uplift the 'localSubnet' concept to cover all local ethernet interfaces (#413)

This commit is contained in:
Michael Robbins
2021-04-10 03:08:20 +10:00
committed by GitHub
parent cc4117e054
commit 8230596f98
5 changed files with 87 additions and 11 deletions

View File

@@ -191,7 +191,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
return err return err
} }
localSubnet, err := routingConf.LocalSubnet() localNetworks, err := routingConf.LocalNetworks()
if err != nil { if err != nil {
return err return err
} }
@@ -201,7 +201,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
return err return err
} }
firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localSubnet, defaultIP) firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localNetworks, defaultIP)
if err := routingConf.Setup(); err != nil { if err := routingConf.Setup(); err != nil {
return err return err

View File

@@ -94,8 +94,10 @@ func (c *configurator) enable(ctx context.Context) (err error) {
return fmt.Errorf("cannot enable firewall: %w", err) return fmt.Errorf("cannot enable firewall: %w", err)
} }
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, c.localSubnet, remove); err != nil { for _, network := range c.localNetworks {
return fmt.Errorf("cannot enable firewall: %w", err) if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, network.Subnet, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
} }
for _, subnet := range c.outboundSubnets { for _, subnet := range c.outboundSubnets {
@@ -106,8 +108,10 @@ func (c *configurator) enable(ctx context.Context) (err error) {
// Allows packets from any IP address to go through eth0 / local network // Allows packets from any IP address to go through eth0 / local network
// to reach Gluetun. // to reach Gluetun.
if err := c.acceptInputToSubnet(ctx, c.defaultInterface, c.localSubnet, remove); err != nil { for _, network := range c.localNetworks {
return fmt.Errorf("cannot enable firewall: %w", err) if err := c.acceptInputToSubnet(ctx, network.InterfaceName, network.Subnet, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
} }
for port, intf := range c.allowedInputPorts { for port, intf := range c.allowedInputPorts {

View File

@@ -24,7 +24,7 @@ type Configurator interface {
RemoveAllowedPort(ctx context.Context, port uint16) (err error) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
SetDebug() SetDebug()
// SetNetworkInformation is meant to be called only once // SetNetworkInformation is meant to be called only once
SetNetworkInformation(defaultInterface string, defaultGateway net.IP, localSubnet net.IPNet, localIP net.IP) SetNetworkInformation(defaultInterface string, defaultGateway net.IP, localNetworks []routing.LocalNetwork, localIP net.IP)
} }
type configurator struct { //nolint:maligned type configurator struct { //nolint:maligned
@@ -36,7 +36,7 @@ type configurator struct { //nolint:maligned
debug bool debug bool
defaultInterface string defaultInterface string
defaultGateway net.IP defaultGateway net.IP
localSubnet net.IPNet localNetworks []routing.LocalNetwork
localIP net.IP localIP net.IP
networkInfoMutex sync.Mutex networkInfoMutex sync.Mutex
@@ -64,11 +64,11 @@ func (c *configurator) SetDebug() {
} }
func (c *configurator) SetNetworkInformation( func (c *configurator) SetNetworkInformation(
defaultInterface string, defaultGateway net.IP, localSubnet net.IPNet, localIP net.IP) { defaultInterface string, defaultGateway net.IP, localNetworks []routing.LocalNetwork, localIP net.IP) {
c.networkInfoMutex.Lock() c.networkInfoMutex.Lock()
defer c.networkInfoMutex.Unlock() defer c.networkInfoMutex.Unlock()
c.defaultInterface = defaultInterface c.defaultInterface = defaultInterface
c.defaultGateway = defaultGateway c.defaultGateway = defaultGateway
c.localSubnet = localSubnet c.localNetworks = localNetworks
c.localIP = localIP c.localIP = localIP
} }

View File

@@ -9,6 +9,12 @@ import (
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
) )
type LocalNetwork struct {
Subnet net.IPNet
InterfaceName string
IP net.IP
}
func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) { func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
if err != nil { if err != nil {
@@ -88,6 +94,72 @@ func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
return defaultSubnet, fmt.Errorf("cannot find default subnet in %d routes", len(routes)) return defaultSubnet, fmt.Errorf("cannot find default subnet in %d routes", len(routes))
} }
func (r *routing) LocalNetworks() (localNetworks []LocalNetwork, err error) {
links, err := netlink.LinkList()
if err != nil {
return localNetworks, fmt.Errorf("cannot find local subnet: %w", err)
}
localLinks := make(map[int]struct{})
for _, link := range links {
if link.Attrs().EncapType != "ether" {
continue
}
localLinks[link.Attrs().Index] = struct{}{}
if r.verbose {
r.logger.Info("local ethernet link found: %s", link.Attrs().Name)
}
}
if len(localLinks) == 0 {
return localNetworks, fmt.Errorf("cannot find any local interfaces")
}
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
return localNetworks, fmt.Errorf("cannot list local routes: %w", err)
}
for _, route := range routes {
if route.Gw != nil || route.Dst == nil {
continue
} else if _, ok := localLinks[route.LinkIndex]; !ok {
continue
}
var localNet LocalNetwork
localNet.Subnet = *route.Dst
if r.verbose {
r.logger.Info("local subnet found: %s", localNet.Subnet.String())
}
link, err := netlink.LinkByIndex(route.LinkIndex)
if err != nil {
return localNetworks, fmt.Errorf("cannot get link by index: %w", err)
}
localNet.InterfaceName = link.Attrs().Name
ip, err := r.assignedIP(localNet.InterfaceName)
if err != nil {
return localNetworks, fmt.Errorf("cannot get IP assigned to link: %w", err)
}
localNet.IP = ip
localNetworks = append(localNetworks, localNet)
}
if len(localNetworks) == 0 {
return localNetworks, fmt.Errorf("cannot find any local networks across %d routes", len(routes))
}
return localNetworks, nil
}
func (r *routing) assignedIP(interfaceName string) (ip net.IP, err error) { func (r *routing) assignedIP(interfaceName string) (ip net.IP, err error) {
iface, err := net.InterfaceByName(interfaceName) iface, err := net.InterfaceByName(interfaceName)
if err != nil { if err != nil {

View File

@@ -16,7 +16,7 @@ type Routing interface {
// Read only // Read only
DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error)
LocalSubnet() (defaultSubnet net.IPNet, err error) LocalNetworks() (localNetworks []LocalNetwork, err error)
DefaultIP() (defaultIP net.IP, err error) DefaultIP() (defaultIP net.IP, err error)
VPNDestinationIP() (ip net.IP, err error) VPNDestinationIP() (ip net.IP, err error)
VPNLocalGatewayIP() (ip net.IP, err error) VPNLocalGatewayIP() (ip net.IP, err error)