Rewrite of the entrypoint in Golang (#71)

- General improvements
    - Parallel download of only needed files at start
    - Prettier console output with all streams merged (openvpn, unbound, shadowsocks etc.)
    - Simplified Docker final image
    - Faster bootup
- DNS over TLS
    - Finer grain blocking at DNS level: malicious, ads and surveillance
    - Choose your DNS over TLS providers
    - Ability to use multiple DNS over TLS providers for DNS split horizon
    - Environment variables for DNS logging
    - DNS block lists needed are downloaded and built automatically at start, in parallel
- PIA
    - A random region is selected if the REGION parameter is left empty (thanks @rorph for your PR)
    - Routing and iptables adjusted so it can work as a Kubernetes pod sidecar (thanks @rorph for your PR)
This commit is contained in:
Quentin McGaw
2020-02-06 20:42:46 -05:00
committed by GitHub
parent 3de4ffcf66
commit 64649039d9
74 changed files with 4598 additions and 1019 deletions

63
internal/constants/dns.go Normal file
View File

@@ -0,0 +1,63 @@
package constants
import (
"github.com/qdm12/private-internet-access-docker/internal/models"
)
const (
// Cloudflare is a DNS over TLS provider
Cloudflare models.DNSProvider = "cloudflare"
// Google is a DNS over TLS provider
Google models.DNSProvider = "google"
// Quad9 is a DNS over TLS provider
Quad9 models.DNSProvider = "quad9"
// Quadrant is a DNS over TLS provider
Quadrant models.DNSProvider = "quadrant"
// CleanBrowsing is a DNS over TLS provider
CleanBrowsing models.DNSProvider = "cleanbrowsing"
// SecureDNS is a DNS over TLS provider
SecureDNS models.DNSProvider = "securedns"
// LibreDNS is a DNS over TLS provider
LibreDNS models.DNSProvider = "libredns"
)
const (
CloudflareAddress1 models.DNSForwardAddress = "1.1.1.1@853#cloudflare-dns.com"
CloudflareAddress2 models.DNSForwardAddress = "1.0.0.1@853#cloudflare-dns.com"
GoogleAddress1 models.DNSForwardAddress = "8.8.8.8@853#dns.google"
GoogleAddress2 models.DNSForwardAddress = "8.8.4.4@853#dns.google"
Quad9Address1 models.DNSForwardAddress = "9.9.9.9@853#dns.quad9.net"
Quad9Address2 models.DNSForwardAddress = "149.112.112.112@853#dns.quad9.net"
QuadrantAddress models.DNSForwardAddress = "12.159.2.159@853#dns-tls.qis.io"
CleanBrowsingAddress1 models.DNSForwardAddress = "185.228.168.9@853#security-filter-dns.cleanbrowsing.org"
CleanBrowsingAddress2 models.DNSForwardAddress = "185.228.169.9@853#security-filter-dns.cleanbrowsing.org"
SecureDNSAddress models.DNSForwardAddress = "146.185.167.43@853#dot.securedns.eu"
LibreDNSAddress models.DNSForwardAddress = "116.203.115.192@853#dot.libredns.gr"
)
var DNSAddressesMapping = map[models.DNSProvider][]models.DNSForwardAddress{
Cloudflare: []models.DNSForwardAddress{CloudflareAddress1, CloudflareAddress2},
Google: []models.DNSForwardAddress{GoogleAddress1, GoogleAddress2},
Quad9: []models.DNSForwardAddress{Quad9Address1, Quad9Address2},
Quadrant: []models.DNSForwardAddress{QuadrantAddress},
CleanBrowsing: []models.DNSForwardAddress{CleanBrowsingAddress1, CleanBrowsingAddress2},
SecureDNS: []models.DNSForwardAddress{SecureDNSAddress},
LibreDNS: []models.DNSForwardAddress{LibreDNSAddress},
}
// Block lists URLs
const (
AdsBlockListHostnamesURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated"
AdsBlockListIPsURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated"
MaliciousBlockListHostnamesURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/malicious-hostnames.updated"
MaliciousBlockListIPsURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/malicious-ips.updated"
SurveillanceBlockListHostnamesURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated"
SurveillanceBlockListIPsURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated"
)
// DNS certificates to fetch
// TODO obtain from source directly, see qdm12/updated)
const (
NamedRootURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/named.root.updated"
RootKeyURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/root.key.updated"
)

View File

@@ -0,0 +1,10 @@
package constants
import (
"github.com/qdm12/private-internet-access-docker/internal/models"
)
const (
TUN models.VPNDevice = "tun0"
TAP models.VPNDevice = "tap0"
)

View File

@@ -0,0 +1,28 @@
package constants
import (
"github.com/qdm12/private-internet-access-docker/internal/models"
)
const (
// UnboundConf is the file path to the Unbound configuration file
UnboundConf models.Filepath = "/etc/unbound/unbound.conf"
// ResolvConf is the file path to the system resolv.conf file
ResolvConf models.Filepath = "/etc/resolv.conf"
// OpenVPNAuthConf is the file path to the OpenVPN auth file
OpenVPNAuthConf models.Filepath = "/etc/openvpn/auth.conf"
// OpenVPNConf is the file path to the OpenVPN client configuration file
OpenVPNConf models.Filepath = "/etc/openvpn/target.ovpn"
// TunnelDevice is the file path to tun device
TunnelDevice models.Filepath = "/dev/net/tun"
// NetRoute is the path to the file containing information on the network route
NetRoute models.Filepath = "/proc/net/route"
// TinyProxyConf is the filepath to the tinyproxy configuration file
TinyProxyConf models.Filepath = "/etc/tinyproxy/tinyproxy.conf"
// ShadowsocksConf is the filepath to the shadowsocks configuration file
ShadowsocksConf models.Filepath = "/etc/shadowsocks.json"
// RootHints is the filepath to the root.hints file used by Unbound
RootHints models.Filepath = "/etc/unbound/root.hints"
// RootKey is the filepath to the root.key file used by Unbound
RootKey models.Filepath = "/etc/unbound/root.key"
)

70
internal/constants/pia.go Normal file
View File

@@ -0,0 +1,70 @@
package constants
import (
"github.com/qdm12/private-internet-access-docker/internal/models"
)
const (
// PIAEncryptionNormal is the normal level of encryption for communication with PIA servers
PIAEncryptionNormal models.PIAEncryption = "normal"
// PIAEncryptionStrong is the strong level of encryption for communication with PIA servers
PIAEncryptionStrong models.PIAEncryption = "strong"
)
const (
AUMelbourne models.PIARegion = "AU Melbourne"
AUPerth models.PIARegion = "AU Perth"
AUSydney models.PIARegion = "AU Sydney"
Austria models.PIARegion = "Austria"
Belgium models.PIARegion = "Belgium"
CAMontreal models.PIARegion = "CA Montreal"
CAToronto models.PIARegion = "CA Toronto"
CAVancouver models.PIARegion = "CA Vancouver"
CzechRepublic models.PIARegion = "Czech Republic"
DEBerlin models.PIARegion = "DE Berlin"
DEFrankfurt models.PIARegion = "DE Frankfurt"
Denmark models.PIARegion = "Denmark"
Finland models.PIARegion = "Finland"
France models.PIARegion = "France"
HongKong models.PIARegion = "Hong Kong"
Hungary models.PIARegion = "Hungary"
India models.PIARegion = "India"
Ireland models.PIARegion = "Ireland"
Israel models.PIARegion = "Israel"
Italy models.PIARegion = "Italy"
Japan models.PIARegion = "Japan"
Luxembourg models.PIARegion = "Luxembourg"
Mexico models.PIARegion = "Mexico"
Netherlands models.PIARegion = "Netherlands"
NewZealand models.PIARegion = "New Zealand"
Norway models.PIARegion = "Norway"
Poland models.PIARegion = "Poland"
Romania models.PIARegion = "Romania"
Singapore models.PIARegion = "Singapore"
Spain models.PIARegion = "Spain"
Sweden models.PIARegion = "Sweden"
Switzerland models.PIARegion = "Switzerland"
UAE models.PIARegion = "UAE"
UKLondon models.PIARegion = "UK London"
UKManchester models.PIARegion = "UK Manchester"
UKSouthampton models.PIARegion = "UK Southampton"
USAtlanta models.PIARegion = "US Atlanta"
USCalifornia models.PIARegion = "US California"
USChicago models.PIARegion = "US Chicago"
USDenver models.PIARegion = "US Denver"
USEast models.PIARegion = "US East"
USFlorida models.PIARegion = "US Florida"
USHouston models.PIARegion = "US Houston"
USLasVegas models.PIARegion = "US Las Vegas"
USNewYorkCity models.PIARegion = "US New York City"
USSeattle models.PIARegion = "US Seattle"
USSiliconValley models.PIARegion = "US Silicon Valley"
USTexas models.PIARegion = "US Texas"
USWashingtonDC models.PIARegion = "US Washington DC"
USWest models.PIARegion = "US West"
)
const (
PIAOpenVPNURL models.URL = "https://www.privateinternetaccess.com/openvpn"
PIAPortForwardURL models.URL = "http://209.222.18.222:2000"
)

View File

@@ -0,0 +1,13 @@
package constants
const (
// Annoucement is a message annoucement
Annoucement = "Total rewrite in Go with many new features"
// AnnoucementExpiration is the expiration time of the annoucement in unix timestamp
AnnoucementExpiration = 1582761600
)
const (
// IssueLink is the link for users to use to create issues
IssueLink = "https://github.com/qdm12/private-internet-access-docker/issues/new"
)

View File

@@ -0,0 +1,16 @@
package constants
import (
"github.com/qdm12/private-internet-access-docker/internal/models"
)
const (
// TinyProxyInfoLevel is the info log level for TinyProxy
TinyProxyInfoLevel models.TinyProxyLogLevel = "Info"
// TinyProxyWarnLevel is the warning log level for TinyProxy
TinyProxyWarnLevel models.TinyProxyLogLevel = "Warning"
// TinyProxyErrorLevel is the error log level for TinyProxy
TinyProxyErrorLevel models.TinyProxyLogLevel = "Error"
// TinyProxyCriticalLevel is the critical log level for TinyProxy
TinyProxyCriticalLevel models.TinyProxyLogLevel = "Critical"
)

21
internal/constants/vpn.go Normal file
View File

@@ -0,0 +1,21 @@
package constants
import (
"github.com/qdm12/private-internet-access-docker/internal/models"
)
const (
// PrivateInternetAccess is a VPN provider
PrivateInternetAccess models.VPNProvider = "private internet access"
// Mullvad is a VPN provider
Mullvad models.VPNProvider = "mullvad"
// Windscribe is a VPN provider
Windscribe models.VPNProvider = "windscribe"
)
const (
// TCP is a network protocol (reliable and slower than UDP)
TCP models.NetworkProtocol = "tcp"
// UDP is a network protocol (unreliable and faster than TCP)
UDP models.NetworkProtocol = "udp"
)

40
internal/dns/command.go Normal file
View File

@@ -0,0 +1,40 @@
package dns
import (
"fmt"
"io"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) Start(verbosityDetailsLevel uint8) (stdout io.ReadCloser, err error) {
c.logger.Info("%s: starting unbound", logPrefix)
args := []string{"-d", "-c", string(constants.UnboundConf)}
if verbosityDetailsLevel > 0 {
args = append(args, "-"+strings.Repeat("v", int(verbosityDetailsLevel)))
}
// Only logs to stderr
_, stdout, _, err = c.commander.Start("unbound", args...)
return stdout, err
}
func (c *configurator) Version() (version string, err error) {
output, err := c.commander.Run("unbound", "-V")
if err != nil {
return "", fmt.Errorf("unbound version: %w", err)
}
for _, line := range strings.Split(output, "\n") {
if strings.Contains(line, "Version ") {
words := strings.Fields(line)
if len(words) < 2 {
continue
}
version = words[1]
}
}
if version == "" {
return "", fmt.Errorf("unbound version was not found in %q", output)
}
return version, nil
}

View File

@@ -0,0 +1,69 @@
package dns
import (
"fmt"
"testing"
commandMocks "github.com/qdm12/golibs/command/mocks"
loggingMocks "github.com/qdm12/golibs/logging/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func Test_Start(t *testing.T) {
t.Parallel()
logger := &loggingMocks.Logger{}
logger.On("Info", "%s: starting unbound", logPrefix).Once()
commander := &commandMocks.Commander{}
commander.On("Start", "unbound", "-d", "-c", string(constants.UnboundConf), "-vv").
Return(nil, nil, nil, nil).Once()
c := &configurator{commander: commander, logger: logger}
stdout, err := c.Start(2)
assert.Nil(t, stdout)
assert.NoError(t, err)
logger.AssertExpectations(t)
commander.AssertExpectations(t)
}
func Test_Version(t *testing.T) {
t.Parallel()
tests := map[string]struct {
runOutput string
runErr error
version string
err error
}{
"no data": {
err: fmt.Errorf(`unbound version was not found in ""`),
},
"2 lines with version": {
runOutput: "Version \nVersion 1.0-a hello\n",
version: "1.0-a",
},
"run error": {
runErr: fmt.Errorf("error"),
err: fmt.Errorf("unbound version: error"),
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
commander := &commandMocks.Commander{}
commander.On("Run", "unbound", "-V").
Return(tc.runOutput, tc.runErr).Once()
c := &configurator{commander: commander}
version, err := c.Version()
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.version, version)
commander.AssertExpectations(t)
})
}
}

280
internal/dns/conf.go Normal file
View File

