IP_STATUS_FILE and routing improvements (#130)

- Obtains VPN public IP address from routing table
- Logs and writes VPN Public IP address to `/ip` as soon as VPN is up
- Obtain port forward, logs it and writes it as soon as VPN is up
- Routing fully refactored and tested
- Routing reads from `/proc/net/route`
- Routing mutates the routes using `ip route ...`
This commit is contained in:
Quentin McGaw
2020-04-12 08:55:13 -04:00
committed by GitHub
parent da8391e9ae
commit 3ac3e5022c
21 changed files with 1309 additions and 299 deletions

View File

@@ -38,6 +38,7 @@ ENV VPNSP=pia \
TZ= \ TZ= \
UID=1000 \ UID=1000 \
GID=1000 \ GID=1000 \
IP_STATUS_FILE="/ip" \
# PIA only # PIA only
PASSWORD= \ PASSWORD= \
REGION="Austria" \ REGION="Austria" \

View File

@@ -157,6 +157,7 @@ docker run --rm --network=container:pia alpine:3.11 wget -qO- https://ipinfo.io
| `EXTRA_SUBNETS` | | Optional | ✅ | ✅ | ✅ | Comma separated subnets allowed in the container firewall | In example `192.168.1.0/24,192.168.10.121,10.0.0.5/28` | | `EXTRA_SUBNETS` | | Optional | ✅ | ✅ | ✅ | Comma separated subnets allowed in the container firewall | In example `192.168.1.0/24,192.168.10.121,10.0.0.5/28` |
| `PORT_FORWARDING` | `off` | | ✅ | ❌ | ❌ | Enable port forwarding on the VPN server | `on`, `off` | | `PORT_FORWARDING` | `off` | | ✅ | ❌ | ❌ | Enable port forwarding on the VPN server | `on`, `off` |
| `PORT_FORWARDING_STATUS_FILE` | `/forwarded_port` | | ✅ | ❌ | ❌ | File path to store the forwarded port number | Any valid file path | | `PORT_FORWARDING_STATUS_FILE` | `/forwarded_port` | | ✅ | ❌ | ❌ | File path to store the forwarded port number | Any valid file path |
| `IP_STATUS_FILE` | `/ip` | | ✅ | ✅ | ✅ | File path to store the public IP address assigned | Any valid file path |
| `TINYPROXY` | `off` | | ✅ | ✅ | ✅ | Enable the internal HTTP proxy tinyproxy | `on`, `off` | | `TINYPROXY` | `off` | | ✅ | ✅ | ✅ | Enable the internal HTTP proxy tinyproxy | `on`, `off` |
| `TINYPROXY_LOG` | `Info` | | ✅ | ✅ | ✅ | Tinyproxy log level | `Info`, `Connect`, `Notice`, `Warning`, `Error`, `Critical` | | `TINYPROXY_LOG` | `Info` | | ✅ | ✅ | ✅ | Tinyproxy log level | `Info`, `Connect`, `Notice`, `Warning`, `Error`, `Critical` |
| `TINYPROXY_PORT` | `8888` | | ✅ | ✅ | ✅ | Internal port number for Tinyproxy to listen on | `1024` to `65535` | | `TINYPROXY_PORT` | `8888` | | ✅ | ✅ | ✅ | Internal port number for Tinyproxy to listen on | `1024` to `65535` |

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"strings"
"time" "time"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
@@ -24,6 +25,7 @@ import (
"github.com/qdm12/private-internet-access-docker/internal/openvpn" "github.com/qdm12/private-internet-access-docker/internal/openvpn"
"github.com/qdm12/private-internet-access-docker/internal/params" "github.com/qdm12/private-internet-access-docker/internal/params"
"github.com/qdm12/private-internet-access-docker/internal/pia" "github.com/qdm12/private-internet-access-docker/internal/pia"
"github.com/qdm12/private-internet-access-docker/internal/routing"
"github.com/qdm12/private-internet-access-docker/internal/settings" "github.com/qdm12/private-internet-access-docker/internal/settings"
"github.com/qdm12/private-internet-access-docker/internal/shadowsocks" "github.com/qdm12/private-internet-access-docker/internal/shadowsocks"
"github.com/qdm12/private-internet-access-docker/internal/splash" "github.com/qdm12/private-internet-access-docker/internal/splash"
@@ -52,7 +54,8 @@ func main() {
alpineConf := alpine.NewConfigurator(logger, fileManager) alpineConf := alpine.NewConfigurator(logger, fileManager)
ovpnConf := openvpn.NewConfigurator(logger, fileManager) ovpnConf := openvpn.NewConfigurator(logger, fileManager)
dnsConf := dns.NewConfigurator(logger, client, fileManager) dnsConf := dns.NewConfigurator(logger, client, fileManager)
firewallConf := firewall.NewConfigurator(logger, fileManager) firewallConf := firewall.NewConfigurator(logger)
routingConf := routing.NewRouting(logger, fileManager)
piaConf := pia.NewConfigurator(client, fileManager, firewallConf, logger) piaConf := pia.NewConfigurator(client, fileManager, firewallConf, logger)
mullvadConf := mullvad.NewConfigurator(fileManager, logger) mullvadConf := mullvad.NewConfigurator(fileManager, logger)
windscribeConf := windscribe.NewConfigurator(fileManager) windscribeConf := windscribe.NewConfigurator(fileManager)
@@ -100,6 +103,9 @@ func main() {
err = ovpnConf.WriteAuthFile(openVPNUser, openVPNPassword, allSettings.System.UID, allSettings.System.GID) err = ovpnConf.WriteAuthFile(openVPNUser, openVPNPassword, allSettings.System.UID, allSettings.System.GID)
e.FatalOnError(err) e.FatalOnError(err)
defaultInterface, defaultGateway, defaultSubnet, err := routingConf.DefaultRoute()
e.FatalOnError(err)
// Temporarily reset chain policies allowing Kubernetes sidecar to // Temporarily reset chain policies allowing Kubernetes sidecar to
// successfully restart the container. Without this, the existing rules will // successfully restart the container. Without this, the existing rules will
// pre-exist, preventing the nslookup of the PIA region address. These will // pre-exist, preventing the nslookup of the PIA region address. These will
@@ -111,7 +117,19 @@ func main() {
go func() { go func() {
// Blocking line merging reader for all programs: openvpn, tinyproxy, unbound and shadowsocks // Blocking line merging reader for all programs: openvpn, tinyproxy, unbound and shadowsocks
logger.Info("Launching standard output merger") logger.Info("Launching standard output merger")
err = streamMerger.CollectLines(func(line string) { logger.Info(line) }) err = streamMerger.CollectLines(func(line string) {
logger.Info(line)
if strings.Contains(line, "Initialization Sequence Completed") {
onConnected(logger, routingConf, fileManager, piaConf,
defaultInterface,
allSettings.VPNSP,
allSettings.PIA.PortForwarding.Enabled,
allSettings.PIA.PortForwarding.Filepath,
allSettings.System.IPStatusFilepath,
allSettings.System.UID,
allSettings.System.GID)
}
})
e.FatalOnError(err) e.FatalOnError(err)
}() }()
@@ -191,9 +209,7 @@ func main() {
e.FatalOnError(err) e.FatalOnError(err)
} }
defaultInterface, defaultGateway, defaultSubnet, err := firewallConf.GetDefaultRoute() err = routingConf.AddRoutesVia(allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface)
e.FatalOnError(err)
err = firewallConf.AddRoutesVia(allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface)
e.FatalOnError(err) e.FatalOnError(err)
err = firewallConf.Clear() err = firewallConf.Clear()
e.FatalOnError(err) e.FatalOnError(err)
@@ -247,28 +263,6 @@ func main() {
go streamMerger.Merge("shadowsocks", stream) go streamMerger.Merge("shadowsocks", stream)
} }
if allSettings.VPNSP == "pia" && allSettings.PIA.PortForwarding.Enabled {
time.AfterFunc(10*time.Second, func() {
port, err := piaConf.GetPortForward()
if err != nil {
logger.Error("port forwarding:", err)
return
}
if err := piaConf.WritePortForward(
allSettings.PIA.PortForwarding.Filepath,
port,
allSettings.System.UID,
allSettings.System.GID); err != nil {
logger.Error("port forwarding:", err)
return
}
if err := piaConf.AllowPortForwardFirewall(constants.TUN, port); err != nil {
logger.Error("port forwarding:", err)
return
}
})
}
stream, waitFn, err := ovpnConf.Start() stream, waitFn, err := ovpnConf.Start()
e.FatalOnError(err) e.FatalOnError(err)
go streamMerger.Merge("openvpn", stream) go streamMerger.Merge("openvpn", stream)
@@ -284,3 +278,48 @@ func main() {
}) })
e.FatalOnError(waitFn()) e.FatalOnError(waitFn())
} }
func onConnected(
logger logging.Logger,
routingConf routing.Routing,
fileManager files.FileManager,
piaConf pia.Configurator,
defaultInterface string,
VPNSP string,
portForwarding bool,
portForwardingFilepath models.Filepath,
ipStatusFilepath models.Filepath,
uid, gid int,
) {
ip, err := routingConf.CurrentPublicIP(defaultInterface)
if err != nil {
logger.Error(err)
} else {
logger.Info("Tunnel IP is %s, see more information at https://ipinfo.io/%s", ip, ip)
err := fileManager.WriteLinesToFile(
string(ipStatusFilepath),
[]string{ip.String()},
files.Ownership(uid, gid),
files.Permissions(400))
if err != nil {
logger.Error(err)
}
}
if VPNSP != "pia" || !portForwarding {
return
}
port, err := piaConf.GetPortForward()
if err != nil {
logger.Error("port forwarding:", err)
return
}
logger.Info("port forwarding: Port %d", port)
if err := piaConf.WritePortForward(portForwardingFilepath, port, uid, gid); err != nil {
logger.Error("port forwarding:", err)
return
}
if err := piaConf.AllowPortForwardFirewall(constants.TUN, port); err != nil {
logger.Error("port forwarding:", err)
return
}
}

