@@ -16,7 +16,7 @@ const (
|
||||
)
|
||||
|
||||
func (r *routing) Setup() (err error) {
|
||||
defaultIP, err := r.defaultIP()
|
||||
defaultIP, err := r.DefaultIP()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrSetup, err)
|
||||
}
|
||||
@@ -40,11 +40,19 @@ func (r *routing) Setup() (err error) {
|
||||
if err := r.addRouteVia(defaultDestination, defaultGateway, defaultInterfaceName, table); err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrSetup, err)
|
||||
}
|
||||
|
||||
r.stateMutex.RLock()
|
||||
outboundSubnets := r.outboundSubnets
|
||||
r.stateMutex.RUnlock()
|
||||
if err := r.setOutboundRoutes(outboundSubnets, defaultInterfaceName, defaultGateway); err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrSetup, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *routing) TearDown() error {
|
||||
defaultIP, err := r.defaultIP()
|
||||
defaultIP, err := r.DefaultIP()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrTeardown, err)
|
||||
}
|
||||
@@ -60,5 +68,10 @@ func (r *routing) TearDown() error {
|
||||
if err := r.deleteIPRule(defaultIP, table, priority); err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrTeardown, err)
|
||||
}
|
||||
|
||||
if err := r.setOutboundRoutes(nil, defaultInterfaceName, defaultGateway); err != nil {
|
||||
return fmt.Errorf("%s: %w", ErrSetup, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
58
internal/routing/outboundsubnets.go
Normal file
58
internal/routing/outboundsubnets.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
func (r *routing) SetOutboundRoutes(outboundSubnets []net.IPNet) error {
|
||||
defaultInterface, defaultGateway, err := r.DefaultRoute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot set oubtound subnets in routing: %w", err)
|
||||
}
|
||||
return r.setOutboundRoutes(outboundSubnets, defaultInterface, defaultGateway)
|
||||
}
|
||||
|
||||
func (r *routing) setOutboundRoutes(outboundSubnets []net.IPNet,
|
||||
defaultInterfaceName string, defaultGateway net.IP) error {
|
||||
r.stateMutex.Lock()
|
||||
defer r.stateMutex.Unlock()
|
||||
|
||||
subnetsToRemove := findSubnetsToRemove(r.outboundSubnets, outboundSubnets)
|
||||
subnetsToAdd := findSubnetsToAdd(r.outboundSubnets, outboundSubnets)
|
||||
|
||||
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.removeOutboundSubnets(subnetsToRemove, defaultInterfaceName, defaultGateway)
|
||||
if err := r.addOutboundSubnets(subnetsToAdd, defaultInterfaceName, defaultGateway); err != nil {
|
||||
return fmt.Errorf("cannot set outbound subnets in routing: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *routing) removeOutboundSubnets(subnets []net.IPNet,
|
||||
defaultInterfaceName string, defaultGateway net.IP) {
|
||||
for _, subnet := range subnets {
|
||||
const table = 0
|
||||
if err := r.deleteRouteVia(subnet, defaultGateway, defaultInterfaceName, table); err != nil {
|
||||
r.logger.Error("cannot remove outdated outbound subnet from routing: %s", err)
|
||||
continue
|
||||
}
|
||||
r.outboundSubnets = removeSubnetFromSubnets(r.outboundSubnets, subnet)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *routing) addOutboundSubnets(subnets []net.IPNet,
|
||||
defaultInterfaceName string, defaultGateway net.IP) error {
|
||||
for _, subnet := range subnets {
|
||||
const table = 0
|
||||
if err := r.addRouteVia(subnet, defaultGateway, defaultInterfaceName, table); err != nil {
|
||||
return fmt.Errorf("cannot add outbound subnet %s to routing: %w", subnet, err)
|
||||
}
|
||||
r.outboundSubnets = append(r.outboundSubnets, subnet)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -33,7 +33,7 @@ func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP
|
||||
return "", nil, fmt.Errorf("cannot find default route in %d routes", len(routes))
|
||||
}
|
||||
|
||||
func (r *routing) defaultIP() (ip net.IP, err error) {
|
||||
func (r *routing) DefaultIP() (ip net.IP, err error) {
|
||||
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot get default IP address: %w", err)
|
||||
|
||||
@@ -2,25 +2,35 @@ package routing
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type Routing interface {
|
||||
// Mutations
|
||||
Setup() (err error)
|
||||
TearDown() error
|
||||
SetOutboundRoutes(outboundSubnets []net.IPNet) error
|
||||
|
||||
// Read only
|
||||
DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error)
|
||||
LocalSubnet() (defaultSubnet net.IPNet, err error)
|
||||
DefaultIP() (defaultIP net.IP, err error)
|
||||
VPNDestinationIP() (ip net.IP, err error)
|
||||
VPNLocalGatewayIP() (ip net.IP, err error)
|
||||
|
||||
// Internal state
|
||||
SetVerbose(verbose bool)
|
||||
SetDebug()
|
||||
}
|
||||
|
||||
type routing struct {
|
||||
logger logging.Logger
|
||||
verbose bool
|
||||
debug bool
|
||||
logger logging.Logger
|
||||
verbose bool
|
||||
debug bool
|
||||
outboundSubnets []net.IPNet
|
||||
stateMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewConfigurator creates a new Configurator instance.
|
||||
|
||||
53
internal/routing/subnets.go
Normal file
53
internal/routing/subnets.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func findSubnetsToAdd(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd []net.IPNet) {
|
||||
for _, newSubnet := range newSubnets {
|
||||
found := false
|
||||
for _, oldSubnet := range oldSubnets {
|
||||
if subnetsAreEqual(oldSubnet, newSubnet) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
subnetsToAdd = append(subnetsToAdd, newSubnet)
|
||||
}
|
||||
}
|
||||
return subnetsToAdd
|
||||
}
|
||||
|
||||
func findSubnetsToRemove(oldSubnets, newSubnets []net.IPNet) (subnetsToRemove []net.IPNet) {
|
||||
for _, oldSubnet := range oldSubnets {
|
||||
found := false
|
||||
for _, newSubnet := range newSubnets {
|
||||
if subnetsAreEqual(oldSubnet, newSubnet) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
subnetsToRemove = append(subnetsToRemove, oldSubnet)
|
||||
}
|
||||
}
|
||||
return subnetsToRemove
|
||||
}
|
||||
|
||||
func subnetsAreEqual(a, b net.IPNet) bool {
|
||||
return a.IP.Equal(b.IP) && a.Mask.String() == b.Mask.String()
|
||||
}
|
||||
|
||||
func removeSubnetFromSubnets(subnets []net.IPNet, subnet net.IPNet) []net.IPNet {
|
||||
L := len(subnets)
|
||||
for i := range subnets {
|
||||
if subnetsAreEqual(subnet, subnets[i]) {
|
||||
subnets[i] = subnets[L-1]
|
||||
subnets = subnets[:L-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
return subnets
|
||||
}
|
||||
Reference in New Issue
Block a user