@@ -0,0 +1,280 @@
package dns
import (
"fmt"
"sort"
"strings"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/settings"
)
func (c *configurator) MakeUnboundConf(settings settings.DNS, uid, gid int) (err error) {
c.logger.Info("%s: generating Unbound configuration", logPrefix)
lines, warnings, err := generateUnboundConf(settings, c.client, c.logger)
for _, warning := range warnings {
c.logger.Warn(warning)
}
if err != nil {
return err
}
return c.fileManager.WriteLinesToFile(
string(constants.UnboundConf),
lines,
files.FileOwnership(uid, gid),
files.FilePermissions(0400))
}
// MakeUnboundConf generates an Unbound configuration from the user provided settings
func generateUnboundConf(settings settings.DNS, client network.Client, logger logging.Logger) (lines []string, warnings []error, err error) {
serverSection := map[string]string{
// Logging
"verbosity": fmt.Sprintf("%d", settings.VerbosityLevel),
"val-log-level": fmt.Sprintf("%d", settings.ValidationLogLevel),
"use-syslog": "no",
// Performance
"num-threads": "1",
"prefetch": "yes",
"prefetch-key": "yes",
"key-cache-size": "16m",
"key-cache-slabs": "4",
"msg-cache-size": "4m",
"msg-cache-slabs": "4",
"rrset-cache-size": "4m",
"rrset-cache-slabs": "4",
"cache-min-ttl": "3600",
"cache-max-ttl": "9000",
// Privacy
"rrset-roundrobin": "yes",
"hide-identity": "yes",
"hide-version": "yes",
// Security
"tls-cert-bundle": "\"/etc/ssl/certs/ca-certificates.crt\"",
"root-hints": fmt.Sprintf("%q", constants.RootHints),
"trust-anchor-file": fmt.Sprintf("%q", constants.RootKey),
"harden-below-nxdomain": "yes",
"harden-referral-path": "yes",
"harden-algo-downgrade": "yes",
// Network
"do-ip4": "yes",
"do-ip6": "no",
"interface": "127.0.0.1",
"port": "53",
// Other
"username": "\"nonrootuser\"",
}
// Block lists
hostnamesLines, ipsLines, warnings := buildBlocked(client,
settings.BlockMalicious, settings.BlockAds, settings.BlockSurveillance,
settings.AllowedHostnames, settings.PrivateAddresses,
)
logger.Info("%s: %d hostnames blocked overall", logPrefix, len(hostnamesLines))
logger.Info("%s: %d IP addresses blocked overall", logPrefix, len(ipsLines))
sort.Slice(hostnamesLines, func(i, j int) bool { // for unit tests really
return hostnamesLines[i] < hostnamesLines[j]
})
sort.Slice(ipsLines, func(i, j int) bool { // for unit tests really
return ipsLines[i] < ipsLines[j]
})
// Server
lines = append(lines, "server:")
var serverLines []string
for k, v := range serverSection {
serverLines = append(serverLines, " "+k+": "+v)
}
sort.Slice(serverLines, func(i, j int) bool {
return serverLines[i] < serverLines[j]
})
lines = append(lines, serverLines...)
lines = append(lines, hostnamesLines...)
lines = append(lines, ipsLines...)
// Forward zone
lines = append(lines, "forward-zone:")
forwardZoneSection := map[string]string{
"name": "\".\"",
"forward-tls-upstream": "yes",
}
var forwardZoneLines []string
for k, v := range forwardZoneSection {
forwardZoneLines = append(forwardZoneLines, " "+k+": "+v)
}
sort.Slice(forwardZoneLines, func(i, j int) bool {
return forwardZoneLines[i] < forwardZoneLines[j]
})
for _, provider := range settings.Providers {
forwardAddresses, ok := constants.DNSAddressesMapping[provider]
if !ok || len(forwardAddresses) == 0 {
return nil, warnings, fmt.Errorf("DNS provider %q does not have any matching forward addresses", provider)
}
for _, forwardAddress := range forwardAddresses {
forwardZoneLines = append(forwardZoneLines, fmt.Sprintf(" forward-addr: %s", forwardAddress))
}
}
lines = append(lines, forwardZoneLines...)
return lines, warnings, nil
}
func buildBlocked(client network.Client, blockMalicious, blockAds, blockSurveillance bool,
allowedHostnames, privateAddresses []string) (hostnamesLines, ipsLines []string, errs []error) {
chHostnames := make(chan []string)
chIPs := make(chan []string)
chErrors := make(chan []error)
go func() {
lines, errs := buildBlockedHostnames(client, blockMalicious, blockAds, blockSurveillance, allowedHostnames)
chHostnames <- lines
chErrors <- errs
}()
go func() {
lines, errs := buildBlockedIPs(client, blockMalicious, blockAds, blockSurveillance, privateAddresses)
chIPs <- lines
chErrors <- errs
}()
n := 2
for n > 0 {
select {
case lines := <-chHostnames:
hostnamesLines = append(hostnamesLines, lines...)
case lines := <-chIPs:
ipsLines = append(ipsLines, lines...)
case routineErrs := <-chErrors:
errs = append(errs, routineErrs...)
n--
}
}
return hostnamesLines, ipsLines, errs
}
func getList(client network.Client, URL string) (results []string, err error) {
content, status, err := client.GetContent(URL)
if err != nil {
return nil, err
} else if status != 200 {
return nil, fmt.Errorf("HTTP status code is %d and not 200", status)
}
results = strings.Split(string(content), "\n")
// remove empty lines
last := len(results) - 1
for i := range results {
if len(results[i]) == 0 {
results[i] = results[last]
last--
}
}
results = results[:last+1]
if len(results) == 0 {
return nil, nil
}
return results, nil
}
func buildBlockedHostnames(client network.Client, blockMalicious, blockAds, blockSurveillance bool,
allowedHostnames []string) (lines []string, errs []error) {
chResults := make(chan []string)
chError := make(chan error)
listsLeftToFetch := 0
if blockMalicious {
listsLeftToFetch++
go func() {
results, err := getList(client, string(constants.MaliciousBlockListHostnamesURL))
chResults <- results
chError <- err
}()
}
if blockAds {
listsLeftToFetch++
go func() {
results, err := getList(client, string(constants.AdsBlockListHostnamesURL))
chResults <- results
chError <- err
}()
}
if blockSurveillance {
listsLeftToFetch++
go func() {
results, err := getList(client, string(constants.SurveillanceBlockListHostnamesURL))
chResults <- results
chError <- err
}()
}
uniqueResults := make(map[string]struct{})
for listsLeftToFetch > 0 {
select {
case results := <-chResults:
for _, result := range results {
uniqueResults[result] = struct{}{}
}
case err := <-chError:
listsLeftToFetch--
if err != nil {
errs = append(errs, err)
}
}
}
for _, allowedHostname := range allowedHostnames {
delete(uniqueResults, allowedHostname)
}
for result := range uniqueResults {
lines = append(lines, " local-zone: \""+result+"\" static")
}
return lines, errs
}
func buildBlockedIPs(client network.Client, blockMalicious, blockAds, blockSurveillance bool,
privateAddresses []string) (lines []string, errs []error) {
chResults := make(chan []string)
chError := make(chan error)
listsLeftToFetch := 0
if blockMalicious {
listsLeftToFetch++
go func() {
results, err := getList(client, string(constants.MaliciousBlockListIPsURL))
chResults <- results
chError <- err
}()
}
if blockAds {
listsLeftToFetch++
go func() {
results, err := getList(client, string(constants.AdsBlockListIPsURL))
chResults <- results
chError <- err
}()
}
if blockSurveillance {
listsLeftToFetch++
go func() {
results, err := getList(client, string(constants.SurveillanceBlockListIPsURL))
chResults <- results
chError <- err
}()
}
uniqueResults := make(map[string]struct{})
for listsLeftToFetch > 0 {
select {
case results := <-chResults:
for _, result := range results {
uniqueResults[result] = struct{}{}
}
case err := <-chError:
listsLeftToFetch--
if err != nil {
errs = append(errs, err)
}
}
}
for _, privateAddress := range privateAddresses {
uniqueResults[privateAddress] = struct{}{}
}
for result := range uniqueResults {
lines = append(lines, " private-address: "+result)
}
return lines, errs
}

518
internal/dns/conf_test.go Normal file
View File

@@ -0,0 +1,518 @@
package dns
import (
"fmt"
"strings"
"testing"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network/mocks"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/models"
"github.com/qdm12/private-internet-access-docker/internal/settings"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_generateUnboundConf(t *testing.T) {
t.Parallel()
settings := settings.DNS{
Providers: []models.DNSProvider{constants.Cloudflare, constants.Quad9},
AllowedHostnames: []string{"a"},
PrivateAddresses: []string{"9.9.9.9"},
BlockMalicious: true,
BlockSurveillance: false,
BlockAds: false,
VerbosityLevel: 2,
ValidationLogLevel: 3,
}
client := &mocks.Client{}
client.On("GetContent", string(constants.MaliciousBlockListHostnamesURL)).
Return([]byte("b\na\nc"), 200, nil).Once()
client.On("GetContent", string(constants.MaliciousBlockListIPsURL)).
Return([]byte("c\nd\n"), 200, nil).Once()
emptyLogger, err := logging.NewEmptyLogger()
require.NoError(t, err)
lines, warnings, err := generateUnboundConf(settings, client, emptyLogger)
require.Len(t, warnings, 0)
require.NoError(t, err)
client.AssertExpectations(t)
expected := `
server:
cache-max-ttl: 9000
cache-min-ttl: 3600
do-ip4: yes
do-ip6: no
harden-algo-downgrade: yes
harden-below-nxdomain: yes
harden-referral-path: yes
hide-identity: yes
hide-version: yes
interface: 127.0.0.1
key-cache-size: 16m
key-cache-slabs: 4
msg-cache-size: 4m
msg-cache-slabs: 4
num-threads: 1
port: 53
prefetch-key: yes
prefetch: yes
root-hints: "/etc/unbound/root.hints"
rrset-cache-size: 4m
rrset-cache-slabs: 4
rrset-roundrobin: yes
tls-cert-bundle: "/etc/ssl/certs/ca-certificates.crt"
trust-anchor-file: "/etc/unbound/root.key"
use-syslog: no
username: "nonrootuser"
val-log-level: 3
verbosity: 2
local-zone: "b" static
local-zone: "c" static
private-address: 9.9.9.9
private-address: c
private-address: d
forward-zone:
forward-tls-upstream: yes
name: "."
forward-addr: 1.1.1.1@853#cloudflare-dns.com
forward-addr: 1.0.0.1@853#cloudflare-dns.com
forward-addr: 9.9.9.9@853#dns.quad9.net
forward-addr: 149.112.112.112@853#dns.quad9.net`
assert.Equal(t, expected, "\n"+strings.Join(lines, "\n"))
}
func Test_buildBlocked(t *testing.T) {
t.Parallel()
type blockParams struct {
blocked bool
content []byte
clientErr error
}
tests := map[string]struct {
malicious blockParams
ads blockParams
surveillance blockParams
allowedHostnames []string
privateAddresses []string
hostnamesLines []string
ipsLines []string
errsString []string
}{
"none blocked": {},
"all blocked without lists": {
malicious: blockParams{
blocked: true,
},
ads: blockParams{
blocked: true,
},
surveillance: blockParams{
blocked: true,
},
},
"all blocked with lists": {
malicious: blockParams{
blocked: true,
content: []byte("malicious"),
},
ads: blockParams{
blocked: true,
content: []byte("ads"),
},
surveillance: blockParams{
blocked: true,
content: []byte("surveillance"),
},
hostnamesLines: []string{
" local-zone: \"ads\" static",
" local-zone: \"malicious\" static",
" local-zone: \"surveillance\" static"},
ipsLines: []string{
" private-address: ads",
" private-address: malicious",
" private-address: surveillance"},
},
"all blocked with allowed hostnames": {
malicious: blockParams{
blocked: true,
content: []byte("malicious"),
},
ads: blockParams{
blocked: true,
content: []byte("ads"),
},
surveillance: blockParams{
blocked: true,
content: []byte("surveillance"),
},
allowedHostnames: []string{"ads"},
hostnamesLines: []string{
" local-zone: \"malicious\" static",
" local-zone: \"surveillance\" static"},
ipsLines: []string{
" private-address: ads",
" private-address: malicious",
" private-address: surveillance"},
},
"all blocked with private addresses": {
malicious: blockParams{
blocked: true,
content: []byte("malicious"),
},
ads: blockParams{
blocked: true,
content: []byte("ads"),
},
surveillance: blockParams{
blocked: true,
content: []byte("surveillance"),
},
privateAddresses: []string{"ads", "192.100.1.5"},
hostnamesLines: []string{
" local-zone: \"ads\" static",
" local-zone: \"malicious\" static",
" local-zone: \"surveillance\" static"},
ipsLines: []string{
" private-address: 192.100.1.5",
" private-address: ads",
" private-address: malicious",
" private-address: surveillance"},
},
"all blocked with lists and one error": {
malicious: blockParams{
blocked: true,
content: []byte("malicious"),
},
ads: blockParams{
blocked: true,
content: []byte("ads"),
clientErr: fmt.Errorf("ads error"),
},
surveillance: blockParams{
blocked: true,
content: []byte("surveillance"),
},
hostnamesLines: []string{
" local-zone: \"malicious\" static",
" local-zone: \"surveillance\" static"},
ipsLines: []string{
" private-address: malicious",
" private-address: surveillance"},
errsString: []string{"ads error", "ads error"},
},
"all blocked with errors": {
malicious: blockParams{
blocked: true,
clientErr: fmt.Errorf("malicious"),
},
ads: blockParams{
blocked: true,
clientErr: fmt.Errorf("ads"),
},
surveillance: blockParams{
blocked: true,
clientErr: fmt.Errorf("surveillance"),
},
errsString: []string{"malicious", "malicious", "ads", "ads", "surveillance", "surveillance"},
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
client := &mocks.Client{}
if tc.malicious.blocked {
client.On("GetContent", string(constants.MaliciousBlockListHostnamesURL)).
Return(tc.malicious.content, 200, tc.malicious.clientErr).Once()
client.On("GetContent", string(constants.MaliciousBlockListIPsURL)).
Return(tc.malicious.content, 200, tc.malicious.clientErr).Once()
}
if tc.ads.blocked {
client.On("GetContent", string(constants.AdsBlockListHostnamesURL)).
Return(tc.ads.content, 200, tc.ads.clientErr).Once()
client.On("GetContent", string(constants.AdsBlockListIPsURL)).
Return(tc.ads.content, 200, tc.ads.clientErr).Once()
}
if tc.surveillance.blocked {
client.On("GetContent", string(constants.SurveillanceBlockListHostnamesURL)).
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Once()
client.On("GetContent", string(constants.SurveillanceBlockListIPsURL)).
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Once()
}
hostnamesLines, ipsLines, errs := buildBlocked(client, tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
tc.allowedHostnames, tc.privateAddresses)
var errsString []string
for _, err := range errs {
errsString = append(errsString, err.Error())
}
assert.ElementsMatch(t, tc.errsString, errsString)
assert.ElementsMatch(t, tc.hostnamesLines, hostnamesLines)
assert.ElementsMatch(t, tc.ipsLines, ipsLines)
client.AssertExpectations(t)
})
}
}
func Test_getList(t *testing.T) {
t.Parallel()
tests := map[string]struct {
content []byte
status int
clientErr error
results []string
err error
}{
"no result": {nil, 200, nil, nil, nil},
"bad status": {nil, 500, nil, nil, fmt.Errorf("HTTP status code is 500 and not 200")},
"network error": {nil, 200, fmt.Errorf("error"), nil, fmt.Errorf("error")},
"results": {[]byte("a\nb\nc\n"), 200, nil, []string{"a", "b", "c"}, nil},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
client := &mocks.Client{}
client.On("GetContent", "irrelevant_url").Return(
tc.content, tc.status, tc.clientErr,
).Once()
results, err := getList(client, "irrelevant_url")
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.results, results)
client.AssertExpectations(t)
})
}
}
func Test_buildBlockedHostnames(t *testing.T) {
t.Parallel()
type blockParams struct {
blocked bool
content []byte
clientErr error
}
tests := map[string]struct {
malicious blockParams
ads blockParams
surveillance blockParams
allowedHostnames []string
lines []string
errsString []string
}{
"nothing blocked": {
lines: nil,
errsString: nil,
},
"only malicious blocked": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
clientErr: nil,
},
lines: []string{
" local-zone: \"site_a\" static",
" local-zone: \"site_b\" static"},
errsString: nil,
},
"all blocked with some duplicates": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
},
ads: blockParams{
blocked: true,
content: []byte("site_a\nsite_c"),
},
surveillance: blockParams{
blocked: true,
content: []byte("site_c\nsite_a"),
},
lines: []string{
" local-zone: \"site_a\" static",
" local-zone: \"site_b\" static",
" local-zone: \"site_c\" static"},
errsString: nil,
},
"all blocked with one errored": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
},
ads: blockParams{
blocked: true,
content: []byte("site_a\nsite_c"),
},
surveillance: blockParams{
blocked: true,
clientErr: fmt.Errorf("surveillance error"),
},
lines: []string{
" local-zone: \"site_a\" static",
" local-zone: \"site_b\" static",
" local-zone: \"site_c\" static"},
errsString: []string{"surveillance error"},
},
"blocked with allowed hostnames": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
},
ads: blockParams{
blocked: true,
content: []byte("site_c\nsite_d"),
},
allowedHostnames: []string{"site_b", "site_c"},
lines: []string{
" local-zone: \"site_a\" static",
" local-zone: \"site_d\" static"},
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
client := &mocks.Client{}
if tc.malicious.blocked {
client.On("GetContent", string(constants.MaliciousBlockListHostnamesURL)).
Return(tc.malicious.content, 200, tc.malicious.clientErr).Once()
}
if tc.ads.blocked {
client.On("GetContent", string(constants.AdsBlockListHostnamesURL)).
Return(tc.ads.content, 200, tc.ads.clientErr).Once()
}
if tc.surveillance.blocked {
client.On("GetContent", string(constants.SurveillanceBlockListHostnamesURL)).
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Once()
}
lines, errs := buildBlockedHostnames(client,
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked, tc.allowedHostnames)
var errsString []string
for _, err := range errs {
errsString = append(errsString, err.Error())
}
assert.ElementsMatch(t, tc.errsString, errsString)
assert.ElementsMatch(t, tc.lines, lines)
client.AssertExpectations(t)
})
}
}
func Test_buildBlockedIPs(t *testing.T) {
t.Parallel()
type blockParams struct {
blocked bool
content []byte
clientErr error
}
tests := map[string]struct {
malicious blockParams
ads blockParams
surveillance blockParams
privateAddresses []string
lines []string
errsString []string
}{
"nothing blocked": {
lines: nil,
errsString: nil,
},
"only malicious blocked": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
clientErr: nil,
},
lines: []string{
" private-address: site_a",
" private-address: site_b"},
errsString: nil,
},
"all blocked with some duplicates": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
},
ads: blockParams{
blocked: true,
content: []byte("site_a\nsite_c"),
},
surveillance: blockParams{
blocked: true,
content: []byte("site_c\nsite_a"),
},
lines: []string{
" private-address: site_a",
" private-address: site_b",
" private-address: site_c"},
errsString: nil,
},
"all blocked with one errored": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
},
ads: blockParams{
blocked: true,
content: []byte("site_a\nsite_c"),
},
surveillance: blockParams{
blocked: true,
clientErr: fmt.Errorf("surveillance error"),
},
lines: []string{
" private-address: site_a",
" private-address: site_b",
" private-address: site_c"},
errsString: []string{"surveillance error"},
},
"blocked with private addresses": {
malicious: blockParams{
blocked: true,
content: []byte("site_a\nsite_b"),
},
ads: blockParams{
blocked: true,
content: []byte("site_c"),
},
privateAddresses: []string{"site_c", "site_d"},
lines: []string{
" private-address: site_a",
" private-address: site_b",
" private-address: site_c",
" private-address: site_d"},
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
client := &mocks.Client{}
if tc.malicious.blocked {
client.On("GetContent", string(constants.MaliciousBlockListIPsURL)).
Return(tc.malicious.content, 200, tc.malicious.clientErr).Once()
}
if tc.ads.blocked {
client.On("GetContent", string(constants.AdsBlockListIPsURL)).
Return(tc.ads.content, 200, tc.ads.clientErr).Once()
}
if tc.surveillance.blocked {
client.On("GetContent", string(constants.SurveillanceBlockListIPsURL)).
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Once()
}
lines, errs := buildBlockedIPs(client,
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked, tc.privateAddresses)
var errsString []string
for _, err := range errs {
errsString = append(errsString, err.Error())
}
assert.ElementsMatch(t, tc.errsString, errsString)
assert.ElementsMatch(t, tc.lines, lines)
client.AssertExpectations(t)
})
}
}

