Fix routing reading issues
- Detect VPN gateway properly - Fix local subnet detection, refers to #188 - Split LocalSubnet from DefaultRoute (2 different routes actually)
This commit is contained in:
@@ -301,7 +301,7 @@ func onConnected(allSettings settings.Settings, logger logging.Logger, routingCo
|
|||||||
portForward <- struct{}{}
|
portForward <- struct{}{}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
defaultInterface, _, _, err := routingConf.DefaultRoute()
|
defaultInterface, _, err := routingConf.DefaultRoute()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -63,7 +63,11 @@ func (c *configurator) fallbackToDisabled(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocognit
|
func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocognit
|
||||||
defaultInterface, defaultGateway, defaultSubnet, err := c.routing.DefaultRoute()
|
defaultInterface, defaultGateway, err := c.routing.DefaultRoute()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
localSubnet, err := c.routing.LocalSubnet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
@@ -100,10 +104,10 @@ func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocogn
|
|||||||
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil {
|
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
if err := c.acceptInputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil {
|
if err := c.acceptInputFromToSubnet(ctx, localSubnet, "*", remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
if err := c.acceptOutputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil {
|
if err := c.acceptOutputFromToSubnet(ctx, localSubnet, "*", remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
for _, subnet := range c.allowedSubnets {
|
for _, subnet := range c.allowedSubnets {
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNe
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultInterface, defaultGateway, _, err := c.routing.DefaultRoute()
|
defaultInterface, defaultGateway, err := c.routing.DefaultRoute()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func (c *configurator) SetVPNConnections(ctx context.Context, connections []mode
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultInterface, _, _, err := c.routing.DefaultRoute()
|
defaultInterface, _, err := c.routing.DefaultRoute()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
|
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,24 +23,59 @@ func parseRoutingTable(data []byte) (entries []routingEntry, err error) {
|
|||||||
return entries, nil
|
return entries, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) {
|
func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
|
||||||
r.logger.Info("detecting default network route")
|
|
||||||
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
|
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, defaultSubnet, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
entries, err := parseRoutingTable(data)
|
entries, err := parseRoutingTable(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, defaultSubnet, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
if len(entries) < 2 {
|
if len(entries) < 2 {
|
||||||
return "", nil, defaultSubnet, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute)
|
return "", nil, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute)
|
||||||
}
|
}
|
||||||
defaultInterface = entries[0].iface
|
var defaultRouteEntry routingEntry
|
||||||
defaultGateway = entries[0].gateway
|
for _, entry := range entries {
|
||||||
defaultSubnet = net.IPNet{IP: entries[1].destination, Mask: entries[1].mask}
|
if entry.mask.String() == "00000000" {
|
||||||
r.logger.Info("default route found: interface %s, gateway %s, subnet %s", defaultInterface, defaultGateway.String(), defaultSubnet.String())
|
defaultRouteEntry = entry
|
||||||
return defaultInterface, defaultGateway, defaultSubnet, nil
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if defaultRouteEntry.iface == "" {
|
||||||
|
return "", nil, fmt.Errorf("cannot find default route")
|
||||||
|
}
|
||||||
|
defaultInterface = defaultRouteEntry.iface
|
||||||
|
defaultGateway = defaultRouteEntry.gateway
|
||||||
|
r.logger.Info("default route found: interface %s, gateway %s", defaultInterface, defaultGateway.String())
|
||||||
|
return defaultInterface, defaultGateway, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
|
||||||
|
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
|
||||||
|
if err != nil {
|
||||||
|
return defaultSubnet, err
|
||||||
|
}
|
||||||
|
entries, err := parseRoutingTable(data)
|
||||||
|
if err != nil {
|
||||||
|
return defaultSubnet, err
|
||||||
|
}
|
||||||
|
if len(entries) < 2 {
|
||||||
|
return defaultSubnet, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute)
|
||||||
|
}
|
||||||
|
var localSubnetEntry routingEntry
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.gateway.Equal(net.IP{0, 0, 0, 0}) && !strings.HasPrefix(entry.iface, "tun") {
|
||||||
|
localSubnetEntry = entry
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if localSubnetEntry.iface == "" {
|
||||||
|
return defaultSubnet, fmt.Errorf("cannot find local subnet route")
|
||||||
|
}
|
||||||
|
defaultSubnet = net.IPNet{IP: localSubnetEntry.destination, Mask: localSubnetEntry.mask}
|
||||||
|
r.logger.Info("local subnet found: %s", defaultSubnet.String())
|
||||||
|
return defaultSubnet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) {
|
func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) {
|
||||||
|
|||||||
@@ -14,6 +14,16 @@ import (
|
|||||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const exampleRouteData = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||||
|
tun0 00000000 050A030A 0003 0 0 0 00000080 0 0 0
|
||||||
|
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
||||||
|
tun0 010A030A 050A030A 0007 0 0 0 FFFFFFFF 0 0 0
|
||||||
|
tun0 050A030A 00000000 0005 0 0 0 FFFFFFFF 0 0 0
|
||||||
|
eth0 42196956 010011AC 0007 0 0 0 FFFFFFFF 0 0 0
|
||||||
|
tun0 00000080 050A030A 0003 0 0 0 00000080 0 0 0
|
||||||
|
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
|
||||||
|
`
|
||||||
|
|
||||||
func Test_parseRoutingTable(t *testing.T) {
|
func Test_parseRoutingTable(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
tests := map[string]struct {
|
tests := map[string]struct {
|
||||||
@@ -93,7 +103,6 @@ func Test_DefaultRoute(t *testing.T) {
|
|||||||
readErr error
|
readErr error
|
||||||
defaultInterface string
|
defaultInterface string
|
||||||
defaultGateway net.IP
|
defaultGateway net.IP
|
||||||
defaultSubnet net.IPNet
|
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
"no data": {
|
"no data": {
|
||||||
@@ -104,6 +113,73 @@ func Test_DefaultRoute(t *testing.T) {
|
|||||||
"parse error": {
|
"parse error": {
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||||
eth0 x
|
eth0 x
|
||||||
|
`),
|
||||||
|
err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x\": not enough fields")},
|
||||||
|
"single entry": {
|
||||||
|
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||||
|
eth0 00000000 050A090A 0003 0 0 0 00000080 0 0 0
|
||||||
|
`),
|
||||||
|
err: fmt.Errorf("not enough entries (1) found in %s", constants.NetRoute)},
|
||||||
|
"success": {
|
||||||
|
data: []byte(exampleRouteData),
|
||||||
|
defaultInterface: "eth0",
|
||||||
|
defaultGateway: net.IP{172, 17, 0, 1},
|
||||||
|
},
|
||||||
|
"not found": {
|
||||||
|
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||||
|
eth0 00000000 010011AC 0003 0 0 0 10000000 0 0 0
|
||||||
|
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
|
||||||
|
`),
|
||||||
|
err: fmt.Errorf("cannot find default route"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, tc := range tests {
|
||||||
|
tc := tc
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
mockCtrl := gomock.NewController(t)
|
||||||
|
defer mockCtrl.Finish()
|
||||||
|
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||||
|
filemanager := mock_files.NewMockFileManager(mockCtrl)
|
||||||
|
|
||||||
|
filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
|
||||||
|
Return(tc.data, tc.readErr).Times(1)
|
||||||
|
if tc.err == nil {
|
||||||
|
logger.EXPECT().Info(
|
||||||
|
"default route found: interface %s, gateway %s",
|
||||||
|
tc.defaultInterface, tc.defaultGateway.String(),
|
||||||
|
).Times(1)
|
||||||
|
}
|
||||||
|
r := &routing{logger: logger, fileManager: filemanager}
|
||||||
|
defaultInterface, defaultGateway, err := r.DefaultRoute()
|
||||||
|
if tc.err != nil {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Equal(t, tc.err.Error(), err.Error())
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.defaultInterface, defaultInterface)
|
||||||
|
assert.Equal(t, tc.defaultGateway, defaultGateway)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_LocalSubnet(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := map[string]struct {
|
||||||
|
data []byte
|
||||||
|
readErr error
|
||||||
|
localSubnet net.IPNet
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
"no data": {
|
||||||
|
err: fmt.Errorf("not enough entries (0) found in %s", constants.NetRoute)},
|
||||||
|
"read error": {
|
||||||
|
readErr: fmt.Errorf("error"),
|
||||||
|
err: fmt.Errorf("error")},
|
||||||
|
"parse error": {
|
||||||
|
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||||
|
eth0 x
|
||||||
`),
|
`),
|
||||||
err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x\": not enough fields")},
|
err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x\": not enough fields")},
|
||||||
"single entry": {
|
"single entry": {
|
||||||
@@ -112,16 +188,19 @@ eth0 00000000 050A090A 0003 0 0 0 00000080
|
|||||||
`),
|
`),
|
||||||
err: fmt.Errorf("not enough entries (1) found in %s", constants.NetRoute)},
|
err: fmt.Errorf("not enough entries (1) found in %s", constants.NetRoute)},
|
||||||
"success": {
|
"success": {
|
||||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
data: []byte(exampleRouteData),
|
||||||
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
localSubnet: net.IPNet{
|
||||||
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
|
|
||||||
`),
|
|
||||||
defaultInterface: "eth0",
|
|
||||||
defaultGateway: net.IP{172, 17, 0, 1},
|
|
||||||
defaultSubnet: net.IPNet{
|
|
||||||
IP: net.IP{172, 17, 0, 0},
|
IP: net.IP{172, 17, 0, 0},
|
||||||
Mask: net.IPMask{255, 255, 0, 0},
|
Mask: net.IPMask{255, 255, 0, 0},
|
||||||
}},
|
},
|
||||||
|
},
|
||||||
|
"not found": {
|
||||||
|
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||||
|
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
||||||
|
eth0 000011AC 10000000 0001 0 0 0 0000FFFF 0 0 0
|
||||||
|
`),
|
||||||
|
err: fmt.Errorf("cannot find local subnet route"),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for name, tc := range tests {
|
for name, tc := range tests {
|
||||||
tc := tc
|
tc := tc
|
||||||
@@ -134,24 +213,18 @@ eth0 000011AC 00000000 0001 0 0 0 0000FFFF
|
|||||||
|
|
||||||
filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
|
filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
|
||||||
Return(tc.data, tc.readErr).Times(1)
|
Return(tc.data, tc.readErr).Times(1)
|
||||||
logger.EXPECT().Info("detecting default network route").Times(1)
|
|
||||||
if tc.err == nil {
|
if tc.err == nil {
|
||||||
logger.EXPECT().Info(
|
logger.EXPECT().Info("local subnet found: %s", tc.localSubnet.String()).Times(1)
|
||||||
"default route found: interface %s, gateway %s, subnet %s",
|
|
||||||
tc.defaultInterface, tc.defaultGateway.String(), tc.defaultSubnet.String(),
|
|
||||||
).Times(1)
|
|
||||||
}
|
}
|
||||||
r := &routing{logger: logger, fileManager: filemanager}
|
r := &routing{logger: logger, fileManager: filemanager}
|
||||||
defaultInterface, defaultGateway, defaultSubnet, err := r.DefaultRoute()
|
localSubnet, err := r.LocalSubnet()
|
||||||
if tc.err != nil {
|
if tc.err != nil {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Equal(t, tc.err.Error(), err.Error())
|
assert.Equal(t, tc.err.Error(), err.Error())
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
assert.Equal(t, tc.defaultInterface, defaultInterface)
|
assert.Equal(t, tc.localSubnet, localSubnet)
|
||||||
assert.Equal(t, tc.defaultGateway, defaultGateway)
|
|
||||||
assert.Equal(t, tc.defaultSubnet, defaultSubnet)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -218,18 +291,8 @@ eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_CurrentIP(t *testing.T) {
|
func Test_VPNGatewayIP(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
const exampleRouteData = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
|
||||||
tun0 00000000 050A090A 0003 0 0 0 00000080 0 0 0
|
|
||||||
eth0 00000000 0100000A 0003 0 0 0 00000000 0 0 0
|
|
||||||
eth0 0000000A 00000000 0001 0 0 0 00FFFFFF 0 0 0
|
|
||||||
tun0 010A090A 050A090A 0007 0 0 0 FFFFFFFF 0 0 0
|
|
||||||
tun0 050A090A 00000000 0005 0 0 0 FFFFFFFF 0 0 0
|
|
||||||
eth0 2194B05F 0100000A 0007 0 0 0 FFFFFFFF 0 0 0
|
|
||||||
tun0 00000080 050A090A 0003 0 0 0 00000080 0 0 0
|
|
||||||
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
|
||||||
`
|
|
||||||
tests := map[string]struct {
|
tests := map[string]struct {
|
||||||
defaultInterface string
|
defaultInterface string
|
||||||
data []byte
|
data []byte
|
||||||
@@ -253,7 +316,7 @@ eth0 x
|
|||||||
"found eth0": {
|
"found eth0": {
|
||||||
defaultInterface: "eth0",
|
defaultInterface: "eth0",
|
||||||
data: []byte(exampleRouteData),
|
data: []byte(exampleRouteData),
|
||||||
ip: net.IP{95, 176, 148, 33},
|
ip: net.IP{86, 105, 25, 66},
|
||||||
},
|
},
|
||||||
"not found tun0": {
|
"not found tun0": {
|
||||||
defaultInterface: "tun0",
|
defaultInterface: "tun0",
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ import (
|
|||||||
type Routing interface {
|
type Routing interface {
|
||||||
AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error
|
AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error
|
||||||
DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error)
|
DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error)
|
||||||
DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error)
|
DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error)
|
||||||
|
LocalSubnet() (defaultSubnet net.IPNet, err error)
|
||||||
VPNGatewayIP(defaultInterface string) (ip net.IP, err error)
|
VPNGatewayIP(defaultInterface string) (ip net.IP, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user