IP_STATUS_FILE and routing improvements (#130)

- Obtains VPN public IP address from routing table
- Logs and writes VPN Public IP address to `/ip` as soon as VPN is up
- Obtain port forward, logs it and writes it as soon as VPN is up
- Routing fully refactored and tested
- Routing reads from `/proc/net/route`
- Routing mutates the routes using `ip route ...`
This commit is contained in:
Quentin McGaw
2020-04-12 08:55:13 -04:00
committed by GitHub
parent da8391e9ae
commit 3ac3e5022c
21 changed files with 1309 additions and 299 deletions

View File

@@ -4,7 +4,6 @@ import (
"net"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
@@ -20,23 +19,19 @@ type Configurator interface {
CreateGeneralRules() error
CreateVPNRules(dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error
CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error
AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error
GetDefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error)
AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error
AllowAnyIncomingOnPort(port uint16) error
}
type configurator struct {
commander command.Commander
logger logging.Logger
fileManager files.FileManager
commander command.Commander
logger logging.Logger
}
// NewConfigurator creates a new Configurator instance
func NewConfigurator(logger logging.Logger, fileManager files.FileManager) Configurator {
func NewConfigurator(logger logging.Logger) Configurator {
return &configurator{
commander: command.NewCommander(),
logger: logger,
fileManager: fileManager,
commander: command.NewCommander(),
logger: logger,
}
}

View File

@@ -1,88 +0,0 @@
package firewall
import (
"encoding/hex"
"net"
"fmt"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error {
for _, subnet := range subnets {
subnetStr := subnet.String()
output, err := c.commander.Run("ip", "route", "show", subnetStr)
if err != nil {
return fmt.Errorf("cannot read route %s: %s: %w", subnetStr, output, err)
} else if len(output) > 0 { // thanks to @npawelek https://github.com/npawelek
continue // already exists
// TODO remove it instead and continue execution below
}
c.logger.Info("%s: adding %s as route via %s", logPrefix, subnetStr, defaultInterface)
output, err = c.commander.Run("ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface)
if err != nil {
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnetStr, defaultGateway.String(), "dev", defaultInterface, output, err)
}
}
return nil
}
func (c *configurator) GetDefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) {
c.logger.Info("%s: detecting default network route", logPrefix)
data, err := c.fileManager.ReadFile(string(constants.NetRoute))
if err != nil {
return "", nil, defaultSubnet, err
}
// Verify number of lines and fields
lines := strings.Split(string(data), "\n")
if len(lines) < 3 {
return "", nil, defaultSubnet, fmt.Errorf("not enough lines (%d) found in %s", len(lines), constants.NetRoute)
}
fieldsLine1 := strings.Fields(lines[1])
if len(fieldsLine1) < 3 {
return "", nil, defaultSubnet, fmt.Errorf("not enough fields in %q", lines[1])
}
fieldsLine2 := strings.Fields(lines[2])
if len(fieldsLine2) < 8 {
return "", nil, defaultSubnet, fmt.Errorf("not enough fields in %q", lines[2])
}
// get information
defaultInterface = fieldsLine1[0]
defaultGateway, err = reversedHexToIPv4(fieldsLine1[2])
if err != nil {
return "", nil, defaultSubnet, err
}
netNumber, err := reversedHexToIPv4(fieldsLine2[1])
if err != nil {
return "", nil, defaultSubnet, err
}
netMask, err := hexToIPv4Mask(fieldsLine2[7])
if err != nil {
return "", nil, defaultSubnet, err
}
subnet := net.IPNet{IP: netNumber, Mask: netMask}
c.logger.Info("%s: default route found: interface %s, gateway %s, subnet %s", logPrefix, defaultInterface, defaultGateway.String(), subnet.String())
return defaultInterface, defaultGateway, subnet, nil
}
func reversedHexToIPv4(reversedHex string) (IP net.IP, err error) {
bytes, err := hex.DecodeString(reversedHex)
if err != nil {
return nil, fmt.Errorf("cannot parse reversed IP hex %q: %s", reversedHex, err)
} else if len(bytes) != 4 {
return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes))
}
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
}
func hexToIPv4Mask(hexString string) (mask net.IPMask, err error) {
bytes, err := hex.DecodeString(hexString)
if err != nil {
return nil, fmt.Errorf("cannot parse hex mask %q: %s", hexString, err)
} else if len(bytes) != 4 {
return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes))
}
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
}