38
internal/dns/dns.go Normal file
View File

@@ -0,0 +1,38 @@
package dns
import (
"io"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
"github.com/qdm12/private-internet-access-docker/internal/settings"
)
const logPrefix = "dns configurator"
type Configurator interface {
DownloadRootHints(uid, gid int) error
DownloadRootKey(uid, gid int) error
MakeUnboundConf(settings settings.DNS, uid, gid int) (err error)
SetLocalNameserver() error
Start(logLevel uint8) (stdout io.ReadCloser, err error)
Version() (version string, err error)
}
type configurator struct {
logger logging.Logger
client network.Client
fileManager files.FileManager
commander command.Commander
}
func NewConfigurator(logger logging.Logger, client network.Client, fileManager files.FileManager) Configurator {
return &configurator{
logger: logger,
client: client,
fileManager: fileManager,
commander: command.NewCommander(),
}
}

32
internal/dns/os.go Normal file
View File

@@ -0,0 +1,32 @@
package dns
import (
"strings"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) SetLocalNameserver() error {
c.logger.Info("%s: setting local nameserver to 127.0.0.1", logPrefix)
data, err := c.fileManager.ReadFile(string(constants.ResolvConf))
if err != nil {
return err
}
s := strings.TrimSuffix(string(data), "\n")
lines := strings.Split(s, "\n")
if len(lines) == 1 && lines[0] == "" {
lines = nil
}
found := false
for i := range lines {
if strings.HasPrefix(lines[i], "nameserver ") {
lines[i] = "nameserver 127.0.0.1"
found = true
}
}
if !found {
lines = append(lines, "nameserver 127.0.0.1")
}
data = []byte(strings.Join(lines, "\n"))
return c.fileManager.WriteToFile(string(constants.ResolvConf), data)
}

72
internal/dns/os_test.go Normal file
View File

@@ -0,0 +1,72 @@
package dns
import (
"fmt"
"testing"
filesmocks "github.com/qdm12/golibs/files/mocks"
loggingmocks "github.com/qdm12/golibs/logging/mocks"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_SetLocalNameserver(t *testing.T) {
t.Parallel()
tests := map[string]struct {
data []byte
writtenData []byte
readErr error
writeErr error
err error
}{
"no data": {
writtenData: []byte("nameserver 127.0.0.1"),
},
"read error": {
readErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": {
writtenData: []byte("nameserver 127.0.0.1"),
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"lines without nameserver": {
data: []byte("abc\ndef\n"),
writtenData: []byte("abc\ndef\nnameserver 127.0.0.1"),
},
"lines with nameserver": {
data: []byte("abc\nnameserver abc def\ndef\n"),
writtenData: []byte("abc\nnameserver 127.0.0.1\ndef"),
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
fileManager := &filesmocks.FileManager{}
fileManager.On("ReadFile", string(constants.ResolvConf)).
Return(tc.data, tc.readErr).Once()
if tc.readErr == nil {
fileManager.On("WriteToFile", string(constants.ResolvConf), tc.writtenData).
Return(tc.writeErr).Once()
}
logger := &loggingmocks.Logger{}
logger.On("Info", "%s: setting local nameserver to 127.0.0.1", logPrefix).Once()
c := &configurator{
fileManager: fileManager,
logger: logger,
}
err := c.SetLocalNameserver()
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
fileManager.AssertExpectations(t)
logger.AssertExpectations(t)
})
}
}

38
internal/dns/roots.go Normal file
View File

@@ -0,0 +1,38 @@
package dns
import (
"fmt"
"github.com/qdm12/golibs/files"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) DownloadRootHints(uid, gid int) error {
c.logger.Info("%s: downloading root hints from %s", logPrefix, constants.NamedRootURL)
content, status, err := c.client.GetContent(string(constants.NamedRootURL))
if err != nil {
return err
} else if status != 200 {
return fmt.Errorf("HTTP status code is %d for %s", status, constants.NamedRootURL)
}
return c.fileManager.WriteToFile(
string(constants.RootHints),
content,
files.FileOwnership(uid, gid),
files.FilePermissions(0400))
}
func (c *configurator) DownloadRootKey(uid, gid int) error {
c.logger.Info("%s: downloading root key from %s", logPrefix, constants.RootKeyURL)
content, status, err := c.client.GetContent(string(constants.RootKeyURL))
if err != nil {
return err
} else if status != 200 {
return fmt.Errorf("HTTP status code is %d for %s", status, constants.RootKeyURL)
}
return c.fileManager.WriteToFile(
string(constants.RootKey),
content,
files.FileOwnership(uid, gid),
files.FilePermissions(0400))
}

144
internal/dns/roots_test.go Normal file
View File

@@ -0,0 +1,144 @@
package dns
import (
"fmt"
"net/http"
"testing"
filesMocks "github.com/qdm12/golibs/files/mocks"
loggingMocks "github.com/qdm12/golibs/logging/mocks"
networkMocks "github.com/qdm12/golibs/network/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func Test_DownloadRootHints(t *testing.T) {
t.Parallel()
tests := map[string]struct {
content []byte
status int
clientErr error
writeErr error
err error
}{
"no data": {
status: http.StatusOK,
},
"bad status": {
status: http.StatusBadRequest,
err: fmt.Errorf("HTTP status code is 400 for https://raw.githubusercontent.com/qdm12/files/master/named.root.updated"),
},
"client error": {
clientErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": {
status: http.StatusOK,
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"data": {
content: []byte("content"),
status: http.StatusOK,
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
logger := &loggingMocks.Logger{}
logger.On("Info", "%s: downloading root hints from %s", logPrefix, constants.NamedRootURL).Once()
client := &networkMocks.Client{}
client.On("GetContent", string(constants.NamedRootURL)).
Return(tc.content, tc.status, tc.clientErr).Once()
fileManager := &filesMocks.FileManager{}
if tc.clientErr == nil && tc.status == http.StatusOK {
fileManager.On(
"WriteToFile",
string(constants.RootHints),
tc.content,
mock.AnythingOfType("files.WriteOptionSetter"),
mock.AnythingOfType("files.WriteOptionSetter")).
Return(tc.writeErr).Once()
}
c := &configurator{logger: logger, client: client, fileManager: fileManager}
err := c.DownloadRootHints(1000, 1000)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
logger.AssertExpectations(t)
client.AssertExpectations(t)
fileManager.AssertExpectations(t)
})
}
}
func Test_DownloadRootKey(t *testing.T) {
t.Parallel()
tests := map[string]struct {
content []byte
status int
clientErr error
writeErr error
err error
}{
"no data": {
status: http.StatusOK,
},
"bad status": {
status: http.StatusBadRequest,
err: fmt.Errorf("HTTP status code is 400 for https://raw.githubusercontent.com/qdm12/files/master/root.key.updated"),
},
"client error": {
clientErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": {
status: http.StatusOK,
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"data": {
content: []byte("content"),
status: http.StatusOK,
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
logger := &loggingMocks.Logger{}
logger.On("Info", "%s: downloading root key from %s", logPrefix, constants.RootKeyURL).Once()
client := &networkMocks.Client{}
client.On("GetContent", string(constants.RootKeyURL)).
Return(tc.content, tc.status, tc.clientErr).Once()
fileManager := &filesMocks.FileManager{}
if tc.clientErr == nil && tc.status == http.StatusOK {
fileManager.On(
"WriteToFile",
string(constants.RootKey),
tc.content,
mock.AnythingOfType("files.WriteOptionSetter"),
mock.AnythingOfType("files.WriteOptionSetter"),
).Return(tc.writeErr).Once()
}
c := &configurator{logger: logger, client: client, fileManager: fileManager}
err := c.DownloadRootKey(1000, 1001)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
logger.AssertExpectations(t)
client.AssertExpectations(t)
fileManager.AssertExpectations(t)
})
}
}

40
internal/env/env.go vendored Normal file
View File

@@ -0,0 +1,40 @@
package env
import (
"os"
"github.com/qdm12/golibs/logging"
)
type Env interface {
FatalOnError(err error)
PrintVersion(program string, commandFn func() (string, error))
}
type env struct {
logger logging.Logger
osExit func(n int)
}
func New(logger logging.Logger) Env {
return &env{
logger: logger,
osExit: os.Exit,
}
}
func (e *env) FatalOnError(err error) {
if err != nil {
e.logger.Error(err)
e.osExit(1)
}
}
func (e *env) PrintVersion(program string, commandFn func() (string, error)) {
version, err := commandFn()
if err != nil {
e.logger.Error(err)
} else {
e.logger.Info("%s version: %s", program, version)
}
}

90
internal/env/env_test.go vendored Normal file
View File

@@ -0,0 +1,90 @@
package env
import (
"fmt"
"testing"
"github.com/qdm12/golibs/logging/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func Test_FatalOnError(t *testing.T) {
t.Parallel()
tests := map[string]struct {
err error
}{
"nil": {},
"err": {fmt.Errorf("error")},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
var logged string
var exitCode int
logger := &mocks.Logger{}
if tc.err != nil {
logger.On("Error", tc.err).
Run(func(args mock.Arguments) {
err := args.Get(0).(error)
logged = err.Error()
}).Once()
}
osExit := func(n int) { exitCode = n }
e := &env{logger, osExit}
e.FatalOnError(tc.err)
if tc.err != nil {
assert.Equal(t, logged, tc.err.Error())
assert.Equal(t, exitCode, 1)
} else {
assert.Empty(t, logged)
assert.Zero(t, exitCode)
}
})
}
}
func Test_PrintVersion(t *testing.T) {
t.Parallel()
tests := map[string]struct {
program string
commandVersion string
commandErr error
}{
"no data": {},
"data": {"binu", "2.3-5", nil},
"error": {"binu", "", fmt.Errorf("error")},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
var logged string
logger := &mocks.Logger{}
if tc.commandErr != nil {
logger.On("Error", tc.commandErr).
Run(func(args mock.Arguments) {
err := args.Get(0).(error)
logged = err.Error()
}).Once()
} else {
logger.On("Info", "%s version: %s", tc.program, tc.commandVersion).
Run(func(args mock.Arguments) {
format := args.Get(0).(string)
program := args.Get(1).(string)
version := args.Get(2).(string)
logged = fmt.Sprintf(format, program, version)
}).Once()
}
e := &env{logger: logger}
commandFn := func() (string, error) { return tc.commandVersion, tc.commandErr }
e.PrintVersion(tc.program, commandFn)
if tc.commandErr != nil {
assert.Equal(t, logged, tc.commandErr.Error())
} else {
assert.Equal(t, logged, fmt.Sprintf("%s version: %s", tc.program, tc.commandVersion))
}
})
}
}

View File

@@ -0,0 +1,42 @@
package firewall
import (
"net"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
const logPrefix = "firewall configurator"
// Configurator allows to change firewall rules and modify network routes
type Configurator interface {
Version() (string, error)
AcceptAll() error
Clear() error
BlockAll() error
CreateGeneralRules() error
CreateVPNRules(dev models.VPNDevice, serverIPs []net.IP, defaultInterface string,
port uint16, protocol models.NetworkProtocol) error
CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error
AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error
GetDefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error)
AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error
}
type configurator struct {
commander command.Commander
logger logging.Logger
fileManager files.FileManager
}
// NewConfigurator creates a new Configurator instance
func NewConfigurator(logger logging.Logger, fileManager files.FileManager) Configurator {
return &configurator{
commander: command.NewCommander(),
logger: logger,
fileManager: fileManager,
}
}

View File

