Rewrite of the entrypoint in Golang (#71)
- General improvements
- Parallel download of only needed files at start
- Prettier console output with all streams merged (openvpn, unbound, shadowsocks etc.)
- Simplified Docker final image
- Faster bootup
- DNS over TLS
- Finer grain blocking at DNS level: malicious, ads and surveillance
- Choose your DNS over TLS providers
- Ability to use multiple DNS over TLS providers for DNS split horizon
- Environment variables for DNS logging
- DNS block lists needed are downloaded and built automatically at start, in parallel
- PIA
- A random region is selected if the REGION parameter is left empty (thanks @rorph for your PR)
- Routing and iptables adjusted so it can work as a Kubernetes pod sidecar (thanks @rorph for your PR)
This commit is contained in:
42
internal/firewall/firewall.go
Normal file
42
internal/firewall/firewall.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
const logPrefix = "firewall configurator"
|
||||
|
||||
// Configurator allows to change firewall rules and modify network routes
|
||||
type Configurator interface {
|
||||
Version() (string, error)
|
||||
AcceptAll() error
|
||||
Clear() error
|
||||
BlockAll() error
|
||||
CreateGeneralRules() error
|
||||
CreateVPNRules(dev models.VPNDevice, serverIPs []net.IP, defaultInterface string,
|
||||
port uint16, protocol models.NetworkProtocol) error
|
||||
CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error
|
||||
AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error
|
||||
GetDefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error)
|
||||
AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
commander command.Commander
|
||||
logger logging.Logger
|
||||
fileManager files.FileManager
|
||||
}
|
||||
|
||||
// NewConfigurator creates a new Configurator instance
|
||||
func NewConfigurator(logger logging.Logger, fileManager files.FileManager) Configurator {
|
||||
return &configurator{
|
||||
commander: command.NewCommander(),
|
||||
logger: logger,
|
||||
fileManager: fileManager,
|
||||
}
|
||||
}
|
||||
130
internal/firewall/iptables.go
Normal file
130
internal/firewall/iptables.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
// Version obtains the version of the installed iptables
|
||||
func (c *configurator) Version() (string, error) {
|
||||
output, err := c.commander.Run("iptables", "--version")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
words := strings.Fields(output)
|
||||
if len(words) < 2 {
|
||||
return "", fmt.Errorf("iptables --version: output is too short: %q", output)
|
||||
}
|
||||
return words[1], nil
|
||||
}
|
||||
|
||||
func (c *configurator) runIptablesInstructions(instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runIptablesInstruction(instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) runIptablesInstruction(instruction string) error {
|
||||
flags := strings.Fields(instruction)
|
||||
if output, err := c.commander.Run("iptables", flags...); err != nil {
|
||||
return fmt.Errorf("failed executing %q: %s: %w", instruction, output, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) Clear() error {
|
||||
c.logger.Info("%s: clearing all rules", logPrefix)
|
||||
return c.runIptablesInstructions([]string{
|
||||
"--flush",
|
||||
"--delete-chain",
|
||||
"-t nat --flush",
|
||||
"-t nat --delete-chain",
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) AcceptAll() error {
|
||||
c.logger.Info("%s: accepting all traffic", logPrefix)
|
||||
return c.runIptablesInstructions([]string{
|
||||
"-P INPUT ACCEPT",
|
||||
"-P OUTPUT ACCEPT",
|
||||
"-P FORWARD ACCEPT",
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) BlockAll() error {
|
||||
c.logger.Info("%s: blocking all traffic", logPrefix)
|
||||
return c.runIptablesInstructions([]string{
|
||||
"-P INPUT DROP",
|
||||
"-F OUTPUT",
|
||||
"-P OUTPUT DROP",
|
||||
"-P FORWARD DROP",
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) CreateGeneralRules() error {
|
||||
c.logger.Info("%s: creating general rules", logPrefix)
|
||||
return c.runIptablesInstructions([]string{
|
||||
"-A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
"-A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
"-A OUTPUT -o lo -j ACCEPT",
|
||||
"-A INPUT -i lo -j ACCEPT",
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) CreateVPNRules(dev models.VPNDevice, serverIPs []net.IP,
|
||||
defaultInterface string, port uint16, protocol models.NetworkProtocol) error {
|
||||
for _, serverIP := range serverIPs {
|
||||
c.logger.Info("%s: allowing output traffic to VPN server %s through %s on port %s %d",
|
||||
logPrefix, serverIP, defaultInterface, protocol, port)
|
||||
if err := c.runIptablesInstruction(
|
||||
fmt.Sprintf("-A OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
serverIP, defaultInterface, protocol, protocol, port)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := c.runIptablesInstruction(fmt.Sprintf("-A OUTPUT -o %s -j ACCEPT", dev)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error {
|
||||
subnetStr := subnet.String()
|
||||
c.logger.Info("%s: accepting input and output traffic for %s", logPrefix, subnetStr)
|
||||
if err := c.runIptablesInstructions([]string{
|
||||
fmt.Sprintf("-A INPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
|
||||
fmt.Sprintf("-A OUTPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, extraSubnet := range extraSubnets {
|
||||
extraSubnetStr := extraSubnet.String()
|
||||
c.logger.Info("%s: accepting input traffic through %s from %s to %s", logPrefix, defaultInterface, extraSubnetStr, subnetStr)
|
||||
if err := c.runIptablesInstruction(
|
||||
fmt.Sprintf("-A INPUT -i %s -s %s -d %s -j ACCEPT", defaultInterface, extraSubnetStr, subnetStr)); err != nil {
|
||||
return err
|
||||
}
|
||||
// Thanks to @npawelek
|
||||
c.logger.Info("%s: accepting output traffic through %s from %s to %s", logPrefix, defaultInterface, subnetStr, extraSubnetStr)
|
||||
if err := c.runIptablesInstruction(
|
||||
fmt.Sprintf("-A OUTPUT -o %s -s %s -d %s -j ACCEPT", defaultInterface, subnetStr, extraSubnetStr)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Used for port forwarding
|
||||
func (c *configurator) AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error {
|
||||
c.logger.Info("%s: accepting input traffic through %s on port %d", logPrefix, device, port)
|
||||
return c.runIptablesInstructions([]string{
|
||||
fmt.Sprintf("-A INPUT -i %s -p tcp --dport %d -j ACCEPT", device, port),
|
||||
fmt.Sprintf("-A INPUT -i %s -p udp --dport %d -j ACCEPT", device, port),
|
||||
})
|
||||
}
|
||||
88
internal/firewall/route.go
Normal file
88
internal/firewall/route.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"net"
|
||||
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error {
|
||||
for _, subnet := range subnets {
|
||||
subnetStr := subnet.String()
|
||||
output, err := c.commander.Run("ip", "route", "show", subnetStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read route %s: %s: %w", subnetStr, output, err)
|
||||
} else if len(output) > 0 { // thanks to @npawelek https://github.com/npawelek
|
||||
continue // already exists
|
||||
// TODO remove it instead and continue execution below
|
||||
}
|
||||
c.logger.Info("%s: adding %s as route via %s", logPrefix, subnetStr, defaultInterface)
|
||||
output, err = c.commander.Run("ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnetStr, defaultGateway.String(), "dev", defaultInterface, output, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) GetDefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) {
|
||||
c.logger.Info("%s: detecting default network route", logPrefix)
|
||||
data, err := c.fileManager.ReadFile(string(constants.NetRoute))
|
||||
if err != nil {
|
||||
return "", nil, defaultSubnet, err
|
||||
}
|
||||
// Verify number of lines and fields
|
||||
lines := strings.Split(string(data), "\n")
|
||||
if len(lines) < 3 {
|
||||
return "", nil, defaultSubnet, fmt.Errorf("not enough lines (%d) found in %s", len(lines), constants.NetRoute)
|
||||
}
|
||||
fieldsLine1 := strings.Fields(lines[1])
|
||||
if len(fieldsLine1) < 3 {
|
||||
return "", nil, defaultSubnet, fmt.Errorf("not enough fields in %q", lines[1])
|
||||
}
|
||||
fieldsLine2 := strings.Fields(lines[2])
|
||||
if len(fieldsLine2) < 8 {
|
||||
return "", nil, defaultSubnet, fmt.Errorf("not enough fields in %q", lines[2])
|
||||
}
|
||||
// get information
|
||||
defaultInterface = fieldsLine1[0]
|
||||
defaultGateway, err = reversedHexToIPv4(fieldsLine1[2])
|
||||
if err != nil {
|
||||
return "", nil, defaultSubnet, err
|
||||
}
|
||||
netNumber, err := reversedHexToIPv4(fieldsLine2[1])
|
||||
if err != nil {
|
||||
return "", nil, defaultSubnet, err
|
||||
}
|
||||
netMask, err := hexToIPv4Mask(fieldsLine2[7])
|
||||
if err != nil {
|
||||
return "", nil, defaultSubnet, err
|
||||
}
|
||||
subnet := net.IPNet{IP: netNumber, Mask: netMask}
|
||||
c.logger.Info("%s: default route found: interface %s, gateway %s, subnet %s", logPrefix, defaultInterface, defaultGateway.String(), subnet.String())
|
||||
return defaultInterface, defaultGateway, subnet, nil
|
||||
}
|
||||
|
||||
func reversedHexToIPv4(reversedHex string) (IP net.IP, err error) {
|
||||
bytes, err := hex.DecodeString(reversedHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse reversed IP hex %q: %s", reversedHex, err)
|
||||
} else if len(bytes) != 4 {
|
||||
return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes))
|
||||
}
|
||||
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
|
||||
}
|
||||
|
||||
func hexToIPv4Mask(hexString string) (mask net.IPMask, err error) {
|
||||
bytes, err := hex.DecodeString(hexString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse hex mask %q: %s", hexString, err)
|
||||
} else if len(bytes) != 4 {
|
||||
return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes))
|
||||
}
|
||||
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
|
||||
}
|
||||
171
internal/firewall/route_test.go
Normal file
171
internal/firewall/route_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
filesmocks "github.com/qdm12/golibs/files/mocks"
|
||||
loggingmocks "github.com/qdm12/golibs/logging/mocks"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func Test_getDefaultRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
data []byte
|
||||
readErr error
|
||||
defaultInterface string
|
||||
defaultGateway net.IP
|
||||
defaultSubnet net.IPNet
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
err: fmt.Errorf("not enough lines (1) found in %s", constants.NetRoute)},
|
||||
"read error": {
|
||||
readErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error")},
|
||||
"not enough fields line 1": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000
|
||||
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
|
||||
err: fmt.Errorf("not enough fields in \"eth0 00000000\"")},
|
||||
"not enough fields line 2": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
||||
eth0 000011AC 00000000 0001 0 0 0`),
|
||||
err: fmt.Errorf("not enough fields in \"eth0 000011AC 00000000 0001 0 0 0\"")},
|
||||
"bad gateway": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 x 0003 0 0 0 00000000 0 0 0
|
||||
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
|
||||
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
||||
"bad net number": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
||||
eth0 x 00000000 0001 0 0 0 0000FFFF 0 0 0`),
|
||||
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
||||
"bad net mask": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
||||
eth0 000011AC 00000000 0001 0 0 0 x 0 0 0`),
|
||||
err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
||||
"success": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
||||
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
|
||||
defaultInterface: "eth0",
|
||||
defaultGateway: net.IP{0xac, 0x11, 0x0, 0x1},
|
||||
defaultSubnet: net.IPNet{
|
||||
IP: net.IP{0xac, 0x11, 0x0, 0x0},
|
||||
Mask: net.IPMask{0xff, 0xff, 0x0, 0x0},
|
||||
}},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fileManager := &filesmocks.FileManager{}
|
||||
fileManager.On("ReadFile", string(constants.NetRoute)).
|
||||
Return(tc.data, tc.readErr).Once()
|
||||
logger := &loggingmocks.Logger{}
|
||||
logger.On("Info", "%s: detecting default network route", logPrefix).Once()
|
||||
if tc.err == nil {
|
||||
logger.On("Info", "%s: default route found: interface %s, gateway %s, subnet %s",
|
||||
logPrefix, tc.defaultInterface, tc.defaultGateway.String(), tc.defaultSubnet.String()).Once()
|
||||
}
|
||||
c := &configurator{logger: logger, fileManager: fileManager}
|
||||
defaultInterface, defaultGateway, defaultSubnet, err := c.GetDefaultRoute()
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.defaultInterface, defaultInterface)
|
||||
assert.Equal(t, tc.defaultGateway, defaultGateway)
|
||||
assert.Equal(t, tc.defaultSubnet, defaultSubnet)
|
||||
fileManager.AssertExpectations(t)
|
||||
logger.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_reversedHexToIPv4(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
reversedHex string
|
||||
IP net.IP
|
||||
err error
|
||||
}{
|
||||
"empty hex": {
|
||||
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
|
||||
"bad hex": {
|
||||
reversedHex: "x",
|
||||
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
||||
"3 bytes hex": {
|
||||
reversedHex: "9abcde",
|
||||
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
|
||||
"correct hex": {
|
||||
reversedHex: "010011AC",
|
||||
IP: []byte{0xac, 0x11, 0x0, 0x1},
|
||||
err: nil},
|
||||
"correct hex 2": {
|
||||
reversedHex: "000011AC",
|
||||
IP: []byte{0xac, 0x11, 0x0, 0x0},
|
||||
err: nil},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
IP, err := reversedHexToIPv4(tc.reversedHex)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.IP, IP)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_hexMaskToDecMask(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
hexString string
|
||||
mask net.IPMask
|
||||
err error
|
||||
}{
|
||||
"empty hex": {
|
||||
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
|
||||
"bad hex": {
|
||||
hexString: "x",
|
||||
err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
||||
"3 bytes hex": {
|
||||
hexString: "9abcde",
|
||||
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
|
||||
"16": {
|
||||
hexString: "0000FFFF",
|
||||
mask: []byte{0xff, 0xff, 0x0, 0x0},
|
||||
err: nil},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mask, err := hexToIPv4Mask(tc.hexString)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.mask, mask)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user