View File

@@ -1,171 +0,0 @@
package firewall
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
filesmocks "github.com/qdm12/golibs/files/mocks"
loggingmocks "github.com/qdm12/golibs/logging/mocks"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func Test_getDefaultRoute(t *testing.T) {
t.Parallel()
tests := map[string]struct {
data []byte
readErr error
defaultInterface string
defaultGateway net.IP
defaultSubnet net.IPNet
err error
}{
"no data": {
err: fmt.Errorf("not enough lines (1) found in %s", constants.NetRoute)},
"read error": {
readErr: fmt.Errorf("error"),
err: fmt.Errorf("error")},
"not enough fields line 1": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
err: fmt.Errorf("not enough fields in \"eth0 00000000\"")},
"not enough fields line 2": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
eth0 000011AC 00000000 0001 0 0 0`),
err: fmt.Errorf("not enough fields in \"eth0 000011AC 00000000 0001 0 0 0\"")},
"bad gateway": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 x 0003 0 0 0 00000000 0 0 0
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"bad net number": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
eth0 x 00000000 0001 0 0 0 0000FFFF 0 0 0`),
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"bad net mask": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
eth0 000011AC 00000000 0001 0 0 0 x 0 0 0`),
err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"success": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
defaultInterface: "eth0",
defaultGateway: net.IP{0xac, 0x11, 0x0, 0x1},
defaultSubnet: net.IPNet{
IP: net.IP{0xac, 0x11, 0x0, 0x0},
Mask: net.IPMask{0xff, 0xff, 0x0, 0x0},
}},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
fileManager := &filesmocks.FileManager{}
fileManager.On("ReadFile", string(constants.NetRoute)).
Return(tc.data, tc.readErr).Once()
logger := &loggingmocks.Logger{}
logger.On("Info", "%s: detecting default network route", logPrefix).Once()
if tc.err == nil {
logger.On("Info", "%s: default route found: interface %s, gateway %s, subnet %s",
logPrefix, tc.defaultInterface, tc.defaultGateway.String(), tc.defaultSubnet.String()).Once()
}
c := &configurator{logger: logger, fileManager: fileManager}
defaultInterface, defaultGateway, defaultSubnet, err := c.GetDefaultRoute()
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.defaultInterface, defaultInterface)
assert.Equal(t, tc.defaultGateway, defaultGateway)
assert.Equal(t, tc.defaultSubnet, defaultSubnet)
fileManager.AssertExpectations(t)
logger.AssertExpectations(t)
})
}
}
func Test_reversedHexToIPv4(t *testing.T) {
t.Parallel()
tests := map[string]struct {
reversedHex string
IP net.IP
err error
}{
"empty hex": {
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
"bad hex": {
reversedHex: "x",
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"3 bytes hex": {
reversedHex: "9abcde",
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
"correct hex": {
reversedHex: "010011AC",
IP: []byte{0xac, 0x11, 0x0, 0x1},
err: nil},
"correct hex 2": {
reversedHex: "000011AC",
IP: []byte{0xac, 0x11, 0x0, 0x0},
err: nil},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
IP, err := reversedHexToIPv4(tc.reversedHex)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.IP, IP)
})
}
}
func Test_hexMaskToDecMask(t *testing.T) {
t.Parallel()
tests := map[string]struct {
hexString string
mask net.IPMask
err error
}{
"empty hex": {
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
"bad hex": {
hexString: "x",
err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"3 bytes hex": {
hexString: "9abcde",
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
"16": {
hexString: "0000FFFF",
mask: []byte{0xff, 0xff, 0x0, 0x0},
err: nil},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mask, err := hexToIPv4Mask(tc.hexString)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.mask, mask)
})
}
}