@@ -0,0 +1,130 @@
package firewall
import (
"fmt"
"net"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
// Version obtains the version of the installed iptables
func (c *configurator) Version() (string, error) {
output, err := c.commander.Run("iptables", "--version")
if err != nil {
return "", err
}
words := strings.Fields(output)
if len(words) < 2 {
return "", fmt.Errorf("iptables --version: output is too short: %q", output)
}
return words[1], nil
}
func (c *configurator) runIptablesInstructions(instructions []string) error {
for _, instruction := range instructions {
if err := c.runIptablesInstruction(instruction); err != nil {
return err
}
}
return nil
}
func (c *configurator) runIptablesInstruction(instruction string) error {
flags := strings.Fields(instruction)
if output, err := c.commander.Run("iptables", flags...); err != nil {
return fmt.Errorf("failed executing %q: %s: %w", instruction, output, err)
}
return nil
}
func (c *configurator) Clear() error {
c.logger.Info("%s: clearing all rules", logPrefix)
return c.runIptablesInstructions([]string{
"--flush",
"--delete-chain",
"-t nat --flush",
"-t nat --delete-chain",
})
}
func (c *configurator) AcceptAll() error {
c.logger.Info("%s: accepting all traffic", logPrefix)
return c.runIptablesInstructions([]string{
"-P INPUT ACCEPT",
"-P OUTPUT ACCEPT",
"-P FORWARD ACCEPT",
})
}
func (c *configurator) BlockAll() error {
c.logger.Info("%s: blocking all traffic", logPrefix)
return c.runIptablesInstructions([]string{
"-P INPUT DROP",
"-F OUTPUT",
"-P OUTPUT DROP",
"-P FORWARD DROP",
})
}
func (c *configurator) CreateGeneralRules() error {
c.logger.Info("%s: creating general rules", logPrefix)
return c.runIptablesInstructions([]string{
"-A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
"-A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
"-A OUTPUT -o lo -j ACCEPT",
"-A INPUT -i lo -j ACCEPT",
})
}
func (c *configurator) CreateVPNRules(dev models.VPNDevice, serverIPs []net.IP,
defaultInterface string, port uint16, protocol models.NetworkProtocol) error {
for _, serverIP := range serverIPs {
c.logger.Info("%s: allowing output traffic to VPN server %s through %s on port %s %d",
logPrefix, serverIP, defaultInterface, protocol, port)
if err := c.runIptablesInstruction(
fmt.Sprintf("-A OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
serverIP, defaultInterface, protocol, protocol, port)); err != nil {
return err
}
}
if err := c.runIptablesInstruction(fmt.Sprintf("-A OUTPUT -o %s -j ACCEPT", dev)); err != nil {
return err
}
return nil
}
func (c *configurator) CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error {
subnetStr := subnet.String()
c.logger.Info("%s: accepting input and output traffic for %s", logPrefix, subnetStr)
if err := c.runIptablesInstructions([]string{
fmt.Sprintf("-A INPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
fmt.Sprintf("-A OUTPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
}); err != nil {
return err
}
for _, extraSubnet := range extraSubnets {
extraSubnetStr := extraSubnet.String()
c.logger.Info("%s: accepting input traffic through %s from %s to %s", logPrefix, defaultInterface, extraSubnetStr, subnetStr)
if err := c.runIptablesInstruction(
fmt.Sprintf("-A INPUT -i %s -s %s -d %s -j ACCEPT", defaultInterface, extraSubnetStr, subnetStr)); err != nil {
return err
}
// Thanks to @npawelek
c.logger.Info("%s: accepting output traffic through %s from %s to %s", logPrefix, defaultInterface, subnetStr, extraSubnetStr)
if err := c.runIptablesInstruction(
fmt.Sprintf("-A OUTPUT -o %s -s %s -d %s -j ACCEPT", defaultInterface, subnetStr, extraSubnetStr)); err != nil {
return err
}
}
return nil
}
// Used for port forwarding
func (c *configurator) AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error {
c.logger.Info("%s: accepting input traffic through %s on port %d", logPrefix, device, port)
return c.runIptablesInstructions([]string{
fmt.Sprintf("-A INPUT -i %s -p tcp --dport %d -j ACCEPT", device, port),
fmt.Sprintf("-A INPUT -i %s -p udp --dport %d -j ACCEPT", device, port),
})
}

View File

@@ -0,0 +1,88 @@
package firewall
import (
"encoding/hex"
"net"
"fmt"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error {
for _, subnet := range subnets {
subnetStr := subnet.String()
output, err := c.commander.Run("ip", "route", "show", subnetStr)
if err != nil {
return fmt.Errorf("cannot read route %s: %s: %w", subnetStr, output, err)
} else if len(output) > 0 { // thanks to @npawelek https://github.com/npawelek
continue // already exists
// TODO remove it instead and continue execution below
}
c.logger.Info("%s: adding %s as route via %s", logPrefix, subnetStr, defaultInterface)
output, err = c.commander.Run("ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface)
if err != nil {
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnetStr, defaultGateway.String(), "dev", defaultInterface, output, err)
}
}
return nil
}
func (c *configurator) GetDefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) {
c.logger.Info("%s: detecting default network route", logPrefix)
data, err := c.fileManager.ReadFile(string(constants.NetRoute))
if err != nil {
return "", nil, defaultSubnet, err
}
// Verify number of lines and fields
lines := strings.Split(string(data), "\n")
if len(lines) < 3 {
return "", nil, defaultSubnet, fmt.Errorf("not enough lines (%d) found in %s", len(lines), constants.NetRoute)
}
fieldsLine1 := strings.Fields(lines[1])
if len(fieldsLine1) < 3 {
return "", nil, defaultSubnet, fmt.Errorf("not enough fields in %q", lines[1])
}
fieldsLine2 := strings.Fields(lines[2])
if len(fieldsLine2) < 8 {
return "", nil, defaultSubnet, fmt.Errorf("not enough fields in %q", lines[2])
}
// get information
defaultInterface = fieldsLine1[0]
defaultGateway, err = reversedHexToIPv4(fieldsLine1[2])
if err != nil {
return "", nil, defaultSubnet, err
}
netNumber, err := reversedHexToIPv4(fieldsLine2[1])
if err != nil {
return "", nil, defaultSubnet, err
}
netMask, err := hexToIPv4Mask(fieldsLine2[7])
if err != nil {
return "", nil, defaultSubnet, err
}
subnet := net.IPNet{IP: netNumber, Mask: netMask}
c.logger.Info("%s: default route found: interface %s, gateway %s, subnet %s", logPrefix, defaultInterface, defaultGateway.String(), subnet.String())
return defaultInterface, defaultGateway, subnet, nil
}
func reversedHexToIPv4(reversedHex string) (IP net.IP, err error) {
bytes, err := hex.DecodeString(reversedHex)
if err != nil {
return nil, fmt.Errorf("cannot parse reversed IP hex %q: %s", reversedHex, err)
} else if len(bytes) != 4 {
return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes))
}
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
}
func hexToIPv4Mask(hexString string) (mask net.IPMask, err error) {
bytes, err := hex.DecodeString(hexString)
if err != nil {
return nil, fmt.Errorf("cannot parse hex mask %q: %s", hexString, err)
} else if len(bytes) != 4 {
return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes))
}
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
}

View File

