diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 526b6f85..9c8b4432 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -81,7 +81,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go alpineConf := alpine.NewConfigurator(fileManager) ovpnConf := openvpn.NewConfigurator(logger, fileManager) dnsConf := dns.NewConfigurator(logger, client, fileManager) - routingConf := routing.NewRouting(logger, fileManager) + routingConf := routing.NewRouting(logger) firewallConf := firewall.NewConfigurator(logger, routingConf, fileManager) tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger) streamMerger := command.NewStreamMerger() @@ -364,7 +364,6 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger, }) } -//nolint:gocognit func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dnsReadyCh <-chan struct{}, unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper, routing routing.Routing, logger logging.Logger, httpClient *http.Client, @@ -388,16 +387,11 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn tickerWg.Add(2) //nolint:gomnd go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg) go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg) - defaultInterface, _, err := routing.DefaultRoute() + vpnDestination, err := routing.VPNDestinationIP() if err != nil { logger.Warn(err) } else { - vpnDestination, err := routing.VPNDestinationIP(defaultInterface) - if err != nil { - logger.Warn(err) - } else { - logger.Info("VPN routing IP address: %s", vpnDestination) - } + logger.Info("VPN routing IP address: %s", vpnDestination) } if portForwardingEnabled { // TODO make instantaneous once v3 go out of service diff --git a/go.mod b/go.mod index 9d02aaf9..92ad7b1d 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/qdm12/golibs v0.0.0-20201018204514-1d5986880422 github.com/qdm12/ss-server v0.0.0-20200819124651-6428e626ee83 github.com/stretchr/testify v1.6.1 + github.com/vishvananda/netlink v1.1.0 golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0 golang.org/x/sys v0.0.0-20201018121011-98379d014ca7 ) diff --git a/go.sum b/go.sum index aaf25827..3e344919 100644 --- a/go.sum +++ b/go.sum @@ -87,6 +87,10 @@ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJy github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0= +github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= +github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df h1:OviZH7qLw/7ZovXvuNyL3XQl8UFofeikI1NW1Gypu7k= +github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= go.uber.org/atomic v1.5.0 h1:OI5t8sDa1Or+q8AeE+yKeB/SDYioSHAgcVljj9JIETY= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/multierr v1.3.0 h1:sFPn2GLc3poCkfrpIXGhBD2X0CMIo4Q/zSULXrj/+uc= @@ -116,6 +120,7 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201018121011-98379d014ca7 h1:CNOpL+H7PSxBI7dF/EIUsfOguRSzWp6CQ91yxZE6PG4= diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index 620caeb6..c87b5fc0 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -109,7 +109,7 @@ func (c *configurator) enable(ctx context.Context) (err error) { } // Re-ensure all routes exist for _, subnet := range c.allowedSubnets { - if err := c.routing.AddRouteVia(ctx, subnet, c.defaultGateway, c.defaultInterface); err != nil { + if err := c.routing.AddRouteVia(subnet, c.defaultGateway, c.defaultInterface); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } } diff --git a/internal/firewall/subnets.go b/internal/firewall/subnets.go index d5a77738..ef0c2fab 100644 --- a/internal/firewall/subnets.go +++ b/internal/firewall/subnets.go @@ -12,7 +12,7 @@ func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNe if !c.enabled { c.logger.Info("firewall disabled, only updating allowed subnets internal list and updating routes") - c.updateSubnetRoutes(ctx, c.allowedSubnets, subnets) + c.updateSubnetRoutes(c.allowedSubnets, subnets) c.allowedSubnets = make([]net.IPNet, len(subnets)) copy(c.allowedSubnets, subnets) return nil @@ -95,7 +95,7 @@ func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, d failed = true c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err) } - if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil { + if err := c.routing.DeleteRouteVia(subnet); err != nil { failed = true c.logger.Error("cannot remove outdated allowed subnet route: %s", err) } @@ -116,7 +116,7 @@ func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defa if err := c.acceptOutputFromSubnetToSubnet(ctx, defaultInterface, localSubnet, subnet, remove); err != nil { return fmt.Errorf("cannot add allowed subnet through firewall: %w", err) } - if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil { + if err := c.routing.AddRouteVia(subnet, defaultGateway, defaultInterface); err != nil { return fmt.Errorf("cannot add route for allowed subnet: %w", err) } c.allowedSubnets = append(c.allowedSubnets, subnet) @@ -125,19 +125,19 @@ func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defa } // updateSubnetRoutes does not return an error in order to try to run as many route commands as possible. -func (c *configurator) updateSubnetRoutes(ctx context.Context, oldSubnets, newSubnets []net.IPNet) { +func (c *configurator) updateSubnetRoutes(oldSubnets, newSubnets []net.IPNet) { subnetsToAdd := findSubnetsToAdd(oldSubnets, newSubnets) subnetsToRemove := findSubnetsToRemove(oldSubnets, newSubnets) if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 { return } for _, subnet := range subnetsToRemove { - if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil { + if err := c.routing.DeleteRouteVia(subnet); err != nil { c.logger.Error("cannot remove outdated route for subnet: %s", err) } } for _, subnet := range subnetsToAdd { - if err := c.routing.AddRouteVia(ctx, subnet, c.defaultGateway, c.defaultInterface); err != nil { + if err := c.routing.AddRouteVia(subnet, c.defaultGateway, c.defaultInterface); err != nil { c.logger.Error("cannot add route for subnet: %s", err) } } diff --git a/internal/routing/entry.go b/internal/routing/entry.go deleted file mode 100644 index 760be48d..00000000 --- a/internal/routing/entry.go +++ /dev/null @@ -1,95 +0,0 @@ -package routing - -import ( - "encoding/hex" - "fmt" - "net" - "strconv" - "strings" -) - -type routingEntry struct { - iface string - destination net.IP - gateway net.IP - flags string - refCount int - use int - metric int - mask net.IPMask - mtu int - window int - irtt int -} - -func parseRoutingEntry(s string) (r routingEntry, err error) { - wrapError := func(err error) error { - return fmt.Errorf("line %q: %w", s, err) - } - fields := strings.Fields(s) - const minFields = 11 - if len(fields) < minFields { - return r, wrapError(fmt.Errorf("not enough fields")) - } - r.iface = fields[0] - r.destination, err = reversedHexToIPv4(fields[1]) - if err != nil { - return r, wrapError(err) - } - r.gateway, err = reversedHexToIPv4(fields[2]) - if err != nil { - return r, wrapError(err) - } - r.flags = fields[3] - r.refCount, err = strconv.Atoi(fields[4]) - if err != nil { - return r, wrapError(err) - } - r.use, err = strconv.Atoi(fields[5]) - if err != nil { - return r, wrapError(err) - } - r.metric, err = strconv.Atoi(fields[6]) - if err != nil { - return r, wrapError(err) - } - r.mask, err = hexToIPv4Mask(fields[7]) - if err != nil { - return r, wrapError(err) - } - r.mtu, err = strconv.Atoi(fields[8]) - if err != nil { - return r, wrapError(err) - } - r.window, err = strconv.Atoi(fields[9]) - if err != nil { - return r, wrapError(err) - } - r.irtt, err = strconv.Atoi(fields[10]) - if err != nil { - return r, wrapError(err) - } - return r, nil -} - -func reversedHexToIPv4(reversedHex string) (ip net.IP, err error) { - bytes, err := hex.DecodeString(reversedHex) - const nBytesRequired = 4 - if err != nil { - return nil, fmt.Errorf("cannot parse reversed IP hex %q: %s", reversedHex, err) - } else if L := len(bytes); L != nBytesRequired { - return nil, fmt.Errorf("hex string contains %d bytes instead of %d", L, nBytesRequired) - } - return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil -} - -func hexToIPv4Mask(hexString string) (mask net.IPMask, err error) { - bytes, err := hex.DecodeString(hexString) - const nBytesRequired = 4 - if err != nil { - return nil, fmt.Errorf("cannot parse hex mask %q: %s", hexString, err) - } else if L := len(bytes); L != nBytesRequired { - return nil, fmt.Errorf("hex string contains %d bytes instead of %d", L, nBytesRequired) - } - return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil -} diff --git a/internal/routing/entry_test.go b/internal/routing/entry_test.go deleted file mode 100644 index 195eb003..00000000 --- a/internal/routing/entry_test.go +++ /dev/null @@ -1,164 +0,0 @@ -package routing - -import ( - "fmt" - "net" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -//nolint:lll -func Test_parseRoutingEntry(t *testing.T) { - t.Parallel() - tests := map[string]struct { - s string - r routingEntry - err error - }{ - "empty string": { - err: fmt.Errorf("line \"\": not enough fields"), - }, - "not enough fields": { - s: "a b c d e", - err: fmt.Errorf("line \"a b c d e\": not enough fields"), - }, - "bad destination": { - s: "eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0", - err: fmt.Errorf("line \"eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"), - }, - "bad gateway": { - s: "eth0 0002A8C0 x 0003 0 0 0 00FFFFFF 0 0 0", - err: fmt.Errorf("line \"eth0 0002A8C0 x 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"), - }, - "bad ref count": { - s: "eth0 0002A8C0 0100000A 0003 x 0 0 00FFFFFF 0 0 0", - err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 x 0 0 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"), - }, - "bad use": { - s: "eth0 0002A8C0 0100000A 0003 0 x 0 00FFFFFF 0 0 0", - err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 x 0 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"), - }, - "bad metric": { - s: "eth0 0002A8C0 0100000A 0003 0 0 x 00FFFFFF 0 0 0", - err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 x 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"), - }, - "bad mask": { - s: "eth0 0002A8C0 0100000A 0003 0 0 0 x 0 0 0", - err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 x 0 0 0\": cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'"), - }, - "bad mtu": { - s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF x 0 0", - err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF x 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"), - }, - "bad window": { - s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 x 0", - err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 x 0\": strconv.Atoi: parsing \"x\": invalid syntax"), - }, - "bad irtt": { - s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 x", - err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 x\": strconv.Atoi: parsing \"x\": invalid syntax"), - }, - "success": { - s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0", - r: routingEntry{ - iface: "eth0", - destination: net.IP{192, 168, 2, 0}, - gateway: net.IP{10, 0, 0, 1}, - flags: "0003", - mask: net.IPMask{255, 255, 255, 0}, - }, - }, - } - for name, tc := range tests { - tc := tc - t.Run(name, func(t *testing.T) { - t.Parallel() - r, err := parseRoutingEntry(tc.s) - if tc.err != nil { - require.Error(t, err) - assert.Equal(t, tc.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - assert.Equal(t, tc.r, r) - } - }) - } -} - -func Test_reversedHexToIPv4(t *testing.T) { - t.Parallel() - tests := map[string]struct { - reversedHex string - IP net.IP - err error - }{ - "empty hex": { - err: fmt.Errorf("hex string contains 0 bytes instead of 4")}, - "bad hex": { - reversedHex: "x", - err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")}, - "3 bytes hex": { - reversedHex: "9abcde", - err: fmt.Errorf("hex string contains 3 bytes instead of 4")}, - "correct hex": { - reversedHex: "010011AC", - IP: []byte{0xac, 0x11, 0x0, 0x1}, - err: nil}, - "correct hex 2": { - reversedHex: "000011AC", - IP: []byte{0xac, 0x11, 0x0, 0x0}, - err: nil}, - } - for name, tc := range tests { - tc := tc - t.Run(name, func(t *testing.T) { - t.Parallel() - IP, err := reversedHexToIPv4(tc.reversedHex) - if tc.err != nil { - require.Error(t, err) - assert.Equal(t, tc.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tc.IP, IP) - }) - } -} - -func Test_hexMaskToDecMask(t *testing.T) { - t.Parallel() - tests := map[string]struct { - hexString string - mask net.IPMask - err error - }{ - "empty hex": { - err: fmt.Errorf("hex string contains 0 bytes instead of 4")}, - "bad hex": { - hexString: "x", - err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")}, - "3 bytes hex": { - hexString: "9abcde", - err: fmt.Errorf("hex string contains 3 bytes instead of 4")}, - "16": { - hexString: "0000FFFF", - mask: []byte{0xff, 0xff, 0x0, 0x0}, - err: nil}, - } - for name, tc := range tests { - tc := tc - t.Run(name, func(t *testing.T) { - t.Parallel() - mask, err := hexToIPv4Mask(tc.hexString) - if tc.err != nil { - require.Error(t, err) - assert.Equal(t, tc.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tc.mask, mask) - }) - } -} diff --git a/internal/routing/mutate.go b/internal/routing/mutate.go index 0cc3fc56..232d2350 100644 --- a/internal/routing/mutate.go +++ b/internal/routing/mutate.go @@ -1,48 +1,45 @@ package routing import ( - "context" "fmt" "net" + + "github.com/vishvananda/netlink" ) -func (r *routing) AddRouteVia(ctx context.Context, - subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error { - subnetStr := subnet.String() - r.logger.Info("adding %s as route via %s %s", subnetStr, defaultGateway, defaultInterface) - exists, err := r.routeExists(subnet) - if err != nil { - return err - } else if exists { - return nil - } +func (r *routing) AddRouteVia(destination net.IPNet, gateway net.IP, iface string) error { + destinationStr := destination.String() + r.logger.Info("adding route for %s", destinationStr) if r.debug { - fmt.Printf("ip route add %s via %s dev %s\n", subnetStr, defaultGateway, defaultInterface) + fmt.Printf("ip route add %s via %s dev %s\n", destinationStr, gateway, iface) } - output, err := r.commander.Run(ctx, - "ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface) + + link, err := netlink.LinkByName(iface) if err != nil { - return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", - subnetStr, defaultGateway, "dev", defaultInterface, output, err) + return fmt.Errorf("cannot add route for %s: %w", destinationStr, err) + } + route := netlink.Route{ + Dst: &destination, + Gw: gateway, + LinkIndex: link.Attrs().Index, + } + if err := netlink.RouteReplace(&route); err != nil { + return fmt.Errorf("cannot add route for %s: %w", destinationStr, err) } return nil } -func (r *routing) DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error) { - subnetStr := subnet.String() - r.logger.Info("deleting route for %s", subnetStr) - exists, err := r.routeExists(subnet) - if err != nil { - return err - } else if !exists { // thanks to @npawelek https://github.com/npawelek - return nil - } +func (r *routing) DeleteRouteVia(destination net.IPNet) (err error) { + destinationStr := destination.String() + r.logger.Info("deleting route for %s", destinationStr) if r.debug { - fmt.Printf("ip route del %s\n", subnetStr) + fmt.Printf("ip route del %s\n", destinationStr) } - output, err := r.commander.Run(ctx, "ip", "route", "del", subnetStr) - if err != nil { - return fmt.Errorf("cannot delete route for %s: %s: %w", subnetStr, output, err) + route := netlink.Route{ + Dst: &destination, + } + if err := netlink.RouteDel(&route); err != nil { + return fmt.Errorf("cannot delete route for %s: %w", destinationStr, err) } return nil } diff --git a/internal/routing/mutate_test.go b/internal/routing/mutate_test.go deleted file mode 100644 index 53d9f08e..00000000 --- a/internal/routing/mutate_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package routing - -import ( - "context" - "fmt" - "net" - "testing" - - "github.com/golang/mock/gomock" - "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/golibs/command/mock_command" - "github.com/qdm12/golibs/files/mock_files" - "github.com/qdm12/golibs/logging/mock_logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func Test_DeleteRouteVia(t *testing.T) { - t.Parallel() - ctx := context.Background() - tests := map[string]struct { - subnet net.IPNet - runOutput string - runErr error - err error - }{ - "no output no error": { - subnet: net.IPNet{ - IP: net.IP{192, 168, 2, 0}, - Mask: net.IPMask{255, 255, 255, 0}, - }, - }, - "error only": { - subnet: net.IPNet{ - IP: net.IP{192, 168, 2, 0}, - Mask: net.IPMask{255, 255, 255, 0}, - }, - runErr: fmt.Errorf("error"), - err: fmt.Errorf("cannot delete route for 192.168.2.0/24: : error"), - }, - "error and output": { - subnet: net.IPNet{ - IP: net.IP{192, 168, 2, 0}, - Mask: net.IPMask{255, 255, 255, 0}, - }, - runErr: fmt.Errorf("error"), - runOutput: "output", - err: fmt.Errorf("cannot delete route for 192.168.2.0/24: output: error"), - }, - } - for name, tc := range tests { - tc := tc - t.Run(name, func(t *testing.T) { - t.Parallel() - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - subnetStr := tc.subnet.String() - - logger := mock_logging.NewMockLogger(mockCtrl) - logger.EXPECT().Info("deleting route for %s") - commander := mock_command.NewMockCommander(mockCtrl) - commander.EXPECT().Run(ctx, "ip", "route", "del", subnetStr). - Return(tc.runOutput, tc.runErr).Times(1) - fileManager := mock_files.NewMockFileManager(mockCtrl) - //nolint:lll - routesData := []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0 -`) - fileManager.EXPECT().ReadFile(string(constants.NetRoute)).Return(routesData, nil) - r := &routing{ - logger: logger, - commander: commander, - fileManager: fileManager, - } - - err := r.DeleteRouteVia(ctx, tc.subnet) - if tc.err != nil { - require.Error(t, err) - assert.Equal(t, tc.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - }) - } -} diff --git a/internal/routing/reader.go b/internal/routing/reader.go index 4a6d175b..3799db75 100644 --- a/internal/routing/reader.go +++ b/internal/routing/reader.go @@ -4,120 +4,105 @@ import ( "bytes" "fmt" "net" - "strings" "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/golibs/files" + "github.com/vishvananda/netlink" ) -func parseRoutingTable(data []byte) (entries []routingEntry, err error) { - lines := strings.Split(strings.TrimSuffix(string(data), "\n"), "\n") - lines = lines[1:] - entries = make([]routingEntry, len(lines)) - for i := range lines { - entries[i], err = parseRoutingEntry(lines[i]) - if err != nil { - return nil, fmt.Errorf("line %d in %s: %w", i+1, constants.NetRoute, err) - } - } - return entries, nil -} - -func getRoutingEntries(fileManager files.FileManager) (entries []routingEntry, err error) { - data, err := fileManager.ReadFile(string(constants.NetRoute)) - if err != nil { - return nil, err - } - return parseRoutingTable(data) -} - func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) { - entries, err := getRoutingEntries(r.fileManager) + routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) if err != nil { - return "", nil, err + return "", nil, fmt.Errorf("cannot list routes: %w", err) } - const minEntries = 2 - if len(entries) < minEntries { - return "", nil, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute) - } - var defaultRouteEntry routingEntry - for _, entry := range entries { - if entry.mask.String() == "00000000" { - defaultRouteEntry = entry - break + for _, route := range routes { + if route.Dst == nil { + defaultGateway = route.Gw + linkIndex := route.LinkIndex + link, err := netlink.LinkByIndex(linkIndex) + if err != nil { + return "", nil, fmt.Errorf("cannot obtain link with index %d for default route: %w", linkIndex, err) + } + attributes := link.Attrs() + defaultInterface = attributes.Name + r.logger.Info("default route found: interface %s, gateway %s", defaultInterface, defaultGateway.String()) + return defaultInterface, defaultGateway, nil } } - if defaultRouteEntry.iface == "" { - return "", nil, fmt.Errorf("cannot find default route") - } - defaultInterface = defaultRouteEntry.iface - defaultGateway = defaultRouteEntry.gateway - r.logger.Info("default route found: interface %s, gateway %s", defaultInterface, defaultGateway.String()) - return defaultInterface, defaultGateway, nil + return "", nil, fmt.Errorf("cannot find default route in %d routes", len(routes)) } func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) { - entries, err := getRoutingEntries(r.fileManager) + routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) if err != nil { - return defaultSubnet, err + return defaultSubnet, fmt.Errorf("cannot find local subnet: %w", err) } - const minEntries = 2 - if len(entries) < minEntries { - return defaultSubnet, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute) - } - var localSubnetEntry routingEntry - for _, entry := range entries { - if entry.gateway.Equal(net.IP{0, 0, 0, 0}) && !strings.HasPrefix(entry.iface, "tun") { - localSubnetEntry = entry + + defaultLinkIndex := -1 + for _, route := range routes { + if route.Dst == nil { + defaultLinkIndex = route.LinkIndex break } } - if localSubnetEntry.iface == "" { - return defaultSubnet, fmt.Errorf("cannot find local subnet route") + if defaultLinkIndex == -1 { + return defaultSubnet, fmt.Errorf("cannot find local subnet: cannot find default link") } - defaultSubnet = net.IPNet{IP: localSubnetEntry.destination, Mask: localSubnetEntry.mask} - r.logger.Info("local subnet found: %s", defaultSubnet.String()) - return defaultSubnet, nil + + for _, route := range routes { + if route.Gw != nil || route.LinkIndex != defaultLinkIndex { + continue + } + defaultSubnet = *route.Dst + r.logger.Info("local subnet found: %s", defaultSubnet.String()) + return defaultSubnet, nil + } + + return defaultSubnet, fmt.Errorf("cannot find default subnet in %d routes", len(routes)) } -func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) { - entries, err := getRoutingEntries(r.fileManager) +func (r *routing) VPNDestinationIP() (ip net.IP, err error) { + routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) if err != nil { - return false, fmt.Errorf("cannot check route existence: %w", err) + return nil, fmt.Errorf("cannot find VPN destination IP: %w", err) } - for _, entry := range entries { - entrySubnet := net.IPNet{IP: entry.destination, Mask: entry.mask} - if entrySubnet.String() == subnet.String() { - return true, nil - } - } - return false, nil -} -func (r *routing) VPNDestinationIP(defaultInterface string) (ip net.IP, err error) { - entries, err := getRoutingEntries(r.fileManager) - if err != nil { - return nil, fmt.Errorf("cannot find VPN gateway IP address: %w", err) - } - for _, entry := range entries { - if entry.iface == defaultInterface && - !ipIsPrivate(entry.destination) && - bytes.Equal(entry.mask, net.IPMask{255, 255, 255, 255}) { - return entry.destination, nil + defaultLinkIndex := -1 + for _, route := range routes { + if route.Dst == nil { + defaultLinkIndex = route.LinkIndex + break } } - return nil, fmt.Errorf("cannot find VPN gateway IP address from ip routes") + if defaultLinkIndex == -1 { + return nil, fmt.Errorf("cannot find VPN destination IP: cannot find default link") + } + + for _, route := range routes { + if route.LinkIndex == defaultLinkIndex && + route.Dst != nil && + !ipIsPrivate(route.Dst.IP) && + bytes.Equal(route.Dst.Mask, net.IPMask{255, 255, 255, 255}) { + return route.Dst.IP, nil + } + } + return nil, fmt.Errorf("cannot find VPN destination IP address from ip routes") } func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) { - entries, err := getRoutingEntries(r.fileManager) + routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) if err != nil { - return nil, fmt.Errorf("cannot find VPN local gateway IP address: %w", err) + return nil, fmt.Errorf("cannot find VPN local gateway IP: %w", err) } - for _, entry := range entries { - if entry.iface == string(constants.TUN) && - entry.destination.Equal(net.IP{0, 0, 0, 0}) { - return entry.gateway, nil + for _, route := range routes { + link, err := netlink.LinkByIndex(route.LinkIndex) + if err != nil { + return nil, fmt.Errorf("cannot find VPN local gateway IP: %w", err) + } + interfaceName := link.Attrs().Name + if interfaceName == string(constants.TUN) && + route.Dst != nil && + route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) { + return route.Gw, nil } } return nil, fmt.Errorf("cannot find VPN local gateway IP address from ip routes") diff --git a/internal/routing/reader_test.go b/internal/routing/reader_test.go deleted file mode 100644 index 169e0e50..00000000 --- a/internal/routing/reader_test.go +++ /dev/null @@ -1,352 +0,0 @@ -package routing - -import ( - "fmt" - "net" - "testing" - - "github.com/golang/mock/gomock" - "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/golibs/files/mock_files" - "github.com/qdm12/golibs/logging/mock_logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -//nolint:lll -const exampleRouteData = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -tun0 00000000 050A030A 0003 0 0 0 00000080 0 0 0 -eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0 -tun0 010A030A 050A030A 0007 0 0 0 FFFFFFFF 0 0 0 -tun0 050A030A 00000000 0005 0 0 0 FFFFFFFF 0 0 0 -eth0 42196956 010011AC 0007 0 0 0 FFFFFFFF 0 0 0 -tun0 00000080 050A030A 0003 0 0 0 00000080 0 0 0 -eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0 -` - -//nolint:lll -func Test_parseRoutingTable(t *testing.T) { - t.Parallel() - tests := map[string]struct { - data []byte - entries []routingEntry - err error - }{ - "nil data": { - entries: []routingEntry{}, - }, - "legend only": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -`), - entries: []routingEntry{}, - }, - "legend and single line": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0 -`), - entries: []routingEntry{{ - iface: "eth0", - destination: net.IP{192, 168, 2, 0}, - gateway: net.IP{10, 0, 0, 1}, - flags: "0003", - mask: net.IPMask{255, 255, 255, 0}, - }}, - }, - "legend and two lines": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0 -eth0 0002A8C0 0100000A 0002 0 0 0 00FFFFFF 0 0 0 -`), - entries: []routingEntry{ - { - iface: "eth0", - destination: net.IP{192, 168, 2, 0}, - gateway: net.IP{10, 0, 0, 1}, - flags: "0003", - mask: net.IPMask{255, 255, 255, 0}, - }, - { - iface: "eth0", - destination: net.IP{192, 168, 2, 0}, - gateway: net.IP{10, 0, 0, 1}, - flags: "0002", - mask: net.IPMask{255, 255, 255, 0}, - }}, - }, - "error": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0 -`), - entries: nil, - err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"), - }, - } - for name, tc := range tests { - tc := tc - t.Run(name, func(t *testing.T) { - t.Parallel() - entries, err := parseRoutingTable(tc.data) - if tc.err != nil { - require.Error(t, err) - assert.Equal(t, tc.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tc.entries, entries) - }) - } -} - -//nolint:lll -func Test_DefaultRoute(t *testing.T) { - t.Parallel() - tests := map[string]struct { - data []byte - readErr error - defaultInterface string - defaultGateway net.IP - err error - }{ - "no data": { - err: fmt.Errorf("not enough entries (0) found in %s", constants.NetRoute)}, - "read error": { - readErr: fmt.Errorf("error"), - err: fmt.Errorf("error")}, - "parse error": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 x -`), - err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x\": not enough fields")}, - "single entry": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 00000000 050A090A 0003 0 0 0 00000080 0 0 0 -`), - err: fmt.Errorf("not enough entries (1) found in %s", constants.NetRoute)}, - "success": { - data: []byte(exampleRouteData), - defaultInterface: "eth0", - defaultGateway: net.IP{172, 17, 0, 1}, - }, - "not found": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 00000000 010011AC 0003 0 0 0 10000000 0 0 0 -eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0 -`), - err: fmt.Errorf("cannot find default route"), - }, - } - for name, tc := range tests { - tc := tc - t.Run(name, func(t *testing.T) { - t.Parallel() - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - logger := mock_logging.NewMockLogger(mockCtrl) - filemanager := mock_files.NewMockFileManager(mockCtrl) - - filemanager.EXPECT().ReadFile(string(constants.NetRoute)). - Return(tc.data, tc.readErr).Times(1) - if tc.err == nil { - logger.EXPECT().Info( - "default route found: interface %s, gateway %s", - tc.defaultInterface, tc.defaultGateway.String(), - ).Times(1) - } - r := &routing{logger: logger, fileManager: filemanager} - defaultInterface, defaultGateway, err := r.DefaultRoute() - if tc.err != nil { - require.Error(t, err) - assert.Equal(t, tc.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tc.defaultInterface, defaultInterface) - assert.Equal(t, tc.defaultGateway, defaultGateway) - }) - } -} - -//nolint:lll -func Test_LocalSubnet(t *testing.T) { - t.Parallel() - tests := map[string]struct { - data []byte - readErr error - localSubnet net.IPNet - err error - }{ - "no data": { - err: fmt.Errorf("not enough entries (0) found in %s", constants.NetRoute)}, - "read error": { - readErr: fmt.Errorf("error"), - err: fmt.Errorf("error")}, - "parse error": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 x -`), - err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x\": not enough fields")}, - "single entry": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 00000000 050A090A 0003 0 0 0 00000080 0 0 0 -`), - err: fmt.Errorf("not enough entries (1) found in %s", constants.NetRoute)}, - "success": { - data: []byte(exampleRouteData), - localSubnet: net.IPNet{ - IP: net.IP{172, 17, 0, 0}, - Mask: net.IPMask{255, 255, 0, 0}, - }, - }, - "not found": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0 -eth0 000011AC 10000000 0001 0 0 0 0000FFFF 0 0 0 -`), - err: fmt.Errorf("cannot find local subnet route"), - }, - } - for name, tc := range tests { - tc := tc - t.Run(name, func(t *testing.T) { - t.Parallel() - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - logger := mock_logging.NewMockLogger(mockCtrl) - filemanager := mock_files.NewMockFileManager(mockCtrl) - - filemanager.EXPECT().ReadFile(string(constants.NetRoute)). - Return(tc.data, tc.readErr).Times(1) - if tc.err == nil { - logger.EXPECT().Info("local subnet found: %s", tc.localSubnet.String()).Times(1) - } - r := &routing{logger: logger, fileManager: filemanager} - localSubnet, err := r.LocalSubnet() - if tc.err != nil { - require.Error(t, err) - assert.Equal(t, tc.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tc.localSubnet, localSubnet) - }) - } -} - -//nolint:lll -func Test_routeExists(t *testing.T) { - t.Parallel() - tests := map[string]struct { - subnet net.IPNet - data []byte - readErr error - exists bool - err error - }{ - "no data": {}, - "read error": { - readErr: fmt.Errorf("error"), - err: fmt.Errorf("cannot check route existence: error"), - }, - "parse error": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 x -`), - err: fmt.Errorf("cannot check route existence: line 1 in /proc/net/route: line \"eth0 x\": not enough fields"), - }, - "not existing": { - subnet: net.IPNet{ - IP: net.IP{192, 168, 2, 0}, - Mask: net.IPMask{255, 255, 255, 128}, - }, - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0 -`), - }, - "existing": { - subnet: net.IPNet{ - IP: net.IP{192, 168, 2, 0}, - Mask: net.IPMask{255, 255, 255, 0}, - }, - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0 -`), - exists: true, - }, - } - for name, tc := range tests { - tc := tc - t.Run(name, func(t *testing.T) { - t.Parallel() - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - filemanager := mock_files.NewMockFileManager(mockCtrl) - filemanager.EXPECT().ReadFile(string(constants.NetRoute)). - Return(tc.data, tc.readErr).Times(1) - r := &routing{fileManager: filemanager} - exists, err := r.routeExists(tc.subnet) - if tc.err != nil { - require.Error(t, err) - assert.Equal(t, tc.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tc.exists, exists) - }) - } -} - -//nolint:lll -func Test_VPNDestinationIP(t *testing.T) { - t.Parallel() - tests := map[string]struct { - defaultInterface string - data []byte - readErr error - ip net.IP - err error - }{ - "no data": { - err: fmt.Errorf("cannot find VPN gateway IP address from ip routes"), - }, - "read error": { - readErr: fmt.Errorf("error"), - err: fmt.Errorf("cannot find VPN gateway IP address: error"), - }, - "parse error": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 x -`), - err: fmt.Errorf("cannot find VPN gateway IP address: line 1 in /proc/net/route: line \"eth0 x\": not enough fields"), - }, - "found eth0": { - defaultInterface: "eth0", - data: []byte(exampleRouteData), - ip: net.IP{86, 105, 25, 66}, - }, - "not found tun0": { - defaultInterface: "tun0", - data: []byte(exampleRouteData), - err: fmt.Errorf("cannot find VPN gateway IP address from ip routes"), - }, - } - for name, tc := range tests { - tc := tc - t.Run(name, func(t *testing.T) { - t.Parallel() - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - filemanager := mock_files.NewMockFileManager(mockCtrl) - filemanager.EXPECT().ReadFile(string(constants.NetRoute)). - Return(tc.data, tc.readErr).Times(1) - r := &routing{fileManager: filemanager} - ip, err := r.VPNDestinationIP(tc.defaultInterface) - if tc.err != nil { - require.Error(t, err) - assert.Equal(t, tc.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tc.ip, ip) - }) - } -} diff --git a/internal/routing/routing.go b/internal/routing/routing.go index 845420e0..638f2dd1 100644 --- a/internal/routing/routing.go +++ b/internal/routing/routing.go @@ -1,37 +1,30 @@ package routing import ( - "context" "net" - "github.com/qdm12/golibs/command" - "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" ) type Routing interface { - AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error - DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error) + AddRouteVia(destination net.IPNet, gateway net.IP, iface string) error + DeleteRouteVia(destination net.IPNet) (err error) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) LocalSubnet() (defaultSubnet net.IPNet, err error) - VPNDestinationIP(defaultInterface string) (ip net.IP, err error) + VPNDestinationIP() (ip net.IP, err error) VPNLocalGatewayIP() (ip net.IP, err error) SetDebug() } type routing struct { - commander command.Commander - logger logging.Logger - fileManager files.FileManager - debug bool + logger logging.Logger + debug bool } // NewConfigurator creates a new Configurator instance. -func NewRouting(logger logging.Logger, fileManager files.FileManager) Routing { +func NewRouting(logger logging.Logger) Routing { return &routing{ - commander: command.NewCommander(), - logger: logger.WithPrefix("routing: "), - fileManager: fileManager, + logger: logger.WithPrefix("routing: "), } }