diff --git a/internal/routing/default.go b/internal/routing/default.go new file mode 100644 index 00000000..8315a468 --- /dev/null +++ b/internal/routing/default.go @@ -0,0 +1,68 @@ +package routing + +import ( + "errors" + "fmt" + "net" + + "github.com/qdm12/gluetun/internal/netlink" +) + +var ( + ErrRouteDefaultNotFound = errors.New("default route not found") +) + +type DefaultRouteGetter interface { + DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) +} + +func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) { + routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) + if err != nil { + return "", nil, fmt.Errorf("%w: %s", ErrRoutesList, err) + } + for _, route := range routes { + if route.Dst == nil { + defaultGateway = route.Gw + linkIndex := route.LinkIndex + link, err := r.netLinker.LinkByIndex(linkIndex) + if err != nil { + return "", nil, fmt.Errorf("%w: for default route at index %d: %s", ErrLinkByIndex, linkIndex, err) + } + attributes := link.Attrs() + defaultInterface = attributes.Name + r.logger.Info("default route found: interface " + defaultInterface + + ", gateway " + defaultGateway.String()) + return defaultInterface, defaultGateway, nil + } + } + return "", nil, fmt.Errorf("%w: in %d route(s)", ErrRouteDefaultNotFound, len(routes)) +} + +type DefaultIPGetter interface { + DefaultIP() (defaultIP net.IP, err error) +} + +func (r *Routing) DefaultIP() (ip net.IP, err error) { + routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) + } + + defaultLinkName := "" + for _, route := range routes { + if route.Dst == nil { + linkIndex := route.LinkIndex + link, err := r.netLinker.LinkByIndex(linkIndex) + if err != nil { + return nil, fmt.Errorf("%w: for default route at index %d: %s", ErrLinkByIndex, linkIndex, err) + } + defaultLinkName = link.Attrs().Name + } + } + if defaultLinkName == "" { + return nil, fmt.Errorf("%w: in %d route(s)", ErrLinkDefaultNotFound, len(routes)) + } + + return r.assignedIP(defaultLinkName) +} diff --git a/internal/routing/enable.go b/internal/routing/enable.go index f963360b..890ae534 100644 --- a/internal/routing/enable.go +++ b/internal/routing/enable.go @@ -17,6 +17,7 @@ var ( ErrIPRuleAdd = errors.New("cannot add IP rule") ErrIPRuleDelete = errors.New("cannot delete IP rule") ErrRouteAdd = errors.New("cannot add route") + ErrRouteDelete = errors.New("cannot delete route") ErrSubnetsOutboundSet = errors.New("cannot set outbound subnets routes") ) diff --git a/internal/routing/errors.go b/internal/routing/errors.go new file mode 100644 index 00000000..95cc7767 --- /dev/null +++ b/internal/routing/errors.go @@ -0,0 +1,11 @@ +package routing + +import ( + "errors" +) + +var ( + ErrLinkByIndex = errors.New("cannot obtain link by index") + ErrLinkDefaultNotFound = errors.New("default link not found") + ErrRoutesList = errors.New("cannot list routes") +) diff --git a/internal/routing/ip.go b/internal/routing/ip.go new file mode 100644 index 00000000..c9f3a7b6 --- /dev/null +++ b/internal/routing/ip.go @@ -0,0 +1,39 @@ +package routing + +import ( + "errors" + "fmt" + "net" +) + +func IPIsPrivate(ip net.IP) bool { + return ip.IsPrivate() || ip.IsLoopback() || + ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() +} + +var ( + errInterfaceIPNotFound = errors.New("IP address not found for interface") + errInterfaceListAddr = errors.New("cannot list interface addresses") + errInterfaceNotFound = errors.New("network interface not found") +) + +func (r *Routing) assignedIP(interfaceName string) (ip net.IP, err error) { + iface, err := net.InterfaceByName(interfaceName) + if err != nil { + return nil, fmt.Errorf("%w: %s: %s", errInterfaceNotFound, interfaceName, err) + } + addresses, err := iface.Addrs() + if err != nil { + return nil, fmt.Errorf("%w: %s: %s", errInterfaceListAddr, interfaceName, err) + } + for _, address := range addresses { + switch value := address.(type) { + case *net.IPAddr: + return value.IP, nil + case *net.IPNet: + return value.IP, nil + } + } + return nil, fmt.Errorf("%w: interface %s in %d addresses", + errInterfaceIPNotFound, interfaceName, len(addresses)) +} diff --git a/internal/routing/reader_test.go b/internal/routing/ip_test.go similarity index 100% rename from internal/routing/reader_test.go rename to internal/routing/ip_test.go diff --git a/internal/routing/local.go b/internal/routing/local.go new file mode 100644 index 00000000..7601e310 --- /dev/null +++ b/internal/routing/local.go @@ -0,0 +1,121 @@ +package routing + +import ( + "errors" + "fmt" + "net" + + "github.com/qdm12/gluetun/internal/netlink" +) + +var ( + ErrLinkList = errors.New("cannot list links") + ErrLinkLocalNotFound = errors.New("local link not found") + ErrSubnetDefaultNotFound = errors.New("default subnet not found") + ErrSubnetLocalNotFound = errors.New("local subnet not found") +) + +type LocalNetwork struct { + IPNet *net.IPNet + InterfaceName string + IP net.IP +} + +type LocalSubnetGetter interface { + LocalSubnet() (defaultSubnet net.IPNet, err error) +} + +func (r *Routing) LocalSubnet() (defaultSubnet net.IPNet, err error) { + routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) + if err != nil { + return defaultSubnet, fmt.Errorf("%w: %s", ErrRoutesList, err) + } + + defaultLinkIndex := -1 + for _, route := range routes { + if route.Dst == nil { + defaultLinkIndex = route.LinkIndex + break + } + } + if defaultLinkIndex == -1 { + return defaultSubnet, fmt.Errorf("%w: in %d route(s)", ErrLinkDefaultNotFound, len(routes)) + } + + for _, route := range routes { + if route.Gw != nil || route.LinkIndex != defaultLinkIndex { + continue + } + defaultSubnet = *route.Dst + r.logger.Info("local subnet found: " + defaultSubnet.String()) + return defaultSubnet, nil + } + + return defaultSubnet, fmt.Errorf("%w: in %d routes", ErrSubnetDefaultNotFound, len(routes)) +} + +type LocalNetworksGetter interface { + LocalNetworks() (localNetworks []LocalNetwork, err error) +} + +func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { + links, err := r.netLinker.LinkList() + if err != nil { + return localNetworks, fmt.Errorf("%w: %s", ErrLinkList, err) + } + + localLinks := make(map[int]struct{}) + + for _, link := range links { + if link.Attrs().EncapType != "ether" { + continue + } + + localLinks[link.Attrs().Index] = struct{}{} + r.logger.Info("local ethernet link found: " + link.Attrs().Name) + } + + if len(localLinks) == 0 { + return localNetworks, fmt.Errorf("%w: in %d links", ErrLinkLocalNotFound, len(links)) + } + + routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_V4) + if err != nil { + return localNetworks, fmt.Errorf("%w: %s", ErrRoutesList, err) + } + + for _, route := range routes { + if route.Gw != nil || route.Dst == nil { + continue + } else if _, ok := localLinks[route.LinkIndex]; !ok { + continue + } + + var localNet LocalNetwork + + localNet.IPNet = route.Dst + r.logger.Info("local ipnet found: " + localNet.IPNet.String()) + + link, err := r.netLinker.LinkByIndex(route.LinkIndex) + if err != nil { + return localNetworks, fmt.Errorf("%w: at index %d: %s", ErrLinkByIndex, route.LinkIndex, err) + } + + localNet.InterfaceName = link.Attrs().Name + + ip, err := r.assignedIP(localNet.InterfaceName) + if err != nil { + return localNetworks, err + } + + localNet.IP = ip + + localNetworks = append(localNetworks, localNet) + } + + if len(localNetworks) == 0 { + return localNetworks, fmt.Errorf("%w: in %d routes", ErrSubnetLocalNotFound, len(routes)) + } + + return localNetworks, nil +} diff --git a/internal/routing/mutate.go b/internal/routing/mutate.go deleted file mode 100644 index 6c41af20..00000000 --- a/internal/routing/mutate.go +++ /dev/null @@ -1,126 +0,0 @@ -package routing - -import ( - "bytes" - "errors" - "fmt" - "net" - "strconv" - - "github.com/qdm12/gluetun/internal/netlink" -) - -var ( - ErrRouteReplace = errors.New("cannot replace route") - ErrRouteDelete = errors.New("cannot delete route") - ErrRuleAdd = errors.New("cannot add routing rule") - ErrRuleDel = errors.New("cannot delete routing rule") -) - -func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP, iface string, table int) error { - destinationStr := destination.String() - r.logger.Info("adding route for " + destinationStr) - r.logger.Debug("ip route replace " + destinationStr + - " via " + gateway.String() + - " dev " + iface + - " table " + strconv.Itoa(table)) - - link, err := r.netLinker.LinkByName(iface) - if err != nil { - return fmt.Errorf("%w: interface %s: %s", ErrLinkByName, iface, err) - } - route := netlink.Route{ - Dst: &destination, - Gw: gateway, - LinkIndex: link.Attrs().Index, - Table: table, - } - if err := r.netLinker.RouteReplace(&route); err != nil { - return fmt.Errorf("%w: for subnet %s at interface %s: %s", - ErrRouteReplace, destinationStr, iface, err) - } - return nil -} - -func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP, iface string, table int) (err error) { - destinationStr := destination.String() - r.logger.Info("deleting route for " + destinationStr) - r.logger.Debug("ip route delete " + destinationStr + - " via " + gateway.String() + - " dev " + iface + - " table " + strconv.Itoa(table)) - - link, err := r.netLinker.LinkByName(iface) - if err != nil { - return fmt.Errorf("%w: for interface %s: %s", ErrLinkByName, iface, err) - } - route := netlink.Route{ - Dst: &destination, - Gw: gateway, - LinkIndex: link.Attrs().Index, - Table: table, - } - if err := r.netLinker.RouteDel(&route); err != nil { - return fmt.Errorf("%w: for subnet %s at interface %s: %s", - ErrRouteDelete, destinationStr, iface, err) - } - return nil -} - -func (r *Routing) addIPRule(src net.IP, table, priority int) error { - r.logger.Debug("ip rule add from " + src.String() + - " lookup " + strconv.Itoa(table) + - " pref " + strconv.Itoa(priority)) - - rule := netlink.NewRule() - rule.Src = netlink.NewIPNet(src) - rule.Priority = priority - rule.Table = table - - rules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) - if err != nil { - return fmt.Errorf("%w: %s", ErrRulesList, err) - } - for _, existingRule := range rules { - if existingRule.Src != nil && - existingRule.Src.IP.Equal(rule.Src.IP) && - bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) && - existingRule.Priority == rule.Priority && - existingRule.Table == rule.Table { - return nil // already exists - } - } - - if err := r.netLinker.RuleAdd(rule); err != nil { - return fmt.Errorf("%w: for rule %q: %s", ErrRuleAdd, rule, err) - } - return nil -} - -func (r *Routing) deleteIPRule(src net.IP, table, priority int) error { - r.logger.Debug("ip rule del from " + src.String() + - " lookup " + strconv.Itoa(table) + - " pref " + strconv.Itoa(priority)) - - rule := netlink.NewRule() - rule.Src = netlink.NewIPNet(src) - rule.Priority = priority - rule.Table = table - - rules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) - if err != nil { - return fmt.Errorf("%w: %s", ErrRulesList, err) - } - for _, existingRule := range rules { - if existingRule.Src != nil && - existingRule.Src.IP.Equal(rule.Src.IP) && - bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) && - existingRule.Priority == rule.Priority && - existingRule.Table == rule.Table { - if err := r.netLinker.RuleDel(rule); err != nil { - return fmt.Errorf("%w: for rule %q: %s", ErrRuleDel, rule, err) - } - } - } - return nil -} diff --git a/internal/routing/outboundsubnets.go b/internal/routing/outboundsubnets.go index 61a14f4c..804240b4 100644 --- a/internal/routing/outboundsubnets.go +++ b/internal/routing/outboundsubnets.go @@ -9,7 +9,7 @@ import ( ) var ( - ErrAddOutboundSubnet = errors.New("cannot add outbound subnet to routes") + errAddOutboundSubnet = errors.New("cannot add outbound subnet to routes") ) type OutboundRoutesSetter interface { @@ -25,7 +25,7 @@ func (r *Routing) SetOutboundRoutes(outboundSubnets []net.IPNet) error { } func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet, - defaultInterfaceName string, defaultGateway net.IP) error { + defaultInterfaceName string, defaultGateway net.IP) (err error) { r.stateMutex.Lock() defer r.stateMutex.Unlock() @@ -36,20 +36,32 @@ func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet, return nil } - r.removeOutboundSubnets(subnetsToRemove, defaultInterfaceName, defaultGateway) - return r.addOutboundSubnets(subnetsToAdd, defaultInterfaceName, defaultGateway) + warnings := r.removeOutboundSubnets(subnetsToRemove, defaultInterfaceName, defaultGateway) + for _, warning := range warnings { + r.logger.Warn("cannot remove outdated outbound subnet from routing: " + warning) + } + + err = r.addOutboundSubnets(subnetsToAdd, defaultInterfaceName, defaultGateway) + if err != nil { + return fmt.Errorf("%w: %s", errAddOutboundSubnet, err) + } + + return nil } func (r *Routing) removeOutboundSubnets(subnets []net.IPNet, - defaultInterfaceName string, defaultGateway net.IP) { + defaultInterfaceName string, defaultGateway net.IP) (warnings []string) { for _, subNet := range subnets { const table = 0 if err := r.deleteRouteVia(subNet, defaultGateway, defaultInterfaceName, table); err != nil { - r.logger.Error("cannot remove outdated outbound subnet from routing: " + err.Error()) + warnings = append(warnings, err.Error()) continue } + r.outboundSubnets = subnet.RemoveSubnetFromSubnets(r.outboundSubnets, subNet) } + + return warnings } func (r *Routing) addOutboundSubnets(subnets []net.IPNet, @@ -57,7 +69,7 @@ func (r *Routing) addOutboundSubnets(subnets []net.IPNet, for _, subnet := range subnets { const table = 0 if err := r.addRouteVia(subnet, defaultGateway, defaultInterfaceName, table); err != nil { - return fmt.Errorf("%w: %s: %s", ErrAddOutboundSubnet, subnet, err) + return fmt.Errorf("%w: for subnet %s", err, subnet) } r.outboundSubnets = append(r.outboundSubnets, subnet) } diff --git a/internal/routing/reader.go b/internal/routing/reader.go deleted file mode 100644 index d376dde4..00000000 --- a/internal/routing/reader.go +++ /dev/null @@ -1,270 +0,0 @@ -package routing - -import ( - "bytes" - "errors" - "fmt" - "net" - - "github.com/qdm12/gluetun/internal/netlink" -) - -type LocalNetwork struct { - IPNet *net.IPNet - InterfaceName string - IP net.IP -} - -var ( - ErrInterfaceIPNotFound = errors.New("IP address not found for interface") - ErrInterfaceListAddr = errors.New("cannot list interface addresses") - ErrInterfaceNotFound = errors.New("network interface not found") - ErrLinkByIndex = errors.New("cannot obtain link by index") - ErrLinkByName = errors.New("cannot obtain link by name") - ErrLinkDefaultNotFound = errors.New("default link not found") - ErrLinkList = errors.New("cannot list links") - ErrLinkLocalNotFound = errors.New("local link not found") - ErrRouteDefaultNotFound = errors.New("default route not found") - ErrRoutesList = errors.New("cannot list routes") - ErrRulesList = errors.New("cannot list rules") - ErrSubnetDefaultNotFound = errors.New("default subnet not found") - ErrSubnetLocalNotFound = errors.New("local subnet not found") - ErrVPNDestinationIPNotFound = errors.New("VPN destination IP address not found") - ErrVPNLocalGatewayIPNotFound = errors.New("VPN local gateway IP address not found") -) - -type DefaultRouteGetter interface { - DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) -} - -func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) { - routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) - if err != nil { - return "", nil, fmt.Errorf("%w: %s", ErrRoutesList, err) - } - for _, route := range routes { - if route.Dst == nil { - defaultGateway = route.Gw - linkIndex := route.LinkIndex - link, err := r.netLinker.LinkByIndex(linkIndex) - if err != nil { - return "", nil, fmt.Errorf("%w: for default route at index %d: %s", ErrLinkByIndex, linkIndex, err) - } - attributes := link.Attrs() - defaultInterface = attributes.Name - r.logger.Info("default route found: interface " + defaultInterface + - ", gateway " + defaultGateway.String()) - return defaultInterface, defaultGateway, nil - } - } - return "", nil, fmt.Errorf("%w: in %d route(s)", ErrRouteDefaultNotFound, len(routes)) -} - -type DefaultIPGetter interface { - DefaultIP() (defaultIP net.IP, err error) -} - -func (r *Routing) DefaultIP() (ip net.IP, err error) { - routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) - if err != nil { - return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) - } - - defaultLinkName := "" - for _, route := range routes { - if route.Dst == nil { - linkIndex := route.LinkIndex - link, err := r.netLinker.LinkByIndex(linkIndex) - if err != nil { - return nil, fmt.Errorf("%w: for default route at index %d: %s", ErrLinkByIndex, linkIndex, err) - } - defaultLinkName = link.Attrs().Name - } - } - if defaultLinkName == "" { - return nil, fmt.Errorf("%w: in %d route(s)", ErrLinkDefaultNotFound, len(routes)) - } - - return r.assignedIP(defaultLinkName) -} - -type LocalSubnetGetter interface { - LocalSubnet() (defaultSubnet net.IPNet, err error) -} - -func (r *Routing) LocalSubnet() (defaultSubnet net.IPNet, err error) { - routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) - if err != nil { - return defaultSubnet, fmt.Errorf("%w: %s", ErrRoutesList, err) - } - - defaultLinkIndex := -1 - for _, route := range routes { - if route.Dst == nil { - defaultLinkIndex = route.LinkIndex - break - } - } - if defaultLinkIndex == -1 { - return defaultSubnet, fmt.Errorf("%w: in %d route(s)", ErrLinkDefaultNotFound, len(routes)) - } - - for _, route := range routes { - if route.Gw != nil || route.LinkIndex != defaultLinkIndex { - continue - } - defaultSubnet = *route.Dst - r.logger.Info("local subnet found: " + defaultSubnet.String()) - return defaultSubnet, nil - } - - return defaultSubnet, fmt.Errorf("%w: in %d routes", ErrSubnetDefaultNotFound, len(routes)) -} - -type LocalNetworksGetter interface { - LocalNetworks() (localNetworks []LocalNetwork, err error) -} - -func (r *Routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { - links, err := r.netLinker.LinkList() - if err != nil { - return localNetworks, fmt.Errorf("%w: %s", ErrLinkList, err) - } - - localLinks := make(map[int]struct{}) - - for _, link := range links { - if link.Attrs().EncapType != "ether" { - continue - } - - localLinks[link.Attrs().Index] = struct{}{} - r.logger.Info("local ethernet link found: " + link.Attrs().Name) - } - - if len(localLinks) == 0 { - return localNetworks, fmt.Errorf("%w: in %d links", ErrLinkLocalNotFound, len(links)) - } - - routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_V4) - if err != nil { - return localNetworks, fmt.Errorf("%w: %s", ErrRoutesList, err) - } - - for _, route := range routes { - if route.Gw != nil || route.Dst == nil { - continue - } else if _, ok := localLinks[route.LinkIndex]; !ok { - continue - } - - var localNet LocalNetwork - - localNet.IPNet = route.Dst - r.logger.Info("local ipnet found: " + localNet.IPNet.String()) - - link, err := r.netLinker.LinkByIndex(route.LinkIndex) - if err != nil { - return localNetworks, fmt.Errorf("%w: at index %d: %s", ErrLinkByIndex, route.LinkIndex, err) - } - - localNet.InterfaceName = link.Attrs().Name - - ip, err := r.assignedIP(localNet.InterfaceName) - if err != nil { - return localNetworks, err - } - - localNet.IP = ip - - localNetworks = append(localNetworks, localNet) - } - - if len(localNetworks) == 0 { - return localNetworks, fmt.Errorf("%w: in %d routes", ErrSubnetLocalNotFound, len(routes)) - } - - return localNetworks, nil -} - -func (r *Routing) assignedIP(interfaceName string) (ip net.IP, err error) { - iface, err := net.InterfaceByName(interfaceName) - if err != nil { - return nil, fmt.Errorf("%w: %s: %s", ErrInterfaceNotFound, interfaceName, err) - } - addresses, err := iface.Addrs() - if err != nil { - return nil, fmt.Errorf("%w: %s: %s", ErrInterfaceListAddr, interfaceName, err) - } - for _, address := range addresses { - switch value := address.(type) { - case *net.IPAddr: - return value.IP, nil - case *net.IPNet: - return value.IP, nil - } - } - return nil, fmt.Errorf("%w: interface %s in %d addresses", - ErrInterfaceIPNotFound, interfaceName, len(addresses)) -} - -type VPNDestinationIPGetter interface { - VPNDestinationIP() (ip net.IP, err error) -} - -func (r *Routing) VPNDestinationIP() (ip net.IP, err error) { - routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) - if err != nil { - return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) - } - - defaultLinkIndex := -1 - for _, route := range routes { - if route.Dst == nil { - defaultLinkIndex = route.LinkIndex - break - } - } - if defaultLinkIndex == -1 { - return nil, fmt.Errorf("%w: in %d route(s)", ErrLinkDefaultNotFound, len(routes)) - } - - for _, route := range routes { - if route.LinkIndex == defaultLinkIndex && - route.Dst != nil && - !IPIsPrivate(route.Dst.IP) && - bytes.Equal(route.Dst.Mask, net.IPMask{255, 255, 255, 255}) { - return route.Dst.IP, nil - } - } - return nil, fmt.Errorf("%w: in %d routes", ErrVPNDestinationIPNotFound, len(routes)) -} - -type VPNLocalGatewayIPGetter interface { - VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) -} - -func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) { - routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) - if err != nil { - return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) - } - for _, route := range routes { - link, err := r.netLinker.LinkByIndex(route.LinkIndex) - if err != nil { - return nil, fmt.Errorf("%w: %s", ErrLinkByIndex, err) - } - interfaceName := link.Attrs().Name - if interfaceName == vpnIntf && - route.Dst != nil && - route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) { - return route.Gw, nil - } - } - return nil, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes)) -} - -func IPIsPrivate(ip net.IP) bool { - return ip.IsPrivate() || ip.IsLoopback() || - ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() -} diff --git a/internal/routing/routes.go b/internal/routing/routes.go new file mode 100644 index 00000000..df4f222e --- /dev/null +++ b/internal/routing/routes.go @@ -0,0 +1,70 @@ +package routing + +import ( + "errors" + "fmt" + "net" + "strconv" + + "github.com/qdm12/gluetun/internal/netlink" +) + +var ( + errLinkByName = errors.New("cannot obtain link by name") +) + +func (r *Routing) addRouteVia(destination net.IPNet, gateway net.IP, + iface string, table int) error { + destinationStr := destination.String() + r.logger.Info("adding route for " + destinationStr) + r.logger.Debug("ip route replace " + destinationStr + + " via " + gateway.String() + + " dev " + iface + + " table " + strconv.Itoa(table)) + + link, err := r.netLinker.LinkByName(iface) + if err != nil { + return fmt.Errorf("%w: interface %s: %s", errLinkByName, iface, err) + } + + route := netlink.Route{ + Dst: &destination, + Gw: gateway, + LinkIndex: link.Attrs().Index, + Table: table, + } + if err := r.netLinker.RouteReplace(&route); err != nil { + return fmt.Errorf("%w: for subnet %s at interface %s", + err, destinationStr, iface) + } + + return nil +} + +func (r *Routing) deleteRouteVia(destination net.IPNet, gateway net.IP, + iface string, table int) (err error) { + destinationStr := destination.String() + r.logger.Info("deleting route for " + destinationStr) + r.logger.Debug("ip route delete " + destinationStr + + " via " + gateway.String() + + " dev " + iface + + " table " + strconv.Itoa(table)) + + link, err := r.netLinker.LinkByName(iface) + if err != nil { + return fmt.Errorf("%w: for interface %s: %s", errLinkByName, iface, err) + } + + route := netlink.Route{ + Dst: &destination, + Gw: gateway, + LinkIndex: link.Attrs().Index, + Table: table, + } + if err := r.netLinker.RouteDel(&route); err != nil { + return fmt.Errorf("%w: for subnet %s at interface %s", + err, destinationStr, iface) + } + + return nil +} diff --git a/internal/routing/rules.go b/internal/routing/rules.go new file mode 100644 index 00000000..fe5aa291 --- /dev/null +++ b/internal/routing/rules.go @@ -0,0 +1,73 @@ +package routing + +import ( + "bytes" + "errors" + "fmt" + "net" + "strconv" + + "github.com/qdm12/gluetun/internal/netlink" +) + +var ( + errRulesList = errors.New("cannot list rules") +) + +func (r *Routing) addIPRule(src net.IP, table, priority int) error { + r.logger.Debug("ip rule add from " + src.String() + + " lookup " + strconv.Itoa(table) + + " pref " + strconv.Itoa(priority)) + + rule := netlink.NewRule() + rule.Src = netlink.NewIPNet(src) + rule.Priority = priority + rule.Table = table + + rules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) + if err != nil { + return fmt.Errorf("%w: %s", errRulesList, err) + } + for _, existingRule := range rules { + if existingRule.Src != nil && + existingRule.Src.IP.Equal(rule.Src.IP) && + bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) && + existingRule.Priority == rule.Priority && + existingRule.Table == rule.Table { + return nil // already exists + } + } + + if err := r.netLinker.RuleAdd(rule); err != nil { + return fmt.Errorf("%w: for rule: %s", err, rule) + } + return nil +} + +func (r *Routing) deleteIPRule(src net.IP, table, priority int) error { + r.logger.Debug("ip rule del from " + src.String() + + " lookup " + strconv.Itoa(table) + + " pref " + strconv.Itoa(priority)) + + rule := netlink.NewRule() + rule.Src = netlink.NewIPNet(src) + rule.Priority = priority + rule.Table = table + + rules, err := r.netLinker.RuleList(netlink.FAMILY_ALL) + if err != nil { + return fmt.Errorf("%w: %s", errRulesList, err) + } + for _, existingRule := range rules { + if existingRule.Src != nil && + existingRule.Src.IP.Equal(rule.Src.IP) && + bytes.Equal(existingRule.Src.Mask, rule.Src.Mask) && + existingRule.Priority == rule.Priority && + existingRule.Table == rule.Table { + if err := r.netLinker.RuleDel(rule); err != nil { + return fmt.Errorf("%w: for rule: %s", err, rule) + } + } + } + return nil +} diff --git a/internal/routing/vpn.go b/internal/routing/vpn.go new file mode 100644 index 00000000..908c262b --- /dev/null +++ b/internal/routing/vpn.go @@ -0,0 +1,71 @@ +package routing + +import ( + "bytes" + "errors" + "fmt" + "net" + + "github.com/qdm12/gluetun/internal/netlink" +) + +var ( + ErrVPNDestinationIPNotFound = errors.New("VPN destination IP address not found") + ErrVPNLocalGatewayIPNotFound = errors.New("VPN local gateway IP address not found") +) + +type VPNDestinationIPGetter interface { + VPNDestinationIP() (ip net.IP, err error) +} + +func (r *Routing) VPNDestinationIP() (ip net.IP, err error) { + routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) + } + + defaultLinkIndex := -1 + for _, route := range routes { + if route.Dst == nil { + defaultLinkIndex = route.LinkIndex + break + } + } + if defaultLinkIndex == -1 { + return nil, fmt.Errorf("%w: in %d route(s)", ErrLinkDefaultNotFound, len(routes)) + } + + for _, route := range routes { + if route.LinkIndex == defaultLinkIndex && + route.Dst != nil && + !IPIsPrivate(route.Dst.IP) && + bytes.Equal(route.Dst.Mask, net.IPMask{255, 255, 255, 255}) { + return route.Dst.IP, nil + } + } + return nil, fmt.Errorf("%w: in %d routes", ErrVPNDestinationIPNotFound, len(routes)) +} + +type VPNLocalGatewayIPGetter interface { + VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) +} + +func (r *Routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) { + routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) + } + for _, route := range routes { + link, err := r.netLinker.LinkByIndex(route.LinkIndex) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrLinkByIndex, err) + } + interfaceName := link.Attrs().Name + if interfaceName == vpnIntf && + route.Dst != nil && + route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) { + return route.Gw, nil + } + } + return nil, fmt.Errorf("%w: in %d routes", ErrVPNLocalGatewayIPNotFound, len(routes)) +}