@@ -0,0 +1,171 @@
package firewall
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
filesmocks "github.com/qdm12/golibs/files/mocks"
loggingmocks "github.com/qdm12/golibs/logging/mocks"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func Test_getDefaultRoute(t *testing.T) {
t.Parallel()
tests := map[string]struct {
data []byte
readErr error
defaultInterface string
defaultGateway net.IP
defaultSubnet net.IPNet
err error
}{
"no data": {
err: fmt.Errorf("not enough lines (1) found in %s", constants.NetRoute)},
"read error": {
readErr: fmt.Errorf("error"),
err: fmt.Errorf("error")},
"not enough fields line 1": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
err: fmt.Errorf("not enough fields in \"eth0 00000000\"")},
"not enough fields line 2": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
eth0 000011AC 00000000 0001 0 0 0`),
err: fmt.Errorf("not enough fields in \"eth0 000011AC 00000000 0001 0 0 0\"")},
"bad gateway": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 x 0003 0 0 0 00000000 0 0 0
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"bad net number": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
eth0 x 00000000 0001 0 0 0 0000FFFF 0 0 0`),
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"bad net mask": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
eth0 000011AC 00000000 0001 0 0 0 x 0 0 0`),
err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"success": {
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
defaultInterface: "eth0",
defaultGateway: net.IP{0xac, 0x11, 0x0, 0x1},
defaultSubnet: net.IPNet{
IP: net.IP{0xac, 0x11, 0x0, 0x0},
Mask: net.IPMask{0xff, 0xff, 0x0, 0x0},
}},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
fileManager := &filesmocks.FileManager{}
fileManager.On("ReadFile", string(constants.NetRoute)).
Return(tc.data, tc.readErr).Once()
logger := &loggingmocks.Logger{}
logger.On("Info", "%s: detecting default network route", logPrefix).Once()
if tc.err == nil {
logger.On("Info", "%s: default route found: interface %s, gateway %s, subnet %s",
logPrefix, tc.defaultInterface, tc.defaultGateway.String(), tc.defaultSubnet.String()).Once()
}
c := &configurator{logger: logger, fileManager: fileManager}
defaultInterface, defaultGateway, defaultSubnet, err := c.GetDefaultRoute()
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.defaultInterface, defaultInterface)
assert.Equal(t, tc.defaultGateway, defaultGateway)
assert.Equal(t, tc.defaultSubnet, defaultSubnet)
fileManager.AssertExpectations(t)
logger.AssertExpectations(t)
})
}
}
func Test_reversedHexToIPv4(t *testing.T) {
t.Parallel()
tests := map[string]struct {
reversedHex string
IP net.IP
err error
}{
"empty hex": {
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
"bad hex": {
reversedHex: "x",
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"3 bytes hex": {
reversedHex: "9abcde",
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
"correct hex": {
reversedHex: "010011AC",
IP: []byte{0xac, 0x11, 0x0, 0x1},
err: nil},
"correct hex 2": {
reversedHex: "000011AC",
IP: []byte{0xac, 0x11, 0x0, 0x0},
err: nil},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
IP, err := reversedHexToIPv4(tc.reversedHex)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.IP, IP)
})
}
}
func Test_hexMaskToDecMask(t *testing.T) {
t.Parallel()
tests := map[string]struct {
hexString string
mask net.IPMask
err error
}{
"empty hex": {
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
"bad hex": {
hexString: "x",
err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")},
"3 bytes hex": {
hexString: "9abcde",
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
"16": {
hexString: "0000FFFF",
mask: []byte{0xff, 0xff, 0x0, 0x0},
err: nil},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mask, err := hexToIPv4Mask(tc.hexString)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.mask, mask)
})
}
}

24
internal/models/alias.go Normal file
View File

@@ -0,0 +1,24 @@
package models
type (
// VPNDevice is the device name used to tunnel using Openvpn
VPNDevice string
// DNSProvider is a DNS over TLS server provider name
DNSProvider string
// DNSForwardAddress is the Unbound formatted forward address
DNSForwardAddress string
// PIAEncryption defines the level of encryption for communication with PIA servers
PIAEncryption string
// PIARegion contains the list of regions available for PIA
PIARegion string
// URL is an HTTP(s) URL address
URL string
// Filepath is a local filesytem file path
Filepath string
// TinyProxyLogLevel is the log level for TinyProxy
TinyProxyLogLevel string
// VPNProvider is the name of the VPN provider to be used
VPNProvider string
// NetworkProtocol contains the network protocol to be used to communicate with the VPN servers
NetworkProtocol string
)

23
internal/openvpn/auth.go Normal file
View File

@@ -0,0 +1,23 @@
package openvpn
import (
"github.com/qdm12/golibs/files"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
// WriteAuthFile writes the OpenVPN auth file to disk with the right permissions
func (c *configurator) WriteAuthFile(user, password string, uid, gid int) error {
authExists, err := c.fileManager.FileExists(string(constants.OpenVPNAuthConf))
if err != nil {
return err
} else if authExists { // in case of container stop/start
c.logger.Info("%s: %s already exists", logPrefix, constants.OpenVPNAuthConf)
return nil
}
c.logger.Info("%s: writing auth file %s", logPrefix, constants.OpenVPNAuthConf)
return c.fileManager.WriteLinesToFile(
string(constants.OpenVPNAuthConf),
[]string{user, password},
files.FileOwnership(uid, gid),
files.FilePermissions(0400))
}

View File

@@ -0,0 +1,28 @@
package openvpn
import (
"fmt"
"io"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) Start() (stdout io.ReadCloser, err error) {
c.logger.Info("%s: starting openvpn", logPrefix)
stdout, _, _, err = c.commander.Start("openvpn", "--config", string(constants.OpenVPNConf))
return stdout, err
}
func (c *configurator) Version() (string, error) {
output, err := c.commander.Run("openvpn", "--version")
if err != nil && err.Error() != "exit status 1" {
return "", err
}
firstLine := strings.Split(output, "\n")[0]
words := strings.Fields(firstLine)
if len(words) < 2 {
return "", fmt.Errorf("openvpn --version: first line is too short: %q", firstLine)
}
return words[1], nil
}

View File

@@ -0,0 +1,35 @@
package openvpn
import (
"io"
"os"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
const logPrefix = "openvpn configurator"
type Configurator interface {
Version() (string, error)
WriteAuthFile(user, password string, uid, gid int) error
CheckTUN() error
Start() (stdout io.ReadCloser, err error)
}
type configurator struct {
fileManager files.FileManager
logger logging.Logger
commander command.Commander
openFile func(name string, flag int, perm os.FileMode) (*os.File, error)
}
func NewConfigurator(logger logging.Logger, fileManager files.FileManager) Configurator {
return &configurator{
fileManager: fileManager,
logger: logger,
commander: command.NewCommander(),
openFile: os.OpenFile,
}
}

21
internal/openvpn/tun.go Normal file
View File

@@ -0,0 +1,21 @@
package openvpn
import (
"fmt"
"os"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
// CheckTUN checks the tunnel device is present and accessible
func (c *configurator) CheckTUN() error {
c.logger.Info("%s: checking for device %s", logPrefix, constants.TunnelDevice)
f, err := c.openFile(string(constants.TunnelDevice), os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("TUN device is not available: %w", err)
}
if err := f.Close(); err != nil {
c.logger.Warn("Could not close TUN device file: %s", err)
}
return nil
}

93
internal/params/dns.go Normal file
View File

@@ -0,0 +1,93 @@
package params
import (
"fmt"
"strings"
libparams "github.com/qdm12/golibs/params"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
// GetDNSOverTLS obtains if the DNS over TLS should be enabled
// from the environment variable DOT
func (p *paramsReader) GetDNSOverTLS() (DNSOverTLS bool, err error) {
return p.envParams.GetOnOff("DOT", libparams.Default("on"))
}
// GetDNSOverTLSProviders obtains the DNS over TLS providers to use
// from the environment variable DOT_PROVIDERS
func (p *paramsReader) GetDNSOverTLSProviders() (providers []models.DNSProvider, err error) {
s, err := p.envParams.GetEnv("DOT_PROVIDERS", libparams.Default("cloudflare"))
if err != nil {
return nil, err
}
for _, word := range strings.Split(s, ",") {
provider := models.DNSProvider(word)
switch provider {
case constants.Cloudflare, constants.Google, constants.Quad9, constants.Quadrant, constants.CleanBrowsing, constants.SecureDNS, constants.LibreDNS:
providers = append(providers, provider)
default:
return nil, fmt.Errorf("DNS over TLS provider %q is not valid", provider)
}
}
return providers, nil
}
// GetDNSOverTLSVerbosity obtains the verbosity level to use for Unbound
// from the environment variable DOT_VERBOSITY
func (p *paramsReader) GetDNSOverTLSVerbosity() (verbosityLevel uint8, err error) {
n, err := p.envParams.GetEnvIntRange("DOT_VERBOSITY", 0, 5, libparams.Default("1"))
return uint8(n), err
}
// GetDNSOverTLSVerbosityDetails obtains the log level to use for Unbound
// from the environment variable DOT_VERBOSITY_DETAILS
func (p *paramsReader) GetDNSOverTLSVerbosityDetails() (verbosityDetailsLevel uint8, err error) {
n, err := p.envParams.GetEnvIntRange("DOT_VERBOSITY_DETAILS", 0, 4, libparams.Default("0"))
return uint8(n), err
}
// GetDNSOverTLSValidationLogLevel obtains the log level to use for Unbound DOT validation
// from the environment variable DOT_VALIDATION_LOGLEVEL
func (p *paramsReader) GetDNSOverTLSValidationLogLevel() (validationLogLevel uint8, err error) {
n, err := p.envParams.GetEnvIntRange("DOT_VALIDATION_LOGLEVEL", 0, 2, libparams.Default("0"))
return uint8(n), err
}
// GetDNSMaliciousBlocking obtains if malicious hostnames/IPs should be blocked
// from being resolved by Unbound, using the environment variable BLOCK_MALICIOUS
func (p *paramsReader) GetDNSMaliciousBlocking() (blocking bool, err error) {
return p.envParams.GetOnOff("BLOCK_MALICIOUS", libparams.Default("on"))
}
// GetDNSSurveillanceBlocking obtains if surveillance hostnames/IPs should be blocked
// from being resolved by Unbound, using the environment variable BLOCK_NSA
func (p *paramsReader) GetDNSSurveillanceBlocking() (blocking bool, err error) {
return p.envParams.GetOnOff("BLOCK_NSA", libparams.Default("off"))
}
// GetDNSAdsBlocking obtains if ads hostnames/IPs should be blocked
// from being resolved by Unbound, using the environment variable BLOCK_ADS
func (p *paramsReader) GetDNSAdsBlocking() (blocking bool, err error) {
return p.envParams.GetOnOff("BLOCK_ADS", libparams.Default("off"))
}
// GetDNSUnblockedHostnames obtains a list of hostnames to unblock from block lists
// from the comma separated list for the environment variable UNBLOCK
func (p *paramsReader) GetDNSUnblockedHostnames() (hostnames []string, err error) {
s, err := p.envParams.GetEnv("UNBLOCK")
if err != nil {
return nil, err
}
if len(s) == 0 {
return nil, nil
}
hostnames = strings.Split(s, ",")
for _, hostname := range hostnames {
if !p.verifier.MatchHostname(hostname) {
return nil, fmt.Errorf("hostname %q does not seem valid", hostname)
}
}
return hostnames, nil
}

View File

@@ -0,0 +1,29 @@
package params
import (
"fmt"
"net"
"strings"
)
// GetExtraSubnets obtains the CIDR subnets from the comma separated list of the
// environment variable EXTRA_SUBNETS
func (p *paramsReader) GetExtraSubnets() (extraSubnets []net.IPNet, err error) {
s, err := p.envParams.GetEnv("EXTRA_SUBNETS")
if err != nil {
return nil, err
} else if s == "" {
return nil, nil
}
subnets := strings.Split(s, ",")
for _, subnet := range subnets {
_, cidr, err := net.ParseCIDR(subnet)
if err != nil {
return nil, fmt.Errorf("could not parse subnet %q from environment variable with key EXTRA_SUBNETS: %w", subnet, err)
} else if cidr == nil {
return nil, fmt.Errorf("parsing subnet %q resulted in a nil CIDR", subnet)
}
extraSubnets = append(extraSubnets, *cidr)
}
return extraSubnets, nil
}

View File

@@ -0,0 +1,13 @@
package params
import (
libparams "github.com/qdm12/golibs/params"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
// GetNetworkProtocol obtains the network protocol to use to connect to the
// VPN servers from the environment variable PROTOCOL
func (p *paramsReader) GetNetworkProtocol() (protocol models.NetworkProtocol, err error) {
s, err := p.envParams.GetValueIfInside("PROTOCOL", []string{"tcp", "udp"}, libparams.Default("udp"))
return models.NetworkProtocol(s), err
}

73
internal/params/params.go Normal file
View File

@@ -0,0 +1,73 @@
package params
import (
"net"
"os"
"github.com/qdm12/golibs/logging"
libparams "github.com/qdm12/golibs/params"
"github.com/qdm12/golibs/verification"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
// ParamsReader contains methods to obtain parameters
type ParamsReader interface {
// DNS over TLS getters
GetDNSOverTLS() (DNSOverTLS bool, err error)
GetDNSOverTLSProviders() (providers []models.DNSProvider, err error)
GetDNSOverTLSVerbosity() (verbosityLevel uint8, err error)
GetDNSOverTLSVerbosityDetails() (verbosityDetailsLevel uint8, err error)
GetDNSOverTLSValidationLogLevel() (validationLogLevel uint8, err error)
GetDNSMaliciousBlocking() (blocking bool, err error)
GetDNSSurveillanceBlocking() (blocking bool, err error)
GetDNSAdsBlocking() (blocking bool, err error)
GetDNSUnblockedHostnames() (hostnames []string, err error)
// Firewall getters
GetExtraSubnets() (extraSubnets []net.IPNet, err error)
// VPN getters
GetNetworkProtocol() (protocol models.NetworkProtocol, err error)
// PIA getters
GetUser() (s string, err error)
GetPassword() (s string, err error)
GetPortForwarding() (activated bool, err error)
GetPortForwardingStatusFilepath() (filepath models.Filepath, err error)
GetPIAEncryption() (models.PIAEncryption, error)
GetPIARegion() (models.PIARegion, error)
// Shadowsocks getters
GetShadowSocks() (activated bool, err error)
GetShadowSocksLog() (activated bool, err error)
GetShadowSocksPort() (port uint16, err error)
GetShadowSocksPassword() (password string, err error)
// Tinyproxy getters
GetTinyProxy() (activated bool, err error)
GetTinyProxyLog() (models.TinyProxyLogLevel, error)
GetTinyProxyPort() (port uint16, err error)
GetTinyProxyUser() (user string, err error)
GetTinyProxyPassword() (password string, err error)
// Version getters
GetVersion() string
GetBuildDate() string
GetVcsRef() string
}
type paramsReader struct {
envParams libparams.EnvParams
logger logging.Logger
verifier verification.Verifier
unsetEnv func(key string) error
}
func NewParamsReader(logger logging.Logger) ParamsReader {
return &paramsReader{
envParams: libparams.NewEnvParams(),
logger: logger,
verifier: verification.NewVerifier(),
unsetEnv: os.Unsetenv,
}
}

87
internal/params/pia.go Normal file
View File

@@ -0,0 +1,87 @@
package params
import (
"fmt"
"math/rand"
libparams "github.com/qdm12/golibs/params"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
// GetUser obtains the user to use to connect to the VPN servers
func (p *paramsReader) GetUser() (s string, err error) {
defer func() {
unsetenvErr := p.unsetEnv("USER")
if err == nil {
err = unsetenvErr
}
}()
s, err = p.envParams.GetEnv("USER")
if err != nil {
return "", err
} else if len(s) == 0 {
return s, fmt.Errorf("USER environment variable cannot be empty")
}
return s, nil
}
// GetPassword obtains the password to use to connect to the VPN servers
func (p *paramsReader) GetPassword() (s string, err error) {
defer func() {
unsetenvErr := p.unsetEnv("PASSWORD")
if err == nil {
err = unsetenvErr
}
}()
s, err = p.envParams.GetEnv("PASSWORD")
if err != nil {
return "", err
} else if len(s) == 0 {
return s, fmt.Errorf("PASSWORD environment variable cannot be empty")
}
return s, nil
}
// GetPortForwarding obtains if port forwarding on the VPN provider server
// side is enabled or not from the environment variable PORT_FORWARDING
func (p *paramsReader) GetPortForwarding() (activated bool, err error) {
s, err := p.envParams.GetEnv("PORT_FORWARDING", libparams.Default("off"))
if err != nil {
return false, err
}
// Custom for retro-compatibility
if s == "false" || s == "off" {
return false, nil
} else if s == "true" || s == "on" {
return true, nil
}
return false, fmt.Errorf("PORT_FORWARDING can only be \"on\" or \"off\"")
}
// GetPortForwardingStatusFilepath obtains the port forwarding status file path
// from the environment variable PORT_FORWARDING_STATUS_FILE
func (p *paramsReader) GetPortForwardingStatusFilepath() (filepath models.Filepath, err error) {
filepathStr, err := p.envParams.GetPath("PORT_FORWARDING_STATUS_FILE", libparams.Default("/forwarded_port"))
return models.Filepath(filepathStr), err
}
// GetPIAEncryption obtains the encryption level for the PIA connection
// from the environment variable ENCRYPTION
func (p *paramsReader) GetPIAEncryption() (models.PIAEncryption, error) {
s, err := p.envParams.GetValueIfInside("ENCRYPTION", []string{"normal", "strong"}, libparams.Default("strong"))
return models.PIAEncryption(s), err
}
// GetPIARegion obtains the region for the PIA server from the
// environment variable REGION
func (p *paramsReader) GetPIARegion() (region models.PIARegion, err error) {
choices := []string{
string(constants.AUMelbourne), string(constants.AUPerth), string(constants.AUSydney), string(constants.Austria), string(constants.Belgium), string(constants.CAMontreal), string(constants.CAToronto), string(constants.CAVancouver), string(constants.CzechRepublic), string(constants.DEBerlin), string(constants.DEFrankfurt), string(constants.Denmark), string(constants.Finland), string(constants.France), string(constants.HongKong), string(constants.Hungary), string(constants.India), string(constants.Ireland), string(constants.Israel), string(constants.Italy), string(constants.Japan), string(constants.Luxembourg), string(constants.Mexico), string(constants.Netherlands), string(constants.NewZealand), string(constants.Norway), string(constants.Poland), string(constants.Romania), string(constants.Singapore), string(constants.Spain), string(constants.Sweden), string(constants.Switzerland), string(constants.UAE), string(constants.UKLondon), string(constants.UKManchester), string(constants.UKSouthampton), string(constants.USAtlanta), string(constants.USCalifornia), string(constants.USChicago), string(constants.USDenver), string(constants.USEast), string(constants.USFlorida), string(constants.USHouston), string(constants.USLasVegas), string(constants.USNewYorkCity), string(constants.USSeattle), string(constants.USSiliconValley), string(constants.USTexas), string(constants.USWashingtonDC), string(constants.USWest),
}
s, err := p.envParams.GetValueIfInside("REGION", choices)
if len(s) == 0 { // Suggestion by @rorph https://github.com/rorph
s = choices[rand.Int()%len(choices)]
}
return models.PIARegion(s), err
}

View File

@@ -0,0 +1,40 @@
package params
import (
"strconv"
libparams "github.com/qdm12/golibs/params"
)
// GetShadowSocks obtains if ShadowSocks is on from the environment variable
// SHADOWSOCKS
func (p *paramsReader) GetShadowSocks() (activated bool, err error) {
return p.envParams.GetOnOff("SHADOWSOCKS", libparams.Default("off"))
}
// GetShadowSocksLog obtains the ShadowSocks log level from the environment variable
// SHADOWSOCKS_LOG
func (p *paramsReader) GetShadowSocksLog() (activated bool, err error) {
return p.envParams.GetOnOff("SHADOWSOCKS_LOG", libparams.Default("off"))
}
// GetShadowSocksPort obtains the ShadowSocks listening port from the environment variable
// SHADOWSOCKS_PORT
func (p *paramsReader) GetShadowSocksPort() (port uint16, err error) {
portStr, err := p.envParams.GetEnv("SHADOWSOCKS_PORT", libparams.Default("8388"))
if err != nil {
return 0, err
}
if err := p.verifier.VerifyPort(portStr); err != nil {
return 0, err
}
portUint64, err := strconv.ParseUint(portStr, 10, 16)
return uint16(portUint64), err
}
// GetShadowSocksPassword obtains the ShadowSocks server password from the environment variable
// SHADOWSOCKS_PASSWORD
func (p *paramsReader) GetShadowSocksPassword() (password string, err error) {
defer p.unsetEnv("SHADOWSOCKS_PASSWORD")
return p.envParams.GetEnv("SHADOWSOCKS_PASSWORD")
}

View File

@@ -0,0 +1,94 @@
package params
import (
"strconv"
libparams "github.com/qdm12/golibs/params"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
// GetTinyProxy obtains if TinyProxy is on from the environment variable
// TINYPROXY, and using PROXY as a retro-compatibility name
func (p *paramsReader) GetTinyProxy() (activated bool, err error) {
// Retro-compatibility
s, err := p.envParams.GetEnv("PROXY")
if err != nil {
return false, err
} else if len(s) != 0 {
p.logger.Warn("You are using the old environment variable PROXY, please consider changing it to TINYPROXY")
return p.envParams.GetOnOff("PROXY", libparams.Compulsory())
}
return p.envParams.GetOnOff("TINYPROXY", libparams.Default("off"))
}
// GetTinyProxyLog obtains the TinyProxy log level from the environment variable
// TINYPROXY_LOG, and using PROXY_LOG_LEVEL as a retro-compatibility name
func (p *paramsReader) GetTinyProxyLog() (models.TinyProxyLogLevel, error) {
// Retro-compatibility
s, err := p.envParams.GetEnv("PROXY_LOG_LEVEL")
if err != nil {
return models.TinyProxyLogLevel(s), err
} else if len(s) != 0 {
p.logger.Warn("You are using the old environment variable PROXY_LOG_LEVEL, please consider changing it to TINYPROXY_LOG")
s, err = p.envParams.GetValueIfInside("PROXY_LOG_LEVEL", []string{"info", "warning", "error", "critical"}, libparams.Compulsory())
return models.TinyProxyLogLevel(s), err
}
s, err = p.envParams.GetValueIfInside("TINYPROXY_LOG", []string{"info", "warning", "error", "critical"}, libparams.Default("info"))
return models.TinyProxyLogLevel(s), err
}
// GetTinyProxyPort obtains the TinyProxy listening port from the environment variable
// TINYPROXY_PORT, and using PROXY_PORT as a retro-compatibility name
func (p *paramsReader) GetTinyProxyPort() (port uint16, err error) {
// Retro-compatibility
portStr, err := p.envParams.GetEnv("PROXY_PORT")
if err != nil {
return 0, err
} else if len(portStr) != 0 {
p.logger.Warn("You are using the old environment variable PROXY_PORT, please consider changing it to TINYPROXY_PORT")
} else {
portStr, err = p.envParams.GetEnv("TINYPROXY_PORT", libparams.Default("8888"))
if err != nil {
return 0, err
}
}
if err := p.verifier.VerifyPort(portStr); err != nil {
return 0, err
}
portUint64, err := strconv.ParseUint(portStr, 10, 16)
return uint16(portUint64), err
}
// GetTinyProxyUser obtains the TinyProxy server user from the environment variable
// TINYPROXY_USER, and using PROXY_USER as a retro-compatibility name
func (p *paramsReader) GetTinyProxyUser() (user string, err error) {
defer p.unsetEnv("PROXY_USER")
defer p.unsetEnv("TINYPROXY_USER")
// Retro-compatibility
user, err = p.envParams.GetEnv("PROXY_USER")
if err != nil {
return user, err
}
if len(user) != 0 {
p.logger.Warn("You are using the old environment variable PROXY_USER, please consider changing it to TINYPROXY_USER")
return user, nil
}
return p.envParams.GetEnv("TINYPROXY_USER")
}
// GetTinyProxyPassword obtains the TinyProxy server password from the environment variable
// TINYPROXY_PASSWORD, and using PROXY_PASSWORD as a retro-compatibility name
func (p *paramsReader) GetTinyProxyPassword() (password string, err error) {
defer p.unsetEnv("PROXY_PASSWORD")
defer p.unsetEnv("TINYPROXY_PASSWORD")
// Retro-compatibility
password, err = p.envParams.GetEnv("PROXY_PASSWORD")
if err != nil {
return password, err
}
if len(password) != 0 {
p.logger.Warn("You are using the old environment variable PROXY_PASSWORD, please consider changing it to TINYPROXY_PASSWORD")
return password, nil
}
return p.envParams.GetEnv("TINYPROXY_PASSWORD")
}

View File

@@ -0,0 +1,20 @@
package params
import (
"github.com/qdm12/golibs/params"
)
func (p *paramsReader) GetVersion() string {
version, _ := p.envParams.GetEnv("VERSION", params.Default("?"))
return version
}
func (p *paramsReader) GetBuildDate() string {
buildDate, _ := p.envParams.GetEnv("BUILD_DATE", params.Default("?"))
return buildDate
}
func (p *paramsReader) GetVcsRef() string {
buildDate, _ := p.envParams.GetEnv("VCS_REF", params.Default("?"))
return buildDate
}

61
internal/pia/download.go Normal file
View File

@@ -0,0 +1,61 @@
package pia
import (
"archive/zip"
"bytes"
"fmt"
"io/ioutil"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
func (c *configurator) DownloadOvpnConfig(encryption models.PIAEncryption,
protocol models.NetworkProtocol, region models.PIARegion) (lines []string, err error) {
c.logger.Info("%s: downloading openvpn configuration files", logPrefix)
URL := buildZipURL(encryption, protocol)
content, status, err := c.client.GetContent(URL)
if err != nil {
return nil, err
} else if status != 200 {
return nil, fmt.Errorf("HTTP Get %s resulted in HTTP status code %d", URL, status)
}
filename := fmt.Sprintf("%s.ovpn", region)
fileContent, err := getFileInZip(content, filename)
if err != nil {
return nil, fmt.Errorf("%s: %w", URL, err)
}
lines = strings.Split(string(fileContent), "\n")
return lines, nil
}
func buildZipURL(encryption models.PIAEncryption, protocol models.NetworkProtocol) (URL string) {
URL = string(constants.PIAOpenVPNURL) + "/openvpn"
if encryption == constants.PIAEncryptionStrong {
URL += "-strong"
}
if protocol == constants.TCP {
URL += "-tcp"
}
return URL + ".zip"
}
func getFileInZip(zipContent []byte, filename string) (fileContent []byte, err error) {
contentLength := int64(len(zipContent))
r, err := zip.NewReader(bytes.NewReader(zipContent), contentLength)
if err != nil {
return nil, err
}
for _, f := range r.File {
if f.Name == filename {
readCloser, err := f.Open()
if err != nil {
return nil, err
}
defer readCloser.Close()
return ioutil.ReadAll(readCloser)
}
}
return nil, fmt.Errorf("%s not found in zip archive file", filename)
}

31
internal/pia/modify.go Normal file
View File

@@ -0,0 +1,31 @@
package pia
import (
"fmt"
"net"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) ModifyLines(lines []string, IPs []net.IP, port uint16) (modifiedLines []string) {
c.logger.Info("%s: adapting openvpn configuration for server IP addresses and port %d", logPrefix, port)
// Remove lines
for _, line := range lines {
if strings.Contains(line, "privateinternetaccess.com") ||
strings.Contains(line, "resolv-retry") {
continue
}
modifiedLines = append(modifiedLines, line)
}
// Add lines
for _, IP := range IPs {
modifiedLines = append(modifiedLines, fmt.Sprintf("remote %s %d", IP.String(), port))
}
modifiedLines = append(modifiedLines, "auth-user-pass "+string(constants.OpenVPNAuthConf))
modifiedLines = append(modifiedLines, "auth-retry nointeract")
modifiedLines = append(modifiedLines, "pull-filter ignore \"auth-token\"") // prevent auth failed loops
modifiedLines = append(modifiedLines, "user nonrootuser")
modifiedLines = append(modifiedLines, "mute-replay-warnings")
return modifiedLines
}

View File

@@ -0,0 +1,31 @@
package pia
import (
"io/ioutil"
"net"
"strings"
"testing"
loggingMocks "github.com/qdm12/golibs/logging/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_ModifyLines(t *testing.T) {
t.Parallel()
original, err := ioutil.ReadFile("testdata/ovpn.golden")
require.NoError(t, err)
originalLines := strings.Split(string(original), "\n")
expected, err := ioutil.ReadFile("testdata/ovpn.modified.golden")
require.NoError(t, err)
expectedLines := strings.Split(string(expected), "\n")
var port uint16 = 3000
IPs := []net.IP{net.IP{100, 10, 10, 10}, net.IP{100, 20, 20, 20}}
logger := &loggingMocks.Logger{}
logger.On("Info", "%s: adapting openvpn configuration for server IP addresses and port %d", logPrefix, port).Once()
c := &configurator{logger: logger}
modifiedLines := c.ModifyLines(originalLines, IPs, port)
assert.Equal(t, expectedLines, modifiedLines)
logger.AssertExpectations(t)
}

54
internal/pia/parse.go Normal file
View File

@@ -0,0 +1,54 @@
package pia
import (
"fmt"
"net"
"strconv"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
func (c *configurator) ParseConfig(lines []string) (IPs []net.IP, port uint16, device models.VPNDevice, err error) {
c.logger.Info("%s: parsing openvpn configuration", logPrefix)
remoteLineFound := false
deviceLineFound := false
for _, line := range lines {
if strings.HasPrefix(line, "remote ") {
remoteLineFound = true
words := strings.Fields(line)
if len(words) != 3 {
return nil, 0, "", fmt.Errorf("line %q misses information", line)
}
host := words[1]
if err := c.verifyPort(words[2]); err != nil {
return nil, 0, "", fmt.Errorf("line %q has an invalid port: %w", line, err)
}
portUint64, _ := strconv.ParseUint(words[2], 10, 16)
port = uint16(portUint64)
IPs, err = c.lookupIP(host)
if err != nil {
return nil, 0, "", err
}
} else if strings.HasPrefix(line, "dev ") {
deviceLineFound = true
fields := strings.Fields(line)
if len(fields) != 2 {
return nil, 0, "", fmt.Errorf("line %q misses information", line)
}
device = models.VPNDevice(fields[1] + "0")
if device != constants.TUN && device != constants.TAP {
return nil, 0, "", fmt.Errorf("device %q is not valid", device)
}
}
}
if remoteLineFound && deviceLineFound {
c.logger.Info("%s: Found %d PIA server IP addresses, port %d and device %s", logPrefix, len(IPs), port, device)
return IPs, port, device, nil
} else if !remoteLineFound {
return nil, 0, "", fmt.Errorf("remote line not found in Openvpn configuration")
} else {
return nil, 0, "", fmt.Errorf("device line not found in Openvpn configuration")
}
}

View File

@@ -0,0 +1,99 @@
package pia
import (
"fmt"
"io/ioutil"
"net"
"strings"
"testing"
loggingMocks "github.com/qdm12/golibs/logging/mocks"
"github.com/qdm12/golibs/verification"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_ParseConfig(t *testing.T) {
t.Parallel()
original, err := ioutil.ReadFile("testdata/ovpn.golden")
require.NoError(t, err)
exampleLines := strings.Split(string(original), "\n")
tests := map[string]struct {
lines []string
lookupIPs []net.IP
lookupIPErr error
IPs []net.IP
port uint16
device models.VPNDevice
err error
}{
"no data": {
err: fmt.Errorf("remote line not found in Openvpn configuration"),
},
"bad remote line": {
lines: []string{"remote field2"},
err: fmt.Errorf("line \"remote field2\" misses information"),
},
"bad remote port": {
lines: []string{"remote field2 port"},
err: fmt.Errorf("line \"remote field2 port\" has an invalid port: port \"port\" is not a valid integer"),
},
"lookupIP error": {
lines: []string{"remote host 1000"},
lookupIPErr: fmt.Errorf("lookup error"),
err: fmt.Errorf("lookup error"),
},
"missing dev line": {
lines: []string{"remote host 1994"},
err: fmt.Errorf("device line not found in Openvpn configuration"),
},
"bad dev line": {
lines: []string{"dev field2 field3"},
err: fmt.Errorf("line \"dev field2 field3\" misses information"),
},
"bad device": {
lines: []string{"dev xx"},
err: fmt.Errorf("device \"xx0\" is not valid"),
},
"valid lines": {
lines: []string{"remote host 1194", "dev tap", "blabla"},
port: 1194,
device: constants.TAP,
},
"real data": {
lines: exampleLines,
lookupIPs: []net.IP{{100, 100, 100, 100}, {100, 100, 200, 200}},
IPs: []net.IP{{100, 100, 100, 100}, {100, 100, 200, 200}},
port: 1198,
device: constants.TUN,
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
logger := &loggingMocks.Logger{}
logger.On("Info", "%s: parsing openvpn configuration", logPrefix).Once()
if tc.err == nil {
logger.On("Info", "%s: Found %d PIA server IP addresses, port %d and device %s", logPrefix, len(tc.IPs), tc.port, tc.device).Once()
}
lookupIP := func(host string) ([]net.IP, error) {
return tc.lookupIPs, tc.lookupIPErr
}
c := &configurator{logger: logger, verifyPort: verification.NewVerifier().VerifyPort, lookupIP: lookupIP}
IPs, port, device, err := c.ParseConfig(tc.lines)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.IPs, IPs)
assert.Equal(t, tc.port, port)
assert.Equal(t, tc.device, device)
logger.AssertExpectations(t)
})
}
}

41
internal/pia/pia.go Normal file
View File

@@ -0,0 +1,41 @@
package pia
import (
"net"
"github.com/qdm12/golibs/crypto/random"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
"github.com/qdm12/golibs/verification"
"github.com/qdm12/private-internet-access-docker/internal/firewall"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
const logPrefix = "PIA configurator"
// Configurator contains methods to download, read and modify the openvpn configuration to connect as a client
type Configurator interface {
DownloadOvpnConfig(encryption models.PIAEncryption,
protocol models.NetworkProtocol, region models.PIARegion) (lines []string, err error)
ParseConfig(lines []string) (IPs []net.IP, port uint16, device models.VPNDevice, err error)
ModifyLines(lines []string, IPs []net.IP, port uint16) (modifiedLines []string)
GetPortForward() (port uint16, err error)
WritePortForward(filepath models.Filepath, port uint16) (err error)
AllowPortForwardFirewall(device models.VPNDevice, port uint16) (err error)
}
type configurator struct {
client network.Client
fileManager files.FileManager
firewall firewall.Configurator
logger logging.Logger
random random.Random
verifyPort func(port string) error
lookupIP func(host string) ([]net.IP, error)
}
// NewConfigurator returns a new Configurator object
func NewConfigurator(client network.Client, fileManager files.FileManager, firewall firewall.Configurator, logger logging.Logger) Configurator {
return &configurator{client, fileManager, firewall, logger, random.NewRandom(), verification.NewVerifier().VerifyPort, net.LookupIP}
}

View File

@@ -0,0 +1,46 @@
package pia
import (
"encoding/hex"
"encoding/json"
"fmt"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
func (c *configurator) GetPortForward() (port uint16, err error) {
c.logger.Info("%s: Obtaining port to be forwarded", logPrefix)
b, err := c.random.GenerateRandomBytes(32)
if err != nil {
return 0, err
}
clientID := hex.EncodeToString(b)
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
content, status, err := c.client.GetContent(url)
if err != nil {
return 0, err
} else if status != 200 {
return 0, fmt.Errorf("status is %d for %s; does your PIA server support port forwarding?", status, url)
} else if len(content) == 0 {
return 0, fmt.Errorf("port forwarding is already activated on this connection, has expired, or you are not connected to a PIA region that supports port forwarding")
}
body := struct {
Port uint16 `json:"port"`
}{}
if err := json.Unmarshal(content, &body); err != nil {
return 0, fmt.Errorf("port forwarding response: %w", err)
}
c.logger.Info("%s: Port forwarded is %d", logPrefix, port)
return body.Port, nil
}
func (c *configurator) WritePortForward(filepath models.Filepath, port uint16) (err error) {
c.logger.Info("%s: Writing forwarded port to %s", logPrefix, filepath)
return c.fileManager.WriteLinesToFile(string(filepath), []string{fmt.Sprintf("%d", port)})
}
func (c *configurator) AllowPortForwardFirewall(device models.VPNDevice, port uint16) (err error) {
c.logger.Info("%s: Allowing forwarded port %d through firewall", logPrefix, port)
return c.firewall.AllowInputTrafficOnPort(device, port)
}

72
internal/pia/testdata/ovpn.golden vendored Normal file
View File

@@ -0,0 +1,72 @@
client
dev tun
proto udp
remote belgium.privateinternetaccess.com 1198
resolv-retry infinite
nobind
persist-key
persist-tun
cipher aes-128-cbc
auth sha1
tls-client
remote-cert-tls server
auth-user-pass
compress
verb 1
reneg-sec 0
<crl-verify>
-----BEGIN X509 CRL-----
MIICWDCCAUAwDQYJKoZIhvcNAQENBQAwgegxCzAJBgNVBAYTAlVTMQswCQYDVQQI
EwJDQTETMBEGA1UEBxMKTG9zQW5nZWxlczEgMB4GA1UEChMXUHJpdmF0ZSBJbnRl
cm5ldCBBY2Nlc3MxIDAeBgNVBAsTF1ByaXZhdGUgSW50ZXJuZXQgQWNjZXNzMSAw
HgYDVQQDExdQcml2YXRlIEludGVybmV0IEFjY2VzczEgMB4GA1UEKRMXUHJpdmF0
ZSBJbnRlcm5ldCBBY2Nlc3MxLzAtBgkqhkiG9w0BCQEWIHNlY3VyZUBwcml2YXRl
aW50ZXJuZXRhY2Nlc3MuY29tFw0xNjA3MDgxOTAwNDZaFw0zNjA3MDMxOTAwNDZa
MCYwEQIBARcMMTYwNzA4MTkwMDQ2MBECAQYXDDE2MDcwODE5MDA0NjANBgkqhkiG
9w0BAQ0FAAOCAQEAQZo9X97ci8EcPYu/uK2HB152OZbeZCINmYyluLDOdcSvg6B5
jI+ffKN3laDvczsG6CxmY3jNyc79XVpEYUnq4rT3FfveW1+Ralf+Vf38HdpwB8EW
B4hZlQ205+21CALLvZvR8HcPxC9KEnev1mU46wkTiov0EKc+EdRxkj5yMgv0V2Re
ze7AP+NQ9ykvDScH4eYCsmufNpIjBLhpLE2cuZZXBLcPhuRzVoU3l7A9lvzG9mjA
5YijHJGHNjlWFqyrn1CfYS6koa4TGEPngBoAziWRbDGdhEgJABHrpoaFYaL61zqy
MR6jC0K2ps9qyZAN74LEBedEfK7tBOzWMwr58A==
-----END X509 CRL-----
</crl-verify>
<ca>
-----BEGIN CERTIFICATE-----
MIIFqzCCBJOgAwIBAgIJAKZ7D5Yv87qDMA0GCSqGSIb3DQEBDQUAMIHoMQswCQYD
VQQGEwJVUzELMAkGA1UECBMCQ0ExEzARBgNVBAcTCkxvc0FuZ2VsZXMxIDAeBgNV
BAoTF1ByaXZhdGUgSW50ZXJuZXQgQWNjZXNzMSAwHgYDVQQLExdQcml2YXRlIElu
dGVybmV0IEFjY2VzczEgMB4GA1UEAxMXUHJpdmF0ZSBJbnRlcm5ldCBBY2Nlc3Mx
IDAeBgNVBCkTF1ByaXZhdGUgSW50ZXJuZXQgQWNjZXNzMS8wLQYJKoZIhvcNAQkB
FiBzZWN1cmVAcHJpdmF0ZWludGVybmV0YWNjZXNzLmNvbTAeFw0xNDA0MTcxNzM1
MThaFw0zNDA0MTIxNzM1MThaMIHoMQswCQYDVQQGEwJVUzELMAkGA1UECBMCQ0Ex
EzARBgNVBAcTCkxvc0FuZ2VsZXMxIDAeBgNVBAoTF1ByaXZhdGUgSW50ZXJuZXQg
QWNjZXNzMSAwHgYDVQQLExdQcml2YXRlIEludGVybmV0IEFjY2VzczEgMB4GA1UE
AxMXUHJpdmF0ZSBJbnRlcm5ldCBBY2Nlc3MxIDAeBgNVBCkTF1ByaXZhdGUgSW50
ZXJuZXQgQWNjZXNzMS8wLQYJKoZIhvcNAQkBFiBzZWN1cmVAcHJpdmF0ZWludGVy
bmV0YWNjZXNzLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAPXD
L1L9tX6DGf36liA7UBTy5I869z0UVo3lImfOs/GSiFKPtInlesP65577nd7UNzzX
lH/P/CnFPdBWlLp5ze3HRBCc/Avgr5CdMRkEsySL5GHBZsx6w2cayQ2EcRhVTwWp
cdldeNO+pPr9rIgPrtXqT4SWViTQRBeGM8CDxAyTopTsobjSiYZCF9Ta1gunl0G/
8Vfp+SXfYCC+ZzWvP+L1pFhPRqzQQ8k+wMZIovObK1s+nlwPaLyayzw9a8sUnvWB
/5rGPdIYnQWPgoNlLN9HpSmsAcw2z8DXI9pIxbr74cb3/HSfuYGOLkRqrOk6h4RC
OfuWoTrZup1uEOn+fw8CAwEAAaOCAVQwggFQMB0GA1UdDgQWBBQv63nQ/pJAt5tL
y8VJcbHe22ZOsjCCAR8GA1UdIwSCARYwggESgBQv63nQ/pJAt5tLy8VJcbHe22ZO
sqGB7qSB6zCB6DELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAkNBMRMwEQYDVQQHEwpM
b3NBbmdlbGVzMSAwHgYDVQQKExdQcml2YXRlIEludGVybmV0IEFjY2VzczEgMB4G
A1UECxMXUHJpdmF0ZSBJbnRlcm5ldCBBY2Nlc3MxIDAeBgNVBAMTF1ByaXZhdGUg
SW50ZXJuZXQgQWNjZXNzMSAwHgYDVQQpExdQcml2YXRlIEludGVybmV0IEFjY2Vz
czEvMC0GCSqGSIb3DQEJARYgc2VjdXJlQHByaXZhdGVpbnRlcm5ldGFjY2Vzcy5j
b22CCQCmew+WL/O6gzAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBDQUAA4IBAQAn
a5PgrtxfwTumD4+3/SYvwoD66cB8IcK//h1mCzAduU8KgUXocLx7QgJWo9lnZ8xU
ryXvWab2usg4fqk7FPi00bED4f4qVQFVfGfPZIH9QQ7/48bPM9RyfzImZWUCenK3
7pdw4Bvgoys2rHLHbGen7f28knT2j/cbMxd78tQc20TIObGjo8+ISTRclSTRBtyC
GohseKYpTS9himFERpUgNtefvYHbn70mIOzfOJFTVqfrptf9jXa9N8Mpy3ayfodz
1wiqdteqFXkTYoSDctgKMiZ6GdocK9nMroQipIQtpnwd4yBDWIyC6Bvlkrq5TQUt
YDQ8z9v+DMO6iwyIDRiU
-----END CERTIFICATE-----
</ca>
disable-occ

View File

@@ -0,0 +1,78 @@
client
dev tun
proto udp
nobind
persist-key
persist-tun
cipher aes-128-cbc
auth sha1
tls-client
remote-cert-tls server
auth-user-pass
compress
verb 1
reneg-sec 0
<crl-verify>
-----BEGIN X509 CRL-----
MIICWDCCAUAwDQYJKoZIhvcNAQENBQAwgegxCzAJBgNVBAYTAlVTMQswCQYDVQQI
EwJDQTETMBEGA1UEBxMKTG9zQW5nZWxlczEgMB4GA1UEChMXUHJpdmF0ZSBJbnRl
cm5ldCBBY2Nlc3MxIDAeBgNVBAsTF1ByaXZhdGUgSW50ZXJuZXQgQWNjZXNzMSAw
HgYDVQQDExdQcml2YXRlIEludGVybmV0IEFjY2VzczEgMB4GA1UEKRMXUHJpdmF0
ZSBJbnRlcm5ldCBBY2Nlc3MxLzAtBgkqhkiG9w0BCQEWIHNlY3VyZUBwcml2YXRl
aW50ZXJuZXRhY2Nlc3MuY29tFw0xNjA3MDgxOTAwNDZaFw0zNjA3MDMxOTAwNDZa
MCYwEQIBARcMMTYwNzA4MTkwMDQ2MBECAQYXDDE2MDcwODE5MDA0NjANBgkqhkiG
9w0BAQ0FAAOCAQEAQZo9X97ci8EcPYu/uK2HB152OZbeZCINmYyluLDOdcSvg6B5
jI+ffKN3laDvczsG6CxmY3jNyc79XVpEYUnq4rT3FfveW1+Ralf+Vf38HdpwB8EW
B4hZlQ205+21CALLvZvR8HcPxC9KEnev1mU46wkTiov0EKc+EdRxkj5yMgv0V2Re
ze7AP+NQ9ykvDScH4eYCsmufNpIjBLhpLE2cuZZXBLcPhuRzVoU3l7A9lvzG9mjA
5YijHJGHNjlWFqyrn1CfYS6koa4TGEPngBoAziWRbDGdhEgJABHrpoaFYaL61zqy
MR6jC0K2ps9qyZAN74LEBedEfK7tBOzWMwr58A==
-----END X509 CRL-----
</crl-verify>
<ca>
-----BEGIN CERTIFICATE-----
MIIFqzCCBJOgAwIBAgIJAKZ7D5Yv87qDMA0GCSqGSIb3DQEBDQUAMIHoMQswCQYD
VQQGEwJVUzELMAkGA1UECBMCQ0ExEzARBgNVBAcTCkxvc0FuZ2VsZXMxIDAeBgNV
BAoTF1ByaXZhdGUgSW50ZXJuZXQgQWNjZXNzMSAwHgYDVQQLExdQcml2YXRlIElu
dGVybmV0IEFjY2VzczEgMB4GA1UEAxMXUHJpdmF0ZSBJbnRlcm5ldCBBY2Nlc3Mx
IDAeBgNVBCkTF1ByaXZhdGUgSW50ZXJuZXQgQWNjZXNzMS8wLQYJKoZIhvcNAQkB
FiBzZWN1cmVAcHJpdmF0ZWludGVybmV0YWNjZXNzLmNvbTAeFw0xNDA0MTcxNzM1
MThaFw0zNDA0MTIxNzM1MThaMIHoMQswCQYDVQQGEwJVUzELMAkGA1UECBMCQ0Ex
EzARBgNVBAcTCkxvc0FuZ2VsZXMxIDAeBgNVBAoTF1ByaXZhdGUgSW50ZXJuZXQg
QWNjZXNzMSAwHgYDVQQLExdQcml2YXRlIEludGVybmV0IEFjY2VzczEgMB4GA1UE
AxMXUHJpdmF0ZSBJbnRlcm5ldCBBY2Nlc3MxIDAeBgNVBCkTF1ByaXZhdGUgSW50
ZXJuZXQgQWNjZXNzMS8wLQYJKoZIhvcNAQkBFiBzZWN1cmVAcHJpdmF0ZWludGVy
bmV0YWNjZXNzLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAPXD
L1L9tX6DGf36liA7UBTy5I869z0UVo3lImfOs/GSiFKPtInlesP65577nd7UNzzX
lH/P/CnFPdBWlLp5ze3HRBCc/Avgr5CdMRkEsySL5GHBZsx6w2cayQ2EcRhVTwWp
cdldeNO+pPr9rIgPrtXqT4SWViTQRBeGM8CDxAyTopTsobjSiYZCF9Ta1gunl0G/
8Vfp+SXfYCC+ZzWvP+L1pFhPRqzQQ8k+wMZIovObK1s+nlwPaLyayzw9a8sUnvWB
/5rGPdIYnQWPgoNlLN9HpSmsAcw2z8DXI9pIxbr74cb3/HSfuYGOLkRqrOk6h4RC
OfuWoTrZup1uEOn+fw8CAwEAAaOCAVQwggFQMB0GA1UdDgQWBBQv63nQ/pJAt5tL
y8VJcbHe22ZOsjCCAR8GA1UdIwSCARYwggESgBQv63nQ/pJAt5tLy8VJcbHe22ZO
sqGB7qSB6zCB6DELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAkNBMRMwEQYDVQQHEwpM
b3NBbmdlbGVzMSAwHgYDVQQKExdQcml2YXRlIEludGVybmV0IEFjY2VzczEgMB4G
A1UECxMXUHJpdmF0ZSBJbnRlcm5ldCBBY2Nlc3MxIDAeBgNVBAMTF1ByaXZhdGUg
SW50ZXJuZXQgQWNjZXNzMSAwHgYDVQQpExdQcml2YXRlIEludGVybmV0IEFjY2Vz
czEvMC0GCSqGSIb3DQEJARYgc2VjdXJlQHByaXZhdGVpbnRlcm5ldGFjY2Vzcy5j
b22CCQCmew+WL/O6gzAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBDQUAA4IBAQAn
a5PgrtxfwTumD4+3/SYvwoD66cB8IcK//h1mCzAduU8KgUXocLx7QgJWo9lnZ8xU
ryXvWab2usg4fqk7FPi00bED4f4qVQFVfGfPZIH9QQ7/48bPM9RyfzImZWUCenK3
7pdw4Bvgoys2rHLHbGen7f28knT2j/cbMxd78tQc20TIObGjo8+ISTRclSTRBtyC
GohseKYpTS9himFERpUgNtefvYHbn70mIOzfOJFTVqfrptf9jXa9N8Mpy3ayfodz
1wiqdteqFXkTYoSDctgKMiZ6GdocK9nMroQipIQtpnwd4yBDWIyC6Bvlkrq5TQUt
YDQ8z9v+DMO6iwyIDRiU
-----END CERTIFICATE-----
</ca>
disable-occ
remote 100.10.10.10 3000
remote 100.20.20.20 3000
auth-user-pass /etc/openvpn/auth.conf
auth-retry nointeract
pull-filter ignore "auth-token"
user nonrootuser
mute-replay-warnings

108
internal/settings/dns.go Normal file
View File

@@ -0,0 +1,108 @@
package settings
import (
"fmt"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/models"
"github.com/qdm12/private-internet-access-docker/internal/params"
)
// DNS contains settings to configure Unbound for DNS over TLS operation
type DNS struct {
Enabled bool
Providers []models.DNSProvider
AllowedHostnames []string
PrivateAddresses []string
BlockMalicious bool
BlockSurveillance bool
BlockAds bool
VerbosityLevel uint8
VerbosityDetailsLevel uint8
ValidationLogLevel uint8
}
func (d *DNS) String() string {
if !d.Enabled {
return "DNS over TLS settings: disabled"
}
blockMalicious, blockSurveillance, blockAds := "disabed", "disabed", "disabed"
if d.BlockMalicious {
blockMalicious = "enabled"
}
if d.BlockSurveillance {
blockSurveillance = "enabled"
}
if d.BlockAds {
blockAds = "enabled"
}
var providersStr []string
for _, provider := range d.Providers {
providersStr = append(providersStr, string(provider))
}
settingsList := []string{
"DNS over TLS settings:",
"DNS over TLS provider: \n |--" + strings.Join(providersStr, "\n |--"),
"Block malicious: " + blockMalicious,
"Block surveillance: " + blockSurveillance,
"Block ads: " + blockAds,
"Allowed hostnames: " + strings.Join(d.AllowedHostnames, ", "),
"Private addresses:\n |--" + strings.Join(d.PrivateAddresses, "\n |--"),
"Verbosity level: " + fmt.Sprintf("%d/5", d.VerbosityLevel),
"Verbosity details level: " + fmt.Sprintf("%d/4", d.VerbosityDetailsLevel),
"Validation log level: " + fmt.Sprintf("%d/2", d.ValidationLogLevel),
}
return strings.Join(settingsList, "\n |--")
}
// GetDNSSettings obtains DNS over TLS settings from environment variables using the params package.
func GetDNSSettings(params params.ParamsReader) (settings DNS, err error) {
settings.Enabled, err = params.GetDNSOverTLS()
if err != nil || !settings.Enabled {
return settings, err
}
settings.Providers, err = params.GetDNSOverTLSProviders()
if err != nil {
return settings, err
}
settings.AllowedHostnames, err = params.GetDNSUnblockedHostnames()
if err != nil {
return settings, err
}
settings.BlockMalicious, err = params.GetDNSMaliciousBlocking()
if err != nil {
return settings, err
}
settings.BlockSurveillance, err = params.GetDNSSurveillanceBlocking()
if err != nil {
return settings, err
}
settings.BlockAds, err = params.GetDNSAdsBlocking()
if err != nil {
return settings, err
}
settings.VerbosityLevel, err = params.GetDNSOverTLSVerbosity()
if err != nil {
return settings, err
}
settings.VerbosityDetailsLevel, err = params.GetDNSOverTLSVerbosityDetails()
if err != nil {
return settings, err
}
settings.ValidationLogLevel, err = params.GetDNSOverTLSValidationLogLevel()
if err != nil {
return settings, err
}
settings.PrivateAddresses = []string{ // TODO make env variable
"127.0.0.1/8",
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"169.254.0.0/16",
"::1/128",
"fc00::/7",
"fe80::/10",
"::ffff:0:0/96",
}
return settings, nil
}

View File

@@ -0,0 +1,34 @@
package settings
import (
"net"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/params"
)
// Firewall contains settings to customize the firewall operation
type Firewall struct {
AllowedSubnets []net.IPNet
}
func (f *Firewall) String() string {
var allowedSubnets []string
for _, net := range f.AllowedSubnets {
allowedSubnets = append(allowedSubnets, net.String())
}
settingsList := []string{
"Firewall settings:",
"Allowed subnets: " + strings.Join(allowedSubnets, ", "),
}
return strings.Join(settingsList, "\n |--")
}
// GetFirewallSettings obtains firewall settings from environment variables using the params package.
func GetFirewallSettings(params params.ParamsReader) (settings Firewall, err error) {
settings.AllowedSubnets, err = params.GetExtraSubnets()
if err != nil {
return settings, err
}
return settings, nil
}

View File

@@ -0,0 +1,30 @@
package settings
import (
"strings"
"github.com/qdm12/private-internet-access-docker/internal/models"
"github.com/qdm12/private-internet-access-docker/internal/params"
)
// OpenVPN contains settings to configure the OpenVPN client
type OpenVPN struct {
NetworkProtocol models.NetworkProtocol
}
// GetOpenVPNSettings obtains the OpenVPN settings using the params functions
func GetOpenVPNSettings(params params.ParamsReader) (settings OpenVPN, err error) {
settings.NetworkProtocol, err = params.GetNetworkProtocol()
if err != nil {
return settings, err
}
return settings, nil
}
func (o *OpenVPN) String() string {
settingsList := []string{
"OpenVPN settings:",
"Network protocol: " + string(o.NetworkProtocol),
}
return strings.Join(settingsList, "\n|--")
}

72
internal/settings/pia.go Normal file
View File

@@ -0,0 +1,72 @@
package settings
import (
"fmt"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/models"
"github.com/qdm12/private-internet-access-docker/internal/params"
)
// PIA contains the settings to connect to a PIA server
type PIA struct {
User string
Password string
Encryption models.PIAEncryption
Region models.PIARegion
PortForwarding PortForwarding
}
// PortForwarding contains settings for port forwarding
type PortForwarding struct {
Enabled bool
Filepath models.Filepath
}
func (p *PortForwarding) String() string {
if p.Enabled {
return fmt.Sprintf("on, saved in %s", p.Filepath)
}
return "off"
}
func (p *PIA) String() string {
settingsList := []string{
"PIA settings:",
"Region: " + string(p.Region),
"Encryption: " + string(p.Encryption),
"Port forwarding: " + p.PortForwarding.String(),
}
return strings.Join(settingsList, "\n |--")
}
// GetPIASettings obtains PIA settings from environment variables using the params package.
func GetPIASettings(params params.ParamsReader) (settings PIA, err error) {
settings.User, err = params.GetUser()
if err != nil {
return settings, err
}
settings.Password, err = params.GetPassword()
if err != nil {
return settings, err
}
settings.Encryption, err = params.GetPIAEncryption()
if err != nil {
return settings, err
}
settings.Region, err = params.GetPIARegion()
if err != nil {
return settings, err
}
settings.PortForwarding.Enabled, err = params.GetPortForwarding()
if err != nil {
return settings, err
}
if settings.PortForwarding.Enabled {
settings.PortForwarding.Filepath, err = params.GetPortForwardingStatusFilepath()
if err != nil {
return settings, err
}
}
return settings, nil
}

View File

@@ -0,0 +1,60 @@
package settings
import (
"strings"
"github.com/qdm12/private-internet-access-docker/internal/params"
)
// Settings contains all settings for the program to run
type Settings struct {
OpenVPN OpenVPN
PIA PIA
DNS DNS
Firewall Firewall
TinyProxy TinyProxy
ShadowSocks ShadowSocks
}
func (s *Settings) String() string {
return strings.Join([]string{
"Settings summary below:",
s.OpenVPN.String(),
s.PIA.String(),
s.DNS.String(),
s.Firewall.String(),
s.TinyProxy.String(),
s.ShadowSocks.String(),
"", // new line at the end
}, "\n")
}
// GetAllSettings obtains all settings for the program and returns an error as soon
// as an error is encountered reading them.
func GetAllSettings(params params.ParamsReader) (settings Settings, err error) {
settings.OpenVPN, err = GetOpenVPNSettings(params)
if err != nil {
return settings, err
}
settings.PIA, err = GetPIASettings(params)
if err != nil {
return settings, err
}
settings.DNS, err = GetDNSSettings(params)
if err != nil {
return settings, err
}
settings.Firewall, err = GetFirewallSettings(params)
if err != nil {
return settings, err
}
settings.TinyProxy, err = GetTinyProxySettings(params)
if err != nil {
return settings, err
}
settings.ShadowSocks, err = GetShadowSocksSettings(params)
if err != nil {
return settings, err
}
return settings, nil
}

View File

@@ -0,0 +1,48 @@
package settings
import (
"fmt"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/params"
)
// ShadowSocks contains settings to configure the Shadowsocks server
type ShadowSocks struct {
Enabled bool
Password string
Log bool
Port uint16
}
func (s *ShadowSocks) String() string {
if !s.Enabled {
return "ShadowSocks settings: disabled"
}
settingsList := []string{
"ShadowSocks settings:",
fmt.Sprintf("Port: %d", s.Port),
}
return strings.Join(settingsList, "\n |--")
}
// GetShadowSocksSettings obtains ShadowSocks settings from environment variables using the params package.
func GetShadowSocksSettings(params params.ParamsReader) (settings ShadowSocks, err error) {
settings.Enabled, err = params.GetShadowSocks()
if err != nil || !settings.Enabled {
return settings, err
}
settings.Port, err = params.GetShadowSocksPort()
if err != nil {
return settings, err
}
settings.Password, err = params.GetShadowSocksPassword()
if err != nil {
return settings, err
}
settings.Log, err = params.GetShadowSocksLog()
if err != nil {
return settings, err
}
return settings, nil
}

View File

@@ -0,0 +1,60 @@
package settings
import (
"fmt"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/models"
"github.com/qdm12/private-internet-access-docker/internal/params"
)
// TinyProxy contains settings to configure TinyProxy
type TinyProxy struct {
Enabled bool
User string
Password string
Port uint16
LogLevel models.TinyProxyLogLevel
}
func (t *TinyProxy) String() string {
if !t.Enabled {
return "TinyProxy settings: disabled"
}
auth := "disabled"
if t.User != "" {
auth = "enabled"
}
settingsList := []string{
"TinyProxy settings:",
fmt.Sprintf("Port: %d", t.Port),
"Authentication: " + auth,
"Log level: " + string(t.LogLevel),
}
return "TinyProxy settings:\n" + strings.Join(settingsList, "\n |--")
}
// GetTinyProxySettings obtains TinyProxy settings from environment variables using the params package.
func GetTinyProxySettings(params params.ParamsReader) (settings TinyProxy, err error) {
settings.Enabled, err = params.GetTinyProxy()
if err != nil || !settings.Enabled {
return settings, err
}
settings.User, err = params.GetTinyProxyUser()
if err != nil {
return settings, err
}
settings.Password, err = params.GetTinyProxyPassword()
if err != nil {
return settings, err
}
settings.Port, err = params.GetTinyProxyPort()
if err != nil {
return settings, err
}
settings.LogLevel, err = params.GetTinyProxyLog()
if err != nil {
return settings, err
}
return settings, nil
}

View File

@@ -0,0 +1,40 @@
package shadowsocks
import (
"fmt"
"io"
"strings"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) Start(server string, port uint16, password string, log bool) (stdout io.ReadCloser, err error) {
c.logger.Info("%s: starting shadowsocks server", logPrefix)
args := []string{
"-c", string(constants.ShadowsocksConf),
"-p", fmt.Sprintf("%d", port),
"-k", password,
}
if log {
args = append(args, "-v")
}
stdout, _, _, err = c.commander.Start("ss-server", args...)
return stdout, err
}
// Version obtains the version of the installed shadowsocks server
func (c *configurator) Version() (string, error) {
output, err := c.commander.Run("ss-server", "-h")
if err != nil {
return "", err
}
lines := strings.Split(output, "\n")
if len(lines) < 2 {
return "", fmt.Errorf("ss-server -h: not enough lines in %q", output)
}
words := strings.Fields(lines[1])
if len(words) < 2 {
return "", fmt.Errorf("ss-server -h: line 2 is too short: %q", lines[1])
}
return words[1], nil
}

View File

@@ -0,0 +1,49 @@
package shadowsocks
import (
"encoding/json"
"fmt"
"github.com/qdm12/golibs/files"
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) MakeConf(port uint16, password string, uid, gid int) (err error) {
c.logger.Info("%s: generating configuration file", logPrefix)
data := generateConf(port, password)
return c.fileManager.WriteToFile(
string(constants.ShadowsocksConf),
data,
files.FileOwnership(uid, gid),
files.FilePermissions(0400))
}
func generateConf(port uint16, password string) (data []byte) {
conf := struct {
Server string `json:"server"`
User string `json:"user"`
Method string `json:"method"`
Timeout uint `json:"timeout"`
FastOpen bool `json:"fast_open"`
Mode string `json:"mode"`
PortPassword map[string]string `json:"port_password"`
Workers uint `json:"workers"`
Interface string `json:"interface"`
Nameserver string `json:"nameserver"`
}{
Server: "0.0.0.0",
User: "nonrootuser",
Method: "chacha20-ietf-poly1305",
Timeout: 30,
FastOpen: false,
Mode: "tcp_and_udp",
PortPassword: map[string]string{
fmt.Sprintf("%d", port): password,
},
Workers: 2,
Interface: "tun",
Nameserver: "127.0.0.1",
}
data, _ = json.Marshal(conf)
return data
}

View File

@@ -0,0 +1,79 @@
package shadowsocks
import (
"fmt"
"testing"
filesMocks "github.com/qdm12/golibs/files/mocks"
loggingMocks "github.com/qdm12/golibs/logging/mocks"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
func Test_generateConf(t *testing.T) {
t.Parallel()
tests := map[string]struct {
port uint16
password string
data []byte
}{
"no data": {
data: []byte(`{"server":"0.0.0.0","user":"nonrootuser","method":"chacha20-ietf-poly1305","timeout":30,"fast_open":false,"mode":"tcp_and_udp","port_password":{"0":""},"workers":2,"interface":"tun","nameserver":"127.0.0.1"}`),
},
"data": {
port: 2000,
password: "abcde",
data: []byte(`{"server":"0.0.0.0","user":"nonrootuser","method":"chacha20-ietf-poly1305","timeout":30,"fast_open":false,"mode":"tcp_and_udp","port_password":{"2000":"abcde"},"workers":2,"interface":"tun","nameserver":"127.0.0.1"}`),
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
data := generateConf(tc.port, tc.password)
assert.Equal(t, tc.data, data)
})
}
}
func Test_MakeConf(t *testing.T) {
t.Parallel()
tests := map[string]struct {
writeErr error
err error
}{
"no write error": {},
"write error": {
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
logger := &loggingMocks.Logger{}
logger.On("Info", "%s: generating configuration file", logPrefix).Once()
fileManager := &filesMocks.FileManager{}
fileManager.On("WriteToFile",
string(constants.ShadowsocksConf),
[]byte(`{"server":"0.0.0.0","user":"nonrootuser","method":"chacha20-ietf-poly1305","timeout":30,"fast_open":false,"mode":"tcp_and_udp","port_password":{"2000":"abcde"},"workers":2,"interface":"tun","nameserver":"127.0.0.1"}`),
mock.AnythingOfType("files.WriteOptionSetter"),
mock.AnythingOfType("files.WriteOptionSetter"),
).
Return(tc.writeErr).Once()
c := &configurator{logger: logger, fileManager: fileManager}
err := c.MakeConf(2000, "abcde", 1000, 1001)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
logger.AssertExpectations(t)
fileManager.AssertExpectations(t)
})
}
}

View File

@@ -0,0 +1,27 @@
package shadowsocks
import (
"io"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
const logPrefix = "shadowsocks configurator"
type Configurator interface {
Version() (string, error)
MakeConf(port uint16, password string, uid, gid int) (err error)
Start(server string, port uint16, password string, log bool) (stdout io.ReadCloser, err error)
}
type configurator struct {
fileManager files.FileManager
logger logging.Logger
commander command.Commander
}
func NewConfigurator(fileManager files.FileManager, logger logging.Logger) Configurator {
return &configurator{fileManager, logger, command.NewCommander()}
}

55
internal/splash/splash.go Normal file
View File

@@ -0,0 +1,55 @@
package splash
import (
"fmt"
"strings"
"time"
"github.com/kyokomi/emoji"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/params"
)
func Splash(paramsReader params.ParamsReader) string {
version := paramsReader.GetVersion()
vcsRef := paramsReader.GetVcsRef()
buildDate := paramsReader.GetBuildDate()
lines := title()
lines = append(lines, "")
lines = append(lines, fmt.Sprintf("Running version %s built on %s (commit %s)", version, buildDate, vcsRef))
lines = append(lines, "")
lines = append(lines, annoucement()...)
lines = append(lines, "")
lines = append(lines, links()...)
return strings.Join(lines, "\n")
}
func title() []string {
return []string{
"=========================================",
"============= PIA container =============",
"========== An exquisite mix of ==========",
"==== OpenVPN, Unbound, DNS over TLS, ====",
"===== Shadowsocks, Tinyproxy and Go =====",
"=========================================",
"=== Made with " + emoji.Sprint(":heart:") + " by github.com/qdm12 ====",
"=========================================",
}
}
func annoucement() []string {
timestamp := time.Now().UnixNano() / 1000000000
if timestamp < constants.AnnoucementExpiration {
return []string{emoji.Sprint(":rotating_light: ") + constants.Annoucement}
}
return nil
}
func links() []string {
return []string{
emoji.Sprint(":wrench: ") + "Need help? " + constants.IssueLink,
emoji.Sprint(":computer: ") + "Email? quentin.mcgaw@gmail.com",
emoji.Sprint(":coffee: ") + "Slack? Join from the Slack button on Github",
emoji.Sprint(":money_with_wings: ") + "Help me? https://github.com/sponsors/qdm12",
}
}

View File

@@ -0,0 +1,26 @@
package tinyproxy
import (
"fmt"
"io"
"strings"
)
func (c *configurator) Start() (stdout io.ReadCloser, err error) {
c.logger.Info("%s: starting tinyproxy server", logPrefix)
stdout, _, _, err = c.commander.Start("tinyproxy", "-d")
return stdout, err
}
// Version obtains the version of the installed Tinyproxy server
func (c *configurator) Version() (string, error) {
output, err := c.commander.Run("tinyproxy", "-v")
if err != nil {
return "", err
}
words := strings.Fields(output)
if len(words) < 2 {
return "", fmt.Errorf("tinyproxy -v: output is too short: %q", output)
}
return words[1], nil
}

View File

@@ -0,0 +1,44 @@
package tinyproxy
import (
"fmt"
"github.com/qdm12/golibs/files"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
func (c *configurator) MakeConf(logLevel models.TinyProxyLogLevel, port uint16, user, password string, uid, gid int) error {
c.logger.Info("%s: generating tinyproxy configuration file", logPrefix)
lines := generateConf(logLevel, port, user, password)
return c.fileManager.WriteLinesToFile(string(constants.TinyProxyConf),
lines,
files.FileOwnership(uid, gid),
files.FilePermissions(0400))
}
func generateConf(logLevel models.TinyProxyLogLevel, port uint16, user, password string) (lines []string) {
confMapping := map[string]string{
"User": "nonrootuser",
"Group": "tinyproxy",
"Port": fmt.Sprintf("%d", port),
"Timeout": "600",
"DefaultErrorFile": "/usr/share/tinyproxy/default.html",
"MaxClients": "100",
"MinSpareServers": "5",
"MaxSpareServers": "20",
"StartServers": "10",
"MaxRequestsPerChild": "0",
"DisableViaHeader": "Yes",
"LogLevel": string(logLevel),
// "StatFile": "/usr/share/tinyproxy/stats.html",
}
if len(user) > 0 {
confMapping["BasicAuth"] = fmt.Sprintf("%s %s", user, password)
}
for k, v := range confMapping {
line := fmt.Sprintf("%s %s", k, v)
lines = append(lines, line)
}
return lines
}

View File

@@ -0,0 +1,28 @@
package tinyproxy
import (
"io"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
const logPrefix = "tinyproxy configurator"
type Configurator interface {
Version() (string, error)
MakeConf(logLevel models.TinyProxyLogLevel, port uint16, user, password string, uid, gid int) error
Start() (stdout io.ReadCloser, err error)
}
type configurator struct {
fileManager files.FileManager
logger logging.Logger
commander command.Commander
}
func NewConfigurator(fileManager files.FileManager, logger logging.Logger) Configurator {
return &configurator{fileManager, logger, command.NewCommander()}
}