1
go.mod
View File

@@ -3,6 +3,7 @@ module github.com/qdm12/private-internet-access-docker
go 1.13 go 1.13
require ( require (
github.com/golang/mock v1.4.3
github.com/kyokomi/emoji v2.1.0+incompatible github.com/kyokomi/emoji v2.1.0+incompatible
github.com/qdm12/golibs v0.0.0-20200329231626-f55b47cd4e96 github.com/qdm12/golibs v0.0.0-20200329231626-f55b47cd4e96
github.com/stretchr/testify v1.5.1 github.com/stretchr/testify v1.5.1

6
go.sum
View File

@@ -35,6 +35,8 @@ github.com/go-openapi/swag v0.17.0 h1:iqrgMg7Q7SvtbWLlltPrkMs0UBJI6oTSs79JFRUi88
github.com/go-openapi/swag v0.17.0/go.mod h1:AByQ+nYG6gQg71GINrmuDXCPWdL640yX49/kXLo40Tg= github.com/go-openapi/swag v0.17.0/go.mod h1:AByQ+nYG6gQg71GINrmuDXCPWdL640yX49/kXLo40Tg=
github.com/go-openapi/validate v0.17.0 h1:pqoViQz3YLOGIhAmD0N4Lt6pa/3Gnj3ymKqQwq8iS6U= github.com/go-openapi/validate v0.17.0 h1:pqoViQz3YLOGIhAmD0N4Lt6pa/3Gnj3ymKqQwq8iS6U=
github.com/go-openapi/validate v0.17.0/go.mod h1:Uh4HdOzKt19xGIGm1qHf/ofbX1YQ4Y+MYsct2VUrAJ4= github.com/go-openapi/validate v0.17.0/go.mod h1:Uh4HdOzKt19xGIGm1qHf/ofbX1YQ4Y+MYsct2VUrAJ4=
github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw=
github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -97,9 +99,11 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200327173247-9dae0f8f5775 h1:TC0v2RSO1u2kn1ZugjrFXkRZAEaqMN/RW+OTZkBzmLE= golang.org/x/sys v0.0.0-20200327173247-9dae0f8f5775 h1:TC0v2RSO1u2kn1ZugjrFXkRZAEaqMN/RW+OTZkBzmLE=
golang.org/x/sys v0.0.0-20200327173247-9dae0f8f5775/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200327173247-9dae0f8f5775/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5 h1:hKsoRgsbwY1NafxrwTs+k64bikrLBkAgPir1TNCj3Zs= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5 h1:hKsoRgsbwY1NafxrwTs+k64bikrLBkAgPir1TNCj3Zs=
@@ -115,3 +119,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=

View File

@@ -4,7 +4,6 @@ import (
"net" "net"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/private-internet-access-docker/internal/models" "github.com/qdm12/private-internet-access-docker/internal/models"
) )
@@ -20,8 +19,6 @@ type Configurator interface {
CreateGeneralRules() error CreateGeneralRules() error
CreateVPNRules(dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error CreateVPNRules(dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error
CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error
AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error
GetDefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error)
AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error
AllowAnyIncomingOnPort(port uint16) error AllowAnyIncomingOnPort(port uint16) error
} }
@@ -29,14 +26,12 @@ type Configurator interface {
type configurator struct { type configurator struct {
commander command.Commander commander command.Commander
logger logging.Logger logger logging.Logger
fileManager files.FileManager
} }
// NewConfigurator creates a new Configurator instance // NewConfigurator creates a new Configurator instance
func NewConfigurator(logger logging.Logger, fileManager files.FileManager) Configurator { func NewConfigurator(logger logging.Logger) Configurator {
return &configurator{ return &configurator{
commander: command.NewCommander(), commander: command.NewCommander(),
logger: logger, logger: logger,
fileManager: fileManager,
} }
} }

View File

@@ -1,88 +0,0 @@
package firewall
import (
"encoding/hex"
"net"
"fmt"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error {
for _, subnet := range subnets {
subnetStr := subnet.String()
output, err := c.commander.Run("ip", "route", "show", subnetStr)
if err != nil {
return fmt.Errorf("cannot read route %s: %s: %w", subnetStr, output, err)
} else if len(output) > 0 { // thanks to @npawelek https://github.com/npawelek
continue // already exists
// TODO remove it instead and continue execution below
}
c.logger.Info("%s: adding %s as route via %s", logPrefix, subnetStr, defaultInterface)
output, err = c.commander.Run("ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface)
if err != nil {
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnetStr, defaultGateway.String(), "dev", defaultInterface, output, err)
}
}
return nil
}
func (c *configurator) GetDefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) {
c.logger.Info("%s: detecting default network route", logPrefix)
data, err := c.fileManager.ReadFile(string(constants.NetRoute))
if err != nil {
return "", nil, defaultSubnet, err
}
// Verify number of lines and fields
lines := strings.Split(string(data), "\n")
if len(lines) < 3 {
return "", nil, defaultSubnet, fmt.Errorf("not enough lines (%d) found in %s", len(lines), constants.NetRoute)
}
fieldsLine1 := strings.Fields(lines[1])
if len(fieldsLine1) < 3 {
return "", nil, defaultSubnet, fmt.Errorf("not enough fields in %q", lines[1])
}
fieldsLine2 := strings.Fields(lines[2])
if len(fieldsLine2) < 8 {
return "", nil, defaultSubnet, fmt.Errorf("not enough fields in %q", lines[2])
}
// get information
defaultInterface = fieldsLine1[0]
defaultGateway, err = reversedHexToIPv4(fieldsLine1[2])
if err != nil {
return "", nil, defaultSubnet, err
}
netNumber, err := reversedHexToIPv4(fieldsLine2[1])
if err != nil {
return "", nil, defaultSubnet, err
}
netMask, err := hexToIPv4Mask(fieldsLine2[7])
if err != nil {
return "", nil, defaultSubnet, err
}
subnet := net.IPNet{IP: netNumber, Mask: netMask}
c.logger.Info("%s: default route found: interface %s, gateway %s, subnet %s", logPrefix, defaultInterface, defaultGateway.String(), subnet.String())
return defaultInterface, defaultGateway, subnet, nil
}
func reversedHexToIPv4(reversedHex string) (IP net.IP, err error) {
bytes, err := hex.DecodeString(reversedHex)
if err != nil {
return nil, fmt.Errorf("cannot parse reversed IP hex %q: %s", reversedHex, err)
} else if len(bytes) != 4 {
return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes))
}
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
}
func hexToIPv4Mask(hexString string) (mask net.IPMask, err error) {
bytes, err := hex.DecodeString(hexString)
if err != nil {
return nil, fmt.Errorf("cannot parse hex mask %q: %s", hexString, err)
} else if len(bytes) != 4 {
return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes))
}
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
}

View File

@@ -1,171 +0,0 @@
package firewall
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
filesmocks "github.com/qdm12/golibs/files/mocks"
loggingmocks "github.com/qdm12/golibs/logging/mocks"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func Test_getDefaultRoute(t *testing.T) {
t.Parallel()
tests := map[string]struct {
data []byte
readErr error
defaultInterface string
defaultGateway net.IP
defaultSubnet net.IPNet
err error
}{
"no data": {
err: fmt.Errorf("not enough lines (1) found in %s", constants.NetRoute)},
"read error": {
readErr: fmt.Errorf("error"),
err: fmt.Errorf("error")},
"not enough fields line 1": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
err: fmt.Errorf("not enough fields in \"eth0 00000000\"")},
"not enough fields line 2": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
eth0 000011AC 00000000 0001 0 0 0`),
err: fmt.Errorf("not enough fields in \"eth0 000011AC 00000000 0001 0 0 0\"")},
"bad gateway": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 x 0003 0 0 0 00000000 0 0 0
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"bad net number": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
eth0 x 00000000 0001 0 0 0 0000FFFF 0 0 0`),
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"bad net mask": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
eth0 000011AC 00000000 0001 0 0 0 x 0 0 0`),
err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"success": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
defaultInterface: "eth0",
defaultGateway: net.IP{0xac, 0x11, 0x0, 0x1},
defaultSubnet: net.IPNet{
IP: net.IP{0xac, 0x11, 0x0, 0x0},
Mask: net.IPMask{0xff, 0xff, 0x0, 0x0},
}},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
fileManager := &filesmocks.FileManager{}
fileManager.On("ReadFile", string(constants.NetRoute)).
Return(tc.data, tc.readErr).Once()
logger := &loggingmocks.Logger{}
logger.On("Info", "%s: detecting default network route", logPrefix).Once()
if tc.err == nil {
logger.On("Info", "%s: default route found: interface %s, gateway %s, subnet %s",
logPrefix, tc.defaultInterface, tc.defaultGateway.String(), tc.defaultSubnet.String()).Once()
}
c := &configurator{logger: logger, fileManager: fileManager}
defaultInterface, defaultGateway, defaultSubnet, err := c.GetDefaultRoute()
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.defaultInterface, defaultInterface)
assert.Equal(t, tc.defaultGateway, defaultGateway)
assert.Equal(t, tc.defaultSubnet, defaultSubnet)
fileManager.AssertExpectations(t)
logger.AssertExpectations(t)
})
}
}
func Test_reversedHexToIPv4(t *testing.T) {
t.Parallel()
tests := map[string]struct {
reversedHex string
IP net.IP
err error
}{
"empty hex": {
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
"bad hex": {
reversedHex: "x",
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"3 bytes hex": {
reversedHex: "9abcde",
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
"correct hex": {
reversedHex: "010011AC",
IP: []byte{0xac, 0x11, 0x0, 0x1},
err: nil},
"correct hex 2": {
reversedHex: "000011AC",
IP: []byte{0xac, 0x11, 0x0, 0x0},
err: nil},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
IP, err := reversedHexToIPv4(tc.reversedHex)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.IP, IP)
})
}
}
func Test_hexMaskToDecMask(t *testing.T) {
t.Parallel()
tests := map[string]struct {
hexString string
mask net.IPMask
err error
}{
"empty hex": {
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
"bad hex": {
hexString: "x",
err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"3 bytes hex": {
hexString: "9abcde",
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
"16": {
hexString: "0000FFFF",
mask: []byte{0xff, 0xff, 0x0, 0x0},
err: nil},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mask, err := hexToIPv4Mask(tc.hexString)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.mask, mask)
})
}
}

View File

@@ -32,6 +32,7 @@ type ParamsReader interface {
GetUID() (uid int, err error) GetUID() (uid int, err error)
GetGID() (gid int, err error) GetGID() (gid int, err error)
GetTimezone() (timezone string, err error) GetTimezone() (timezone string, err error)
GetIPStatusFilepath() (filepath models.Filepath, err error)
// Firewall getters // Firewall getters
GetExtraSubnets() (extraSubnets []net.IPNet, err error) GetExtraSubnets() (extraSubnets []net.IPNet, err error)

View File

@@ -2,6 +2,7 @@ package params
import ( import (
libparams "github.com/qdm12/golibs/params" libparams "github.com/qdm12/golibs/params"
"github.com/qdm12/private-internet-access-docker/internal/models"
) )
// GetUID obtains the user ID to use from the environment variable UID // GetUID obtains the user ID to use from the environment variable UID
@@ -18,3 +19,10 @@ func (p *paramsReader) GetGID() (gid int, err error) {
func (p *paramsReader) GetTimezone() (timezone string, err error) { func (p *paramsReader) GetTimezone() (timezone string, err error) {
return p.envParams.GetEnv("TZ") return p.envParams.GetEnv("TZ")
} }
// GetIPStatusFilepath obtains the IP status file path
// from the environment variable IP_STATUS_FILE
func (p *paramsReader) GetIPStatusFilepath() (filepath models.Filepath, err error) {
filepathStr, err := p.envParams.GetPath("IP_STATUS_FILE", libparams.Default("/ip"), libparams.CaseSensitiveValue())
return models.Filepath(filepathStr), err
}

93
internal/routing/entry.go Normal file
View File

@@ -0,0 +1,93 @@
package routing
import (
"encoding/hex"
"fmt"
"net"
"strconv"
"strings"
)
type routingEntry struct {
iface string
destination net.IP
gateway net.IP
flags string
refCount int
use int
metric int
mask net.IPMask
mtu int
window int
irtt int
}
func parseRoutingEntry(s string) (r routingEntry, err error) {
wrapError := func(err error) error {
return fmt.Errorf("line %q: %w", s, err)
}
fields := strings.Fields(s)
if len(fields) < 11 {
return r, wrapError(fmt.Errorf("not enough fields"))
}
r.iface = fields[0]
r.destination, err = reversedHexToIPv4(fields[1])
if err != nil {
return r, wrapError(err)
}
r.gateway, err = reversedHexToIPv4(fields[2])
if err != nil {
return r, wrapError(err)
}
r.flags = fields[3]
r.refCount, err = strconv.Atoi(fields[4])
if err != nil {
return r, wrapError(err)
}
r.use, err = strconv.Atoi(fields[5])
if err != nil {
return r, wrapError(err)
}
r.metric, err = strconv.Atoi(fields[6])
if err != nil {
return r, wrapError(err)
}
r.mask, err = hexToIPv4Mask(fields[7])
if err != nil {
return r, wrapError(err)
}
r.mtu, err = strconv.Atoi(fields[8])
if err != nil {
return r, wrapError(err)
}
r.window, err = strconv.Atoi(fields[9])
if err != nil {
return r, wrapError(err)
}
r.irtt, err = strconv.Atoi(fields[10])
if err != nil {
return r, wrapError(err)
}
return r, nil
}
func reversedHexToIPv4(reversedHex string) (IP net.IP, err error) {
bytes, err := hex.DecodeString(reversedHex)
if err != nil {
return nil, fmt.Errorf("cannot parse reversed IP hex %q: %s", reversedHex, err)
} else if len(bytes) != 4 {
return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes))
}
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
}
func hexToIPv4Mask(hexString string) (mask net.IPMask, err error) {
bytes, err := hex.DecodeString(hexString)
if err != nil {
return nil, fmt.Errorf("cannot parse hex mask %q: %s", hexString, err)
} else if len(bytes) != 4 {
return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes))
}
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
}

View File

@@ -0,0 +1,163 @@
package routing
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_parseRoutingEntry(t *testing.T) {
t.Parallel()
tests := map[string]struct {
s string
r routingEntry
err error
}{
"empty string": {
err: fmt.Errorf("line \"\": not enough fields"),
},
"not enough fields": {
s: "a b c d e",
err: fmt.Errorf("line \"a b c d e\": not enough fields"),
},
"bad destination": {
s: "eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0",
err: fmt.Errorf("line \"eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"),
},
"bad gateway": {
s: "eth0 0002A8C0 x 0003 0 0 0 00FFFFFF 0 0 0",
err: fmt.Errorf("line \"eth0 0002A8C0 x 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"),
},
"bad ref count": {
s: "eth0 0002A8C0 0100000A 0003 x 0 0 00FFFFFF 0 0 0",
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 x 0 0 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
},
"bad use": {
s: "eth0 0002A8C0 0100000A 0003 0 x 0 00FFFFFF 0 0 0",
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 x 0 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
},
"bad metric": {
s: "eth0 0002A8C0 0100000A 0003 0 0 x 00FFFFFF 0 0 0",
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 x 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
},
"bad mask": {
s: "eth0 0002A8C0 0100000A 0003 0 0 0 x 0 0 0",
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 x 0 0 0\": cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'"),
},
"bad mtu": {
s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF x 0 0",
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF x 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
},
"bad window": {
s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 x 0",
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 x 0\": strconv.Atoi: parsing \"x\": invalid syntax"),
},
"bad irtt": {
s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 x",
err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 x\": strconv.Atoi: parsing \"x\": invalid syntax"),
},
"success": {
s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0",
r: routingEntry{
iface: "eth0",
destination: net.IP{192, 168, 2, 0},
gateway: net.IP{10, 0, 0, 1},
flags: "0003",
mask: net.IPMask{255, 255, 255, 0},
},
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
r, err := parseRoutingEntry(tc.s)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
assert.Equal(t, tc.r, r)
}
})
}
}
func Test_reversedHexToIPv4(t *testing.T) {
t.Parallel()
tests := map[string]struct {
reversedHex string
IP net.IP
err error
}{
"empty hex": {
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
"bad hex": {
reversedHex: "x",
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"3 bytes hex": {
reversedHex: "9abcde",
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
"correct hex": {
reversedHex: "010011AC",
IP: []byte{0xac, 0x11, 0x0, 0x1},
err: nil},
"correct hex 2": {
reversedHex: "000011AC",
IP: []byte{0xac, 0x11, 0x0, 0x0},
err: nil},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
IP, err := reversedHexToIPv4(tc.reversedHex)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.IP, IP)
})
}
}
func Test_hexMaskToDecMask(t *testing.T) {
t.Parallel()
tests := map[string]struct {
hexString string
mask net.IPMask
err error
}{
"empty hex": {
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
"bad hex": {
hexString: "x",
err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"3 bytes hex": {
hexString: "9abcde",
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
"16": {
hexString: "0000FFFF",
mask: []byte{0xff, 0xff, 0x0, 0x0},
err: nil},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mask, err := hexToIPv4Mask(tc.hexString)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.mask, mask)
})
}
}

View File

@@ -0,0 +1,76 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/golibs/command (interfaces: Commander)
// Package routing is a generated GoMock package.
package routing
import (
gomock "github.com/golang/mock/gomock"
io "io"
reflect "reflect"
)
// MockCommander is a mock of Commander interface
type MockCommander struct {
ctrl *gomock.Controller
recorder *MockCommanderMockRecorder
}
// MockCommanderMockRecorder is the mock recorder for MockCommander
type MockCommanderMockRecorder struct {
mock *MockCommander
}
// NewMockCommander creates a new mock instance
func NewMockCommander(ctrl *gomock.Controller) *MockCommander {
mock := &MockCommander{ctrl: ctrl}
mock.recorder = &MockCommanderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockCommander) EXPECT() *MockCommanderMockRecorder {
return m.recorder
}
// Run mocks base method
func (m *MockCommander) Run(arg0 string, arg1 ...string) (string, error) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Run", varargs...)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Run indicates an expected call of Run
func (mr *MockCommanderMockRecorder) Run(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockCommander)(nil).Run), varargs...)
}
// Start mocks base method
func (m *MockCommander) Start(arg0 string, arg1 ...string) (io.ReadCloser, io.ReadCloser, func() error, error) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Start", varargs...)
ret0, _ := ret[0].(io.ReadCloser)
ret1, _ := ret[1].(io.ReadCloser)
ret2, _ := ret[2].(func() error)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
}
// Start indicates an expected call of Start
func (mr *MockCommanderMockRecorder) Start(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockCommander)(nil).Start), varargs...)
}

View File

@@ -0,0 +1,232 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/golibs/files (interfaces: FileManager)
// Package routing is a generated GoMock package.
package routing
import (
gomock "github.com/golang/mock/gomock"
files "github.com/qdm12/golibs/files"
os "os"
reflect "reflect"
)
// MockFileManager is a mock of FileManager interface
type MockFileManager struct {
ctrl *gomock.Controller
recorder *MockFileManagerMockRecorder
}
// MockFileManagerMockRecorder is the mock recorder for MockFileManager
type MockFileManagerMockRecorder struct {
mock *MockFileManager
}
// NewMockFileManager creates a new mock instance
func NewMockFileManager(ctrl *gomock.Controller) *MockFileManager {
mock := &MockFileManager{ctrl: ctrl}
mock.recorder = &MockFileManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockFileManager) EXPECT() *MockFileManagerMockRecorder {
return m.recorder
}
// CreateDir mocks base method
func (m *MockFileManager) CreateDir(arg0 string, arg1 ...files.WriteOptionSetter) error {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "CreateDir", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// CreateDir indicates an expected call of CreateDir
func (mr *MockFileManagerMockRecorder) CreateDir(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDir", reflect.TypeOf((*MockFileManager)(nil).CreateDir), varargs...)
}
// DirectoryExists mocks base method
func (m *MockFileManager) DirectoryExists(arg0 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DirectoryExists", arg0)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DirectoryExists indicates an expected call of DirectoryExists
func (mr *MockFileManagerMockRecorder) DirectoryExists(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DirectoryExists", reflect.TypeOf((*MockFileManager)(nil).DirectoryExists), arg0)
}
// FileExists mocks base method
func (m *MockFileManager) FileExists(arg0 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FileExists", arg0)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FileExists indicates an expected call of FileExists
func (mr *MockFileManagerMockRecorder) FileExists(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FileExists", reflect.TypeOf((*MockFileManager)(nil).FileExists), arg0)
}
// FilepathExists mocks base method
func (m *MockFileManager) FilepathExists(arg0 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilepathExists", arg0)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FilepathExists indicates an expected call of FilepathExists
func (mr *MockFileManagerMockRecorder) FilepathExists(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilepathExists", reflect.TypeOf((*MockFileManager)(nil).FilepathExists), arg0)
}
// GetOwnership mocks base method
func (m *MockFileManager) GetOwnership(arg0 string) (int, int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOwnership", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(int)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetOwnership indicates an expected call of GetOwnership
func (mr *MockFileManagerMockRecorder) GetOwnership(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOwnership", reflect.TypeOf((*MockFileManager)(nil).GetOwnership), arg0)
}
// GetUserPermissions mocks base method
func (m *MockFileManager) GetUserPermissions(arg0 string) (bool, bool, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUserPermissions", arg0)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(bool)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
}
// GetUserPermissions indicates an expected call of GetUserPermissions
func (mr *MockFileManagerMockRecorder) GetUserPermissions(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserPermissions", reflect.TypeOf((*MockFileManager)(nil).GetUserPermissions), arg0)
}
// ReadFile mocks base method
func (m *MockFileManager) ReadFile(arg0 string) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReadFile", arg0)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ReadFile indicates an expected call of ReadFile
func (mr *MockFileManagerMockRecorder) ReadFile(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFile", reflect.TypeOf((*MockFileManager)(nil).ReadFile), arg0)
}
// SetOwnership mocks base method
func (m *MockFileManager) SetOwnership(arg0 string, arg1, arg2 int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetOwnership", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// SetOwnership indicates an expected call of SetOwnership
func (mr *MockFileManagerMockRecorder) SetOwnership(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetOwnership", reflect.TypeOf((*MockFileManager)(nil).SetOwnership), arg0, arg1, arg2)
}
// SetUserPermissions mocks base method
func (m *MockFileManager) SetUserPermissions(arg0 string, arg1 os.FileMode) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetUserPermissions", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SetUserPermissions indicates an expected call of SetUserPermissions
func (mr *MockFileManagerMockRecorder) SetUserPermissions(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserPermissions", reflect.TypeOf((*MockFileManager)(nil).SetUserPermissions), arg0, arg1)
}
// Touch mocks base method
func (m *MockFileManager) Touch(arg0 string, arg1 ...files.WriteOptionSetter) error {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Touch", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Touch indicates an expected call of Touch
func (mr *MockFileManagerMockRecorder) Touch(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Touch", reflect.TypeOf((*MockFileManager)(nil).Touch), varargs...)
}
// WriteLinesToFile mocks base method
func (m *MockFileManager) WriteLinesToFile(arg0 string, arg1 []string, arg2 ...files.WriteOptionSetter) error {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "WriteLinesToFile", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// WriteLinesToFile indicates an expected call of WriteLinesToFile
func (mr *MockFileManagerMockRecorder) WriteLinesToFile(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteLinesToFile", reflect.TypeOf((*MockFileManager)(nil).WriteLinesToFile), varargs...)
}
// WriteToFile mocks base method
func (m *MockFileManager) WriteToFile(arg0 string, arg1 []byte, arg2 ...files.WriteOptionSetter) error {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "WriteToFile", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// WriteToFile indicates an expected call of WriteToFile
func (mr *MockFileManagerMockRecorder) WriteToFile(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteToFile", reflect.TypeOf((*MockFileManager)(nil).WriteToFile), varargs...)
}

View File

@@ -0,0 +1,126 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/golibs/logging (interfaces: Logger)
// Package routing is a generated GoMock package.
package routing
import (
gomock "github.com/golang/mock/gomock"
logging "github.com/qdm12/golibs/logging"
reflect "reflect"
)
// MockLogger is a mock of Logger interface
type MockLogger struct {
ctrl *gomock.Controller
recorder *MockLoggerMockRecorder
}
// MockLoggerMockRecorder is the mock recorder for MockLogger
type MockLoggerMockRecorder struct {
mock *MockLogger
}
// NewMockLogger creates a new mock instance
func NewMockLogger(ctrl *gomock.Controller) *MockLogger {
mock := &MockLogger{ctrl: ctrl}
mock.recorder = &MockLoggerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
return m.recorder
}
// Debug mocks base method
func (m *MockLogger) Debug(arg0 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Debug", varargs...)
}
// Debug indicates an expected call of Debug
func (mr *MockLoggerMockRecorder) Debug(arg0 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0...)
}
// Error mocks base method
func (m *MockLogger) Error(arg0 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Error", varargs...)
}
// Error indicates an expected call of Error
func (mr *MockLoggerMockRecorder) Error(arg0 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0...)
}
// Info mocks base method
func (m *MockLogger) Info(arg0 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Info", varargs...)
}
// Info indicates an expected call of Info
func (mr *MockLoggerMockRecorder) Info(arg0 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0...)
}
// Sync mocks base method
func (m *MockLogger) Sync() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Sync")
ret0, _ := ret[0].(error)
return ret0
}
// Sync indicates an expected call of Sync
func (mr *MockLoggerMockRecorder) Sync() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sync", reflect.TypeOf((*MockLogger)(nil).Sync))
}
// Warn mocks base method
func (m *MockLogger) Warn(arg0 ...interface{}) {
m.ctrl.T.Helper()
varargs := []interface{}{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
m.ctrl.Call(m, "Warn", varargs...)
}
// Warn indicates an expected call of Warn
func (mr *MockLoggerMockRecorder) Warn(arg0 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warn", reflect.TypeOf((*MockLogger)(nil).Warn), arg0...)
}
// WithPrefix mocks base method
func (m *MockLogger) WithPrefix(arg0 string) logging.Logger {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WithPrefix", arg0)
ret0, _ := ret[0].(logging.Logger)
return ret0
}
// WithPrefix indicates an expected call of WithPrefix
func (mr *MockLoggerMockRecorder) WithPrefix(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithPrefix", reflect.TypeOf((*MockLogger)(nil).WithPrefix), arg0)
}

View File

@@ -0,0 +1,34 @@
package routing
import (
"net"
"fmt"
)
func (r *routing) AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error {
for _, subnet := range subnets {
exists, err := r.routeExists(subnet)
if err != nil {
return err
} else if exists { // thanks to @npawelek https://github.com/npawelek
if err := r.removeRoute(subnet); err != nil {
return err
}
}
r.logger.Info("adding %s as route via %s", subnet.String(), defaultInterface)
output, err := r.commander.Run("ip", "route", "add", subnet.String(), "via", defaultGateway.String(), "dev", defaultInterface)
if err != nil {
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnet.String(), defaultGateway.String(), "dev", defaultInterface, output, err)
}
}
return nil
}
func (r *routing) removeRoute(subnet net.IPNet) (err error) {
output, err := r.commander.Run("ip", "route", "del", subnet.String())
if err != nil {
return fmt.Errorf("cannot delete route for %s: %s: %w", subnet.String(), output, err)
}
return nil
}

View File

@@ -0,0 +1,67 @@
package routing
import (
"fmt"
"net"
"testing"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
//go:generate mockgen -destination=mockCommander_test.go -package=routing github.com/qdm12/golibs/command Commander
func Test_removeRoute(t *testing.T) {
t.Parallel()
tests := map[string]struct {
subnet net.IPNet
runOutput string
runErr error
err error
}{
"no output no error": {
subnet: net.IPNet{
IP: net.IP{192, 168, 1, 0},
Mask: net.IPMask{255, 255, 255, 0},
},
},
"error only": {
subnet: net.IPNet{
IP: net.IP{192, 168, 1, 0},
Mask: net.IPMask{255, 255, 255, 0},
},
runErr: fmt.Errorf("error"),
err: fmt.Errorf("cannot delete route for 192.168.1.0/24: : error"),
},
"error and output": {
subnet: net.IPNet{
IP: net.IP{192, 168, 1, 0},
Mask: net.IPMask{255, 255, 255, 0},
},
runErr: fmt.Errorf("error"),
runOutput: "output",
err: fmt.Errorf("cannot delete route for 192.168.1.0/24: output: error"),
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockCommander := NewMockCommander(mockCtrl)
mockCommander.EXPECT().Run("ip", "route", "del", tc.subnet.String()).
Return(tc.runOutput, tc.runErr).Times(1)
r := &routing{commander: mockCommander}
err := r.removeRoute(tc.subnet)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
}
}

104
internal/routing/reader.go Normal file
View File

@@ -0,0 +1,104 @@
package routing
import (
"bytes"
"net"
"fmt"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func parseRoutingTable(data []byte) (entries []routingEntry, err error) {
lines := strings.Split(strings.TrimSuffix(string(data), "\n"), "\n")
lines = lines[1:]
entries = make([]routingEntry, len(lines))
for i := range lines {
entries[i], err = parseRoutingEntry(lines[i])
if err != nil {
return nil, fmt.Errorf("line %d in %s: %w", i+1, constants.NetRoute, err)
}
}
return entries, nil
}
func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) {
r.logger.Info("detecting default network route")
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
if err != nil {
return "", nil, defaultSubnet, err
}
entries, err := parseRoutingTable(data)
if err != nil {
return "", nil, defaultSubnet, err
}
if len(entries) < 2 {
return "", nil, defaultSubnet, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute)
}
defaultInterface = entries[0].iface
defaultGateway = entries[0].gateway
defaultSubnet = net.IPNet{IP: entries[1].destination, Mask: entries[1].mask}
r.logger.Info("default route found: interface %s, gateway %s, subnet %s", defaultInterface, defaultGateway.String(), defaultSubnet.String())
return defaultInterface, defaultGateway, defaultSubnet, nil
}
func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) {
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
if err != nil {
return false, fmt.Errorf("cannot check route existence: %w", err)
}
entries, err := parseRoutingTable(data)
if err != nil {
return false, fmt.Errorf("cannot check route existence: %w", err)
}
for _, entry := range entries {
entrySubnet := net.IPNet{IP: entry.destination, Mask: entry.mask}
if entrySubnet.String() == subnet.String() {
return true, nil
}
}
return false, nil
}
func (r *routing) CurrentPublicIP(defaultInterface string) (ip net.IP, err error) {
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
if err != nil {
return nil, fmt.Errorf("cannot find current IP address: %w", err)
}
entries, err := parseRoutingTable(data)
if err != nil {
return nil, fmt.Errorf("cannot find current IP address: %w", err)
}
for _, entry := range entries {
if entry.iface == defaultInterface &&
!ipIsPrivate(entry.destination) &&
bytes.Equal(entry.mask, net.IPMask{255, 255, 255, 255}) {
return entry.destination, nil
}
}
return nil, fmt.Errorf("cannot find current IP address from ip routes")
}
func ipIsPrivate(ip net.IP) bool {
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
privateCIDRBlocks := [8]string{
"127.0.0.0/8", // localhost
"10.0.0.0/8", // 24-bit block
"172.16.0.0/12", // 20-bit block
"192.168.0.0/16", // 16-bit block
"169.254.0.0/16", // link local address
"::1/128", // localhost IPv6
"fc00::/7", // unique local address IPv6
"fe80::/10", // link local address IPv6
}
for i := range privateCIDRBlocks {
_, CIDR, _ := net.ParseCIDR(privateCIDRBlocks[i])
if CIDR.Contains(ip) {
return true
}
}
return false
}

View File

@@ -0,0 +1,285 @@
package routing
import (
"fmt"
"net"
"testing"
gomock "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
//go:generate mockgen -destination=mockLogger_test.go -package=routing github.com/qdm12/golibs/logging Logger
//go:generate mockgen -destination=mockFilemanager_test.go -package=routing github.com/qdm12/golibs/files FileManager
func Test_parseRoutingTable(t *testing.T) {
t.Parallel()
tests := map[string]struct {
data []byte
entries []routingEntry
err error
}{
"nil data": {
entries: []routingEntry{},
},
"legend only": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
`),
entries: []routingEntry{},
},
"legend and single line": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
`),
entries: []routingEntry{{
iface: "eth0",
destination: net.IP{192, 168, 2, 0},
gateway: net.IP{10, 0, 0, 1},
flags: "0003",
mask: net.IPMask{255, 255, 255, 0},
}},
},
"legend and two lines": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
eth0 0002A8C0 0100000A 0002 0 0 0 00FFFFFF 0 0 0
`),
entries: []routingEntry{
{
iface: "eth0",
destination: net.IP{192, 168, 2, 0},
gateway: net.IP{10, 0, 0, 1},
flags: "0003",
mask: net.IPMask{255, 255, 255, 0},
},
{
iface: "eth0",
destination: net.IP{192, 168, 2, 0},
gateway: net.IP{10, 0, 0, 1},
flags: "0002",
mask: net.IPMask{255, 255, 255, 0},
}},
},
"error": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0
`),
entries: nil,
err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"),
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
entries, err := parseRoutingTable(tc.data)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.entries, entries)
})
}
}
func Test_DefaultRoute(t *testing.T) {
t.Parallel()
tests := map[string]struct {
data []byte
readErr error
defaultInterface string
defaultGateway net.IP
defaultSubnet net.IPNet
err error
}{
"no data": {
err: fmt.Errorf("not enough entries (0) found in %s", constants.NetRoute)},
"read error": {
readErr: fmt.Errorf("error"),
err: fmt.Errorf("error")},
"parse error": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 x
`),
err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x\": not enough fields")},
"single entry": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 050A090A 0003 0 0 0 00000080 0 0 0
`),
err: fmt.Errorf("not enough entries (1) found in %s", constants.NetRoute)},
"success": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0
`),
defaultInterface: "eth0",
defaultGateway: net.IP{172, 17, 0, 1},
defaultSubnet: net.IPNet{
IP: net.IP{172, 17, 0, 0},
Mask: net.IPMask{255, 255, 0, 0},
}},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockLogger := NewMockLogger(mockCtrl)
mockFilemanager := NewMockFileManager(mockCtrl)
mockFilemanager.EXPECT().ReadFile(string(constants.NetRoute)).
Return(tc.data, tc.readErr).Times(1)
mockLogger.EXPECT().Info("detecting default network route").Times(1)
if tc.err == nil {
mockLogger.EXPECT().Info(
"default route found: interface %s, gateway %s, subnet %s",
tc.defaultInterface, tc.defaultGateway.String(), tc.defaultSubnet.String(),
).Times(1)
}
r := &routing{logger: mockLogger, fileManager: mockFilemanager}
defaultInterface, defaultGateway, defaultSubnet, err := r.DefaultRoute()
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.defaultInterface, defaultInterface)
assert.Equal(t, tc.defaultGateway, defaultGateway)
assert.Equal(t, tc.defaultSubnet, defaultSubnet)
})
}
}
func Test_routeExists(t *testing.T) {
t.Parallel()
tests := map[string]struct {
subnet net.IPNet
data []byte
readErr error
exists bool
err error
}{
"no data": {},
"read error": {
readErr: fmt.Errorf("error"),
err: fmt.Errorf("cannot check route existence: error"),
},
"parse error": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 x
`),
err: fmt.Errorf("cannot check route existence: line 1 in /proc/net/route: line \"eth0 x\": not enough fields"),
},
"not existing": {
subnet: net.IPNet{
IP: net.IP{192, 168, 2, 0},
Mask: net.IPMask{255, 255, 255, 128},
},
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
`),
},
"existing": {
subnet: net.IPNet{
IP: net.IP{192, 168, 2, 0},
Mask: net.IPMask{255, 255, 255, 0},
},
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
`),
exists: true,
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockFilemanager := NewMockFileManager(mockCtrl)
mockFilemanager.EXPECT().ReadFile(string(constants.NetRoute)).
Return(tc.data, tc.readErr).Times(1)
r := &routing{fileManager: mockFilemanager}
exists, err := r.routeExists(tc.subnet)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.exists, exists)
})
}
}
func Test_CurrentIP(t *testing.T) {
t.Parallel()
const exampleRouteData = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
tun0 00000000 050A090A 0003 0 0 0 00000080 0 0 0
eth0 00000000 0100000A 0003 0 0 0 00000000 0 0 0
eth0 0000000A 00000000 0001 0 0 0 00FFFFFF 0 0 0
tun0 010A090A 050A090A 0007 0 0 0 FFFFFFFF 0 0 0
tun0 050A090A 00000000 0005 0 0 0 FFFFFFFF 0 0 0
eth0 2194B05F 0100000A 0007 0 0 0 FFFFFFFF 0 0 0
tun0 00000080 050A090A 0003 0 0 0 00000080 0 0 0
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
`
tests := map[string]struct {
defaultInterface string
data []byte
readErr error
ip net.IP
err error
}{
"no data": {
err: fmt.Errorf("cannot find current IP address from ip routes"),
},
"read error": {
readErr: fmt.Errorf("error"),
err: fmt.Errorf("cannot find current IP address: error"),
},
"parse error": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 x
`),
err: fmt.Errorf("cannot find current IP address: line 1 in /proc/net/route: line \"eth0 x\": not enough fields"),
},
"found eth0": {
defaultInterface: "eth0",
data: []byte(exampleRouteData),
ip: net.IP{95, 176, 148, 33},
},
"not found tun0": {
defaultInterface: "tun0",
data: []byte(exampleRouteData),
err: fmt.Errorf("cannot find current IP address from ip routes"),
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
mockFilemanager := NewMockFileManager(mockCtrl)
mockFilemanager.EXPECT().ReadFile(string(constants.NetRoute)).
Return(tc.data, tc.readErr).Times(1)
r := &routing{fileManager: mockFilemanager}
ip, err := r.CurrentPublicIP(tc.defaultInterface)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.ip, ip)
})
}
}

View File

@@ -0,0 +1,30 @@
package routing
import (
"net"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
type Routing interface {
AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error
DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error)
CurrentPublicIP(defaultInterface string) (ip net.IP, err error)
}
type routing struct {
commander command.Commander
logger logging.Logger
fileManager files.FileManager
}
// NewConfigurator creates a new Configurator instance
func NewRouting(logger logging.Logger, fileManager files.FileManager) Routing {
return &routing{
commander: command.NewCommander(),
logger: logger.WithPrefix("routing: "),
fileManager: fileManager,
}
}

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/qdm12/private-internet-access-docker/internal/models"
"github.com/qdm12/private-internet-access-docker/internal/params" "github.com/qdm12/private-internet-access-docker/internal/params"
) )
@@ -12,6 +13,7 @@ type System struct {
UID int UID int
GID int GID int
Timezone string Timezone string
IPStatusFilepath models.Filepath
} }
// GetSystemSettings obtains the System settings using the params functions // GetSystemSettings obtains the System settings using the params functions
@@ -28,6 +30,10 @@ func GetSystemSettings(params params.ParamsReader) (settings System, err error)
if err != nil { if err != nil {
return settings, err return settings, err
} }
settings.IPStatusFilepath, err = params.GetIPStatusFilepath()
if err != nil {
return settings, err
}
return settings, nil return settings, nil
} }
@@ -37,6 +43,7 @@ func (s *System) String() string {
fmt.Sprintf("User ID: %d", s.UID), fmt.Sprintf("User ID: %d", s.UID),
fmt.Sprintf("Group ID: %d", s.GID), fmt.Sprintf("Group ID: %d", s.GID),
fmt.Sprintf("Timezone: %s", s.Timezone), fmt.Sprintf("Timezone: %s", s.Timezone),
fmt.Sprintf("IP Status filepath: %s", s.IPStatusFilepath),
} }
return strings.Join(settingsList, "\n|--") return strings.Join(settingsList, "\n|--")
} }