Rewrite of the entrypoint in Golang (#71)
- General improvements
- Parallel download of only needed files at start
- Prettier console output with all streams merged (openvpn, unbound, shadowsocks etc.)
- Simplified Docker final image
- Faster bootup
- DNS over TLS
- Finer grain blocking at DNS level: malicious, ads and surveillance
- Choose your DNS over TLS providers
- Ability to use multiple DNS over TLS providers for DNS split horizon
- Environment variables for DNS logging
- DNS block lists needed are downloaded and built automatically at start, in parallel
- PIA
- A random region is selected if the REGION parameter is left empty (thanks @rorph for your PR)
- Routing and iptables adjusted so it can work as a Kubernetes pod sidecar (thanks @rorph for your PR)
This commit is contained in:
63
internal/constants/dns.go
Normal file
63
internal/constants/dns.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package constants
|
||||
|
||||
import (
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
// Cloudflare is a DNS over TLS provider
|
||||
Cloudflare models.DNSProvider = "cloudflare"
|
||||
// Google is a DNS over TLS provider
|
||||
Google models.DNSProvider = "google"
|
||||
// Quad9 is a DNS over TLS provider
|
||||
Quad9 models.DNSProvider = "quad9"
|
||||
// Quadrant is a DNS over TLS provider
|
||||
Quadrant models.DNSProvider = "quadrant"
|
||||
// CleanBrowsing is a DNS over TLS provider
|
||||
CleanBrowsing models.DNSProvider = "cleanbrowsing"
|
||||
// SecureDNS is a DNS over TLS provider
|
||||
SecureDNS models.DNSProvider = "securedns"
|
||||
// LibreDNS is a DNS over TLS provider
|
||||
LibreDNS models.DNSProvider = "libredns"
|
||||
)
|
||||
|
||||
const (
|
||||
CloudflareAddress1 models.DNSForwardAddress = "1.1.1.1@853#cloudflare-dns.com"
|
||||
CloudflareAddress2 models.DNSForwardAddress = "1.0.0.1@853#cloudflare-dns.com"
|
||||
GoogleAddress1 models.DNSForwardAddress = "8.8.8.8@853#dns.google"
|
||||
GoogleAddress2 models.DNSForwardAddress = "8.8.4.4@853#dns.google"
|
||||
Quad9Address1 models.DNSForwardAddress = "9.9.9.9@853#dns.quad9.net"
|
||||
Quad9Address2 models.DNSForwardAddress = "149.112.112.112@853#dns.quad9.net"
|
||||
QuadrantAddress models.DNSForwardAddress = "12.159.2.159@853#dns-tls.qis.io"
|
||||
CleanBrowsingAddress1 models.DNSForwardAddress = "185.228.168.9@853#security-filter-dns.cleanbrowsing.org"
|
||||
CleanBrowsingAddress2 models.DNSForwardAddress = "185.228.169.9@853#security-filter-dns.cleanbrowsing.org"
|
||||
SecureDNSAddress models.DNSForwardAddress = "146.185.167.43@853#dot.securedns.eu"
|
||||
LibreDNSAddress models.DNSForwardAddress = "116.203.115.192@853#dot.libredns.gr"
|
||||
)
|
||||
|
||||
var DNSAddressesMapping = map[models.DNSProvider][]models.DNSForwardAddress{
|
||||
Cloudflare: []models.DNSForwardAddress{CloudflareAddress1, CloudflareAddress2},
|
||||
Google: []models.DNSForwardAddress{GoogleAddress1, GoogleAddress2},
|
||||
Quad9: []models.DNSForwardAddress{Quad9Address1, Quad9Address2},
|
||||
Quadrant: []models.DNSForwardAddress{QuadrantAddress},
|
||||
CleanBrowsing: []models.DNSForwardAddress{CleanBrowsingAddress1, CleanBrowsingAddress2},
|
||||
SecureDNS: []models.DNSForwardAddress{SecureDNSAddress},
|
||||
LibreDNS: []models.DNSForwardAddress{LibreDNSAddress},
|
||||
}
|
||||
|
||||
// Block lists URLs
|
||||
const (
|
||||
AdsBlockListHostnamesURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated"
|
||||
AdsBlockListIPsURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated"
|
||||
MaliciousBlockListHostnamesURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/malicious-hostnames.updated"
|
||||
MaliciousBlockListIPsURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/malicious-ips.updated"
|
||||
SurveillanceBlockListHostnamesURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated"
|
||||
SurveillanceBlockListIPsURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated"
|
||||
)
|
||||
|
||||
// DNS certificates to fetch
|
||||
// TODO obtain from source directly, see qdm12/updated)
|
||||
const (
|
||||
NamedRootURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/named.root.updated"
|
||||
RootKeyURL models.URL = "https://raw.githubusercontent.com/qdm12/files/master/root.key.updated"
|
||||
)
|
||||
10
internal/constants/openvpn.go
Normal file
10
internal/constants/openvpn.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package constants
|
||||
|
||||
import (
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
TUN models.VPNDevice = "tun0"
|
||||
TAP models.VPNDevice = "tap0"
|
||||
)
|
||||
28
internal/constants/paths.go
Normal file
28
internal/constants/paths.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package constants
|
||||
|
||||
import (
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
// UnboundConf is the file path to the Unbound configuration file
|
||||
UnboundConf models.Filepath = "/etc/unbound/unbound.conf"
|
||||
// ResolvConf is the file path to the system resolv.conf file
|
||||
ResolvConf models.Filepath = "/etc/resolv.conf"
|
||||
// OpenVPNAuthConf is the file path to the OpenVPN auth file
|
||||
OpenVPNAuthConf models.Filepath = "/etc/openvpn/auth.conf"
|
||||
// OpenVPNConf is the file path to the OpenVPN client configuration file
|
||||
OpenVPNConf models.Filepath = "/etc/openvpn/target.ovpn"
|
||||
// TunnelDevice is the file path to tun device
|
||||
TunnelDevice models.Filepath = "/dev/net/tun"
|
||||
// NetRoute is the path to the file containing information on the network route
|
||||
NetRoute models.Filepath = "/proc/net/route"
|
||||
// TinyProxyConf is the filepath to the tinyproxy configuration file
|
||||
TinyProxyConf models.Filepath = "/etc/tinyproxy/tinyproxy.conf"
|
||||
// ShadowsocksConf is the filepath to the shadowsocks configuration file
|
||||
ShadowsocksConf models.Filepath = "/etc/shadowsocks.json"
|
||||
// RootHints is the filepath to the root.hints file used by Unbound
|
||||
RootHints models.Filepath = "/etc/unbound/root.hints"
|
||||
// RootKey is the filepath to the root.key file used by Unbound
|
||||
RootKey models.Filepath = "/etc/unbound/root.key"
|
||||
)
|
||||
70
internal/constants/pia.go
Normal file
70
internal/constants/pia.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package constants
|
||||
|
||||
import (
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
// PIAEncryptionNormal is the normal level of encryption for communication with PIA servers
|
||||
PIAEncryptionNormal models.PIAEncryption = "normal"
|
||||
// PIAEncryptionStrong is the strong level of encryption for communication with PIA servers
|
||||
PIAEncryptionStrong models.PIAEncryption = "strong"
|
||||
)
|
||||
|
||||
const (
|
||||
AUMelbourne models.PIARegion = "AU Melbourne"
|
||||
AUPerth models.PIARegion = "AU Perth"
|
||||
AUSydney models.PIARegion = "AU Sydney"
|
||||
Austria models.PIARegion = "Austria"
|
||||
Belgium models.PIARegion = "Belgium"
|
||||
CAMontreal models.PIARegion = "CA Montreal"
|
||||
CAToronto models.PIARegion = "CA Toronto"
|
||||
CAVancouver models.PIARegion = "CA Vancouver"
|
||||
CzechRepublic models.PIARegion = "Czech Republic"
|
||||
DEBerlin models.PIARegion = "DE Berlin"
|
||||
DEFrankfurt models.PIARegion = "DE Frankfurt"
|
||||
Denmark models.PIARegion = "Denmark"
|
||||
Finland models.PIARegion = "Finland"
|
||||
France models.PIARegion = "France"
|
||||
HongKong models.PIARegion = "Hong Kong"
|
||||
Hungary models.PIARegion = "Hungary"
|
||||
India models.PIARegion = "India"
|
||||
Ireland models.PIARegion = "Ireland"
|
||||
Israel models.PIARegion = "Israel"
|
||||
Italy models.PIARegion = "Italy"
|
||||
Japan models.PIARegion = "Japan"
|
||||
Luxembourg models.PIARegion = "Luxembourg"
|
||||
Mexico models.PIARegion = "Mexico"
|
||||
Netherlands models.PIARegion = "Netherlands"
|
||||
NewZealand models.PIARegion = "New Zealand"
|
||||
Norway models.PIARegion = "Norway"
|
||||
Poland models.PIARegion = "Poland"
|
||||
Romania models.PIARegion = "Romania"
|
||||
Singapore models.PIARegion = "Singapore"
|
||||
Spain models.PIARegion = "Spain"
|
||||
Sweden models.PIARegion = "Sweden"
|
||||
Switzerland models.PIARegion = "Switzerland"
|
||||
UAE models.PIARegion = "UAE"
|
||||
UKLondon models.PIARegion = "UK London"
|
||||
UKManchester models.PIARegion = "UK Manchester"
|
||||
UKSouthampton models.PIARegion = "UK Southampton"
|
||||
USAtlanta models.PIARegion = "US Atlanta"
|
||||
USCalifornia models.PIARegion = "US California"
|
||||
USChicago models.PIARegion = "US Chicago"
|
||||
USDenver models.PIARegion = "US Denver"
|
||||
USEast models.PIARegion = "US East"
|
||||
USFlorida models.PIARegion = "US Florida"
|
||||
USHouston models.PIARegion = "US Houston"
|
||||
USLasVegas models.PIARegion = "US Las Vegas"
|
||||
USNewYorkCity models.PIARegion = "US New York City"
|
||||
USSeattle models.PIARegion = "US Seattle"
|
||||
USSiliconValley models.PIARegion = "US Silicon Valley"
|
||||
USTexas models.PIARegion = "US Texas"
|
||||
USWashingtonDC models.PIARegion = "US Washington DC"
|
||||
USWest models.PIARegion = "US West"
|
||||
)
|
||||
|
||||
const (
|
||||
PIAOpenVPNURL models.URL = "https://www.privateinternetaccess.com/openvpn"
|
||||
PIAPortForwardURL models.URL = "http://209.222.18.222:2000"
|
||||
)
|
||||
13
internal/constants/splash.go
Normal file
13
internal/constants/splash.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package constants
|
||||
|
||||
const (
|
||||
// Annoucement is a message annoucement
|
||||
Annoucement = "Total rewrite in Go with many new features"
|
||||
// AnnoucementExpiration is the expiration time of the annoucement in unix timestamp
|
||||
AnnoucementExpiration = 1582761600
|
||||
)
|
||||
|
||||
const (
|
||||
// IssueLink is the link for users to use to create issues
|
||||
IssueLink = "https://github.com/qdm12/private-internet-access-docker/issues/new"
|
||||
)
|
||||
16
internal/constants/tinyproxy.go
Normal file
16
internal/constants/tinyproxy.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package constants
|
||||
|
||||
import (
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
// TinyProxyInfoLevel is the info log level for TinyProxy
|
||||
TinyProxyInfoLevel models.TinyProxyLogLevel = "Info"
|
||||
// TinyProxyWarnLevel is the warning log level for TinyProxy
|
||||
TinyProxyWarnLevel models.TinyProxyLogLevel = "Warning"
|
||||
// TinyProxyErrorLevel is the error log level for TinyProxy
|
||||
TinyProxyErrorLevel models.TinyProxyLogLevel = "Error"
|
||||
// TinyProxyCriticalLevel is the critical log level for TinyProxy
|
||||
TinyProxyCriticalLevel models.TinyProxyLogLevel = "Critical"
|
||||
)
|
||||
21
internal/constants/vpn.go
Normal file
21
internal/constants/vpn.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package constants
|
||||
|
||||
import (
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
// PrivateInternetAccess is a VPN provider
|
||||
PrivateInternetAccess models.VPNProvider = "private internet access"
|
||||
// Mullvad is a VPN provider
|
||||
Mullvad models.VPNProvider = "mullvad"
|
||||
// Windscribe is a VPN provider
|
||||
Windscribe models.VPNProvider = "windscribe"
|
||||
)
|
||||
|
||||
const (
|
||||
// TCP is a network protocol (reliable and slower than UDP)
|
||||
TCP models.NetworkProtocol = "tcp"
|
||||
// UDP is a network protocol (unreliable and faster than TCP)
|
||||
UDP models.NetworkProtocol = "udp"
|
||||
)
|
||||
40
internal/dns/command.go
Normal file
40
internal/dns/command.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) Start(verbosityDetailsLevel uint8) (stdout io.ReadCloser, err error) {
|
||||
c.logger.Info("%s: starting unbound", logPrefix)
|
||||
args := []string{"-d", "-c", string(constants.UnboundConf)}
|
||||
if verbosityDetailsLevel > 0 {
|
||||
args = append(args, "-"+strings.Repeat("v", int(verbosityDetailsLevel)))
|
||||
}
|
||||
// Only logs to stderr
|
||||
_, stdout, _, err = c.commander.Start("unbound", args...)
|
||||
return stdout, err
|
||||
}
|
||||
|
||||
func (c *configurator) Version() (version string, err error) {
|
||||
output, err := c.commander.Run("unbound", "-V")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unbound version: %w", err)
|
||||
}
|
||||
for _, line := range strings.Split(output, "\n") {
|
||||
if strings.Contains(line, "Version ") {
|
||||
words := strings.Fields(line)
|
||||
if len(words) < 2 {
|
||||
continue
|
||||
}
|
||||
version = words[1]
|
||||
}
|
||||
}
|
||||
if version == "" {
|
||||
return "", fmt.Errorf("unbound version was not found in %q", output)
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
69
internal/dns/command_test.go
Normal file
69
internal/dns/command_test.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
commandMocks "github.com/qdm12/golibs/command/mocks"
|
||||
loggingMocks "github.com/qdm12/golibs/logging/mocks"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func Test_Start(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := &loggingMocks.Logger{}
|
||||
logger.On("Info", "%s: starting unbound", logPrefix).Once()
|
||||
commander := &commandMocks.Commander{}
|
||||
commander.On("Start", "unbound", "-d", "-c", string(constants.UnboundConf), "-vv").
|
||||
Return(nil, nil, nil, nil).Once()
|
||||
c := &configurator{commander: commander, logger: logger}
|
||||
stdout, err := c.Start(2)
|
||||
assert.Nil(t, stdout)
|
||||
assert.NoError(t, err)
|
||||
logger.AssertExpectations(t)
|
||||
commander.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func Test_Version(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
runOutput string
|
||||
runErr error
|
||||
version string
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
err: fmt.Errorf(`unbound version was not found in ""`),
|
||||
},
|
||||
"2 lines with version": {
|
||||
runOutput: "Version \nVersion 1.0-a hello\n",
|
||||
version: "1.0-a",
|
||||
},
|
||||
"run error": {
|
||||
runErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("unbound version: error"),
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
commander := &commandMocks.Commander{}
|
||||
commander.On("Run", "unbound", "-V").
|
||||
Return(tc.runOutput, tc.runErr).Once()
|
||||
c := &configurator{commander: commander}
|
||||
version, err := c.Version()
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.version, version)
|
||||
commander.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
280
internal/dns/conf.go
Normal file
280
internal/dns/conf.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/settings"
|
||||
)
|
||||
|
||||
func (c *configurator) MakeUnboundConf(settings settings.DNS, uid, gid int) (err error) {
|
||||
c.logger.Info("%s: generating Unbound configuration", logPrefix)
|
||||
lines, warnings, err := generateUnboundConf(settings, c.client, c.logger)
|
||||
for _, warning := range warnings {
|
||||
c.logger.Warn(warning)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.fileManager.WriteLinesToFile(
|
||||
string(constants.UnboundConf),
|
||||
lines,
|
||||
files.FileOwnership(uid, gid),
|
||||
files.FilePermissions(0400))
|
||||
}
|
||||
|
||||
// MakeUnboundConf generates an Unbound configuration from the user provided settings
|
||||
func generateUnboundConf(settings settings.DNS, client network.Client, logger logging.Logger) (lines []string, warnings []error, err error) {
|
||||
serverSection := map[string]string{
|
||||
// Logging
|
||||
"verbosity": fmt.Sprintf("%d", settings.VerbosityLevel),
|
||||
"val-log-level": fmt.Sprintf("%d", settings.ValidationLogLevel),
|
||||
"use-syslog": "no",
|
||||
// Performance
|
||||
"num-threads": "1",
|
||||
"prefetch": "yes",
|
||||
"prefetch-key": "yes",
|
||||
"key-cache-size": "16m",
|
||||
"key-cache-slabs": "4",
|
||||
"msg-cache-size": "4m",
|
||||
"msg-cache-slabs": "4",
|
||||
"rrset-cache-size": "4m",
|
||||
"rrset-cache-slabs": "4",
|
||||
"cache-min-ttl": "3600",
|
||||
"cache-max-ttl": "9000",
|
||||
// Privacy
|
||||
"rrset-roundrobin": "yes",
|
||||
"hide-identity": "yes",
|
||||
"hide-version": "yes",
|
||||
// Security
|
||||
"tls-cert-bundle": "\"/etc/ssl/certs/ca-certificates.crt\"",
|
||||
"root-hints": fmt.Sprintf("%q", constants.RootHints),
|
||||
"trust-anchor-file": fmt.Sprintf("%q", constants.RootKey),
|
||||
"harden-below-nxdomain": "yes",
|
||||
"harden-referral-path": "yes",
|
||||
"harden-algo-downgrade": "yes",
|
||||
// Network
|
||||
"do-ip4": "yes",
|
||||
"do-ip6": "no",
|
||||
"interface": "127.0.0.1",
|
||||
"port": "53",
|
||||
// Other
|
||||
"username": "\"nonrootuser\"",
|
||||
}
|
||||
|
||||
// Block lists
|
||||
hostnamesLines, ipsLines, warnings := buildBlocked(client,
|
||||
settings.BlockMalicious, settings.BlockAds, settings.BlockSurveillance,
|
||||
settings.AllowedHostnames, settings.PrivateAddresses,
|
||||
)
|
||||
logger.Info("%s: %d hostnames blocked overall", logPrefix, len(hostnamesLines))
|
||||
logger.Info("%s: %d IP addresses blocked overall", logPrefix, len(ipsLines))
|
||||
sort.Slice(hostnamesLines, func(i, j int) bool { // for unit tests really
|
||||
return hostnamesLines[i] < hostnamesLines[j]
|
||||
})
|
||||
sort.Slice(ipsLines, func(i, j int) bool { // for unit tests really
|
||||
return ipsLines[i] < ipsLines[j]
|
||||
})
|
||||
|
||||
// Server
|
||||
lines = append(lines, "server:")
|
||||
var serverLines []string
|
||||
for k, v := range serverSection {
|
||||
serverLines = append(serverLines, " "+k+": "+v)
|
||||
}
|
||||
sort.Slice(serverLines, func(i, j int) bool {
|
||||
return serverLines[i] < serverLines[j]
|
||||
})
|
||||
lines = append(lines, serverLines...)
|
||||
lines = append(lines, hostnamesLines...)
|
||||
lines = append(lines, ipsLines...)
|
||||
|
||||
// Forward zone
|
||||
lines = append(lines, "forward-zone:")
|
||||
forwardZoneSection := map[string]string{
|
||||
"name": "\".\"",
|
||||
"forward-tls-upstream": "yes",
|
||||
}
|
||||
var forwardZoneLines []string
|
||||
for k, v := range forwardZoneSection {
|
||||
forwardZoneLines = append(forwardZoneLines, " "+k+": "+v)
|
||||
}
|
||||
sort.Slice(forwardZoneLines, func(i, j int) bool {
|
||||
return forwardZoneLines[i] < forwardZoneLines[j]
|
||||
})
|
||||
for _, provider := range settings.Providers {
|
||||
forwardAddresses, ok := constants.DNSAddressesMapping[provider]
|
||||
if !ok || len(forwardAddresses) == 0 {
|
||||
return nil, warnings, fmt.Errorf("DNS provider %q does not have any matching forward addresses", provider)
|
||||
}
|
||||
for _, forwardAddress := range forwardAddresses {
|
||||
forwardZoneLines = append(forwardZoneLines, fmt.Sprintf(" forward-addr: %s", forwardAddress))
|
||||
}
|
||||
}
|
||||
lines = append(lines, forwardZoneLines...)
|
||||
return lines, warnings, nil
|
||||
}
|
||||
|
||||
func buildBlocked(client network.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||
allowedHostnames, privateAddresses []string) (hostnamesLines, ipsLines []string, errs []error) {
|
||||
chHostnames := make(chan []string)
|
||||
chIPs := make(chan []string)
|
||||
chErrors := make(chan []error)
|
||||
go func() {
|
||||
lines, errs := buildBlockedHostnames(client, blockMalicious, blockAds, blockSurveillance, allowedHostnames)
|
||||
chHostnames <- lines
|
||||
chErrors <- errs
|
||||
}()
|
||||
go func() {
|
||||
lines, errs := buildBlockedIPs(client, blockMalicious, blockAds, blockSurveillance, privateAddresses)
|
||||
chIPs <- lines
|
||||
chErrors <- errs
|
||||
}()
|
||||
n := 2
|
||||
for n > 0 {
|
||||
select {
|
||||
case lines := <-chHostnames:
|
||||
hostnamesLines = append(hostnamesLines, lines...)
|
||||
case lines := <-chIPs:
|
||||
ipsLines = append(ipsLines, lines...)
|
||||
case routineErrs := <-chErrors:
|
||||
errs = append(errs, routineErrs...)
|
||||
n--
|
||||
}
|
||||
}
|
||||
return hostnamesLines, ipsLines, errs
|
||||
}
|
||||
|
||||
func getList(client network.Client, URL string) (results []string, err error) {
|
||||
content, status, err := client.GetContent(URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != 200 {
|
||||
return nil, fmt.Errorf("HTTP status code is %d and not 200", status)
|
||||
}
|
||||
results = strings.Split(string(content), "\n")
|
||||
|
||||
// remove empty lines
|
||||
last := len(results) - 1
|
||||
for i := range results {
|
||||
if len(results[i]) == 0 {
|
||||
results[i] = results[last]
|
||||
last--
|
||||
}
|
||||
}
|
||||
results = results[:last+1]
|
||||
|
||||
if len(results) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func buildBlockedHostnames(client network.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||
allowedHostnames []string) (lines []string, errs []error) {
|
||||
chResults := make(chan []string)
|
||||
chError := make(chan error)
|
||||
listsLeftToFetch := 0
|
||||
if blockMalicious {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(client, string(constants.MaliciousBlockListHostnamesURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
}
|
||||
if blockAds {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(client, string(constants.AdsBlockListHostnamesURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
}
|
||||
if blockSurveillance {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(client, string(constants.SurveillanceBlockListHostnamesURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
}
|
||||
uniqueResults := make(map[string]struct{})
|
||||
for listsLeftToFetch > 0 {
|
||||
select {
|
||||
case results := <-chResults:
|
||||
for _, result := range results {
|
||||
uniqueResults[result] = struct{}{}
|
||||
}
|
||||
case err := <-chError:
|
||||
listsLeftToFetch--
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, allowedHostname := range allowedHostnames {
|
||||
delete(uniqueResults, allowedHostname)
|
||||
}
|
||||
for result := range uniqueResults {
|
||||
lines = append(lines, " local-zone: \""+result+"\" static")
|
||||
}
|
||||
return lines, errs
|
||||
}
|
||||
|
||||
func buildBlockedIPs(client network.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||
privateAddresses []string) (lines []string, errs []error) {
|
||||
chResults := make(chan []string)
|
||||
chError := make(chan error)
|
||||
listsLeftToFetch := 0
|
||||
if blockMalicious {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(client, string(constants.MaliciousBlockListIPsURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
}
|
||||
if blockAds {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(client, string(constants.AdsBlockListIPsURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
}
|
||||
if blockSurveillance {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(client, string(constants.SurveillanceBlockListIPsURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
}
|
||||
uniqueResults := make(map[string]struct{})
|
||||
for listsLeftToFetch > 0 {
|
||||
select {
|
||||
case results := <-chResults:
|
||||
for _, result := range results {
|
||||
uniqueResults[result] = struct{}{}
|
||||
}
|
||||
case err := <-chError:
|
||||
listsLeftToFetch--
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, privateAddress := range privateAddresses {
|
||||
uniqueResults[privateAddress] = struct{}{}
|
||||
}
|
||||
for result := range uniqueResults {
|
||||
lines = append(lines, " private-address: "+result)
|
||||
}
|
||||
return lines, errs
|
||||
}
|
||||
518
internal/dns/conf_test.go
Normal file
518
internal/dns/conf_test.go
Normal file
@@ -0,0 +1,518 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network/mocks"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/settings"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_generateUnboundConf(t *testing.T) {
|
||||
t.Parallel()
|
||||
settings := settings.DNS{
|
||||
Providers: []models.DNSProvider{constants.Cloudflare, constants.Quad9},
|
||||
AllowedHostnames: []string{"a"},
|
||||
PrivateAddresses: []string{"9.9.9.9"},
|
||||
BlockMalicious: true,
|
||||
BlockSurveillance: false,
|
||||
BlockAds: false,
|
||||
VerbosityLevel: 2,
|
||||
ValidationLogLevel: 3,
|
||||
}
|
||||
client := &mocks.Client{}
|
||||
client.On("GetContent", string(constants.MaliciousBlockListHostnamesURL)).
|
||||
Return([]byte("b\na\nc"), 200, nil).Once()
|
||||
client.On("GetContent", string(constants.MaliciousBlockListIPsURL)).
|
||||
Return([]byte("c\nd\n"), 200, nil).Once()
|
||||
emptyLogger, err := logging.NewEmptyLogger()
|
||||
require.NoError(t, err)
|
||||
lines, warnings, err := generateUnboundConf(settings, client, emptyLogger)
|
||||
require.Len(t, warnings, 0)
|
||||
require.NoError(t, err)
|
||||
client.AssertExpectations(t)
|
||||
expected := `
|
||||
server:
|
||||
cache-max-ttl: 9000
|
||||
cache-min-ttl: 3600
|
||||
do-ip4: yes
|
||||
do-ip6: no
|
||||
harden-algo-downgrade: yes
|
||||
harden-below-nxdomain: yes
|
||||
harden-referral-path: yes
|
||||
hide-identity: yes
|
||||
hide-version: yes
|
||||
interface: 127.0.0.1
|
||||
key-cache-size: 16m
|
||||
key-cache-slabs: 4
|
||||
msg-cache-size: 4m
|
||||
msg-cache-slabs: 4
|
||||
num-threads: 1
|
||||
port: 53
|
||||
prefetch-key: yes
|
||||
prefetch: yes
|
||||
root-hints: "/etc/unbound/root.hints"
|
||||
rrset-cache-size: 4m
|
||||
rrset-cache-slabs: 4
|
||||
rrset-roundrobin: yes
|
||||
tls-cert-bundle: "/etc/ssl/certs/ca-certificates.crt"
|
||||
trust-anchor-file: "/etc/unbound/root.key"
|
||||
use-syslog: no
|
||||
username: "nonrootuser"
|
||||
val-log-level: 3
|
||||
verbosity: 2
|
||||
local-zone: "b" static
|
||||
local-zone: "c" static
|
||||
private-address: 9.9.9.9
|
||||
private-address: c
|
||||
private-address: d
|
||||
forward-zone:
|
||||
forward-tls-upstream: yes
|
||||
name: "."
|
||||
forward-addr: 1.1.1.1@853#cloudflare-dns.com
|
||||
forward-addr: 1.0.0.1@853#cloudflare-dns.com
|
||||
forward-addr: 9.9.9.9@853#dns.quad9.net
|
||||
forward-addr: 149.112.112.112@853#dns.quad9.net`
|
||||
assert.Equal(t, expected, "\n"+strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
func Test_buildBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
type blockParams struct {
|
||||
blocked bool
|
||||
content []byte
|
||||
clientErr error
|
||||
}
|
||||
tests := map[string]struct {
|
||||
malicious blockParams
|
||||
ads blockParams
|
||||
surveillance blockParams
|
||||
allowedHostnames []string
|
||||
privateAddresses []string
|
||||
hostnamesLines []string
|
||||
ipsLines []string
|
||||
errsString []string
|
||||
}{
|
||||
"none blocked": {},
|
||||
"all blocked without lists": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
},
|
||||
},
|
||||
"all blocked with lists": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("malicious"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("ads"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("surveillance"),
|
||||
},
|
||||
hostnamesLines: []string{
|
||||
" local-zone: \"ads\" static",
|
||||
" local-zone: \"malicious\" static",
|
||||
" local-zone: \"surveillance\" static"},
|
||||
ipsLines: []string{
|
||||
" private-address: ads",
|
||||
" private-address: malicious",
|
||||
" private-address: surveillance"},
|
||||
},
|
||||
"all blocked with allowed hostnames": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("malicious"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("ads"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("surveillance"),
|
||||
},
|
||||
allowedHostnames: []string{"ads"},
|
||||
hostnamesLines: []string{
|
||||
" local-zone: \"malicious\" static",
|
||||
" local-zone: \"surveillance\" static"},
|
||||
ipsLines: []string{
|
||||
" private-address: ads",
|
||||
" private-address: malicious",
|
||||
" private-address: surveillance"},
|
||||
},
|
||||
"all blocked with private addresses": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("malicious"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("ads"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("surveillance"),
|
||||
},
|
||||
privateAddresses: []string{"ads", "192.100.1.5"},
|
||||
hostnamesLines: []string{
|
||||
" local-zone: \"ads\" static",
|
||||
" local-zone: \"malicious\" static",
|
||||
" local-zone: \"surveillance\" static"},
|
||||
ipsLines: []string{
|
||||
" private-address: 192.100.1.5",
|
||||
" private-address: ads",
|
||||
" private-address: malicious",
|
||||
" private-address: surveillance"},
|
||||
},
|
||||
"all blocked with lists and one error": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("malicious"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("ads"),
|
||||
clientErr: fmt.Errorf("ads error"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("surveillance"),
|
||||
},
|
||||
hostnamesLines: []string{
|
||||
" local-zone: \"malicious\" static",
|
||||
" local-zone: \"surveillance\" static"},
|
||||
ipsLines: []string{
|
||||
" private-address: malicious",
|
||||
" private-address: surveillance"},
|
||||
errsString: []string{"ads error", "ads error"},
|
||||
},
|
||||
"all blocked with errors": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
clientErr: fmt.Errorf("malicious"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
clientErr: fmt.Errorf("ads"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
clientErr: fmt.Errorf("surveillance"),
|
||||
},
|
||||
errsString: []string{"malicious", "malicious", "ads", "ads", "surveillance", "surveillance"},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := &mocks.Client{}
|
||||
if tc.malicious.blocked {
|
||||
client.On("GetContent", string(constants.MaliciousBlockListHostnamesURL)).
|
||||
Return(tc.malicious.content, 200, tc.malicious.clientErr).Once()
|
||||
client.On("GetContent", string(constants.MaliciousBlockListIPsURL)).
|
||||
Return(tc.malicious.content, 200, tc.malicious.clientErr).Once()
|
||||
}
|
||||
if tc.ads.blocked {
|
||||
client.On("GetContent", string(constants.AdsBlockListHostnamesURL)).
|
||||
Return(tc.ads.content, 200, tc.ads.clientErr).Once()
|
||||
client.On("GetContent", string(constants.AdsBlockListIPsURL)).
|
||||
Return(tc.ads.content, 200, tc.ads.clientErr).Once()
|
||||
}
|
||||
if tc.surveillance.blocked {
|
||||
client.On("GetContent", string(constants.SurveillanceBlockListHostnamesURL)).
|
||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Once()
|
||||
client.On("GetContent", string(constants.SurveillanceBlockListIPsURL)).
|
||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Once()
|
||||
}
|
||||
hostnamesLines, ipsLines, errs := buildBlocked(client, tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
|
||||
tc.allowedHostnames, tc.privateAddresses)
|
||||
var errsString []string
|
||||
for _, err := range errs {
|
||||
errsString = append(errsString, err.Error())
|
||||
}
|
||||
assert.ElementsMatch(t, tc.errsString, errsString)
|
||||
assert.ElementsMatch(t, tc.hostnamesLines, hostnamesLines)
|
||||
assert.ElementsMatch(t, tc.ipsLines, ipsLines)
|
||||
client.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getList(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
content []byte
|
||||
status int
|
||||
clientErr error
|
||||
results []string
|
||||
err error
|
||||
}{
|
||||
"no result": {nil, 200, nil, nil, nil},
|
||||
"bad status": {nil, 500, nil, nil, fmt.Errorf("HTTP status code is 500 and not 200")},
|
||||
"network error": {nil, 200, fmt.Errorf("error"), nil, fmt.Errorf("error")},
|
||||
"results": {[]byte("a\nb\nc\n"), 200, nil, []string{"a", "b", "c"}, nil},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := &mocks.Client{}
|
||||
client.On("GetContent", "irrelevant_url").Return(
|
||||
tc.content, tc.status, tc.clientErr,
|
||||
).Once()
|
||||
results, err := getList(client, "irrelevant_url")
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.results, results)
|
||||
client.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_buildBlockedHostnames(t *testing.T) {
|
||||
t.Parallel()
|
||||
type blockParams struct {
|
||||
blocked bool
|
||||
content []byte
|
||||
clientErr error
|
||||
}
|
||||
tests := map[string]struct {
|
||||
malicious blockParams
|
||||
ads blockParams
|
||||
surveillance blockParams
|
||||
allowedHostnames []string
|
||||
lines []string
|
||||
errsString []string
|
||||
}{
|
||||
"nothing blocked": {
|
||||
lines: nil,
|
||||
errsString: nil,
|
||||
},
|
||||
"only malicious blocked": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
clientErr: nil,
|
||||
},
|
||||
lines: []string{
|
||||
" local-zone: \"site_a\" static",
|
||||
" local-zone: \"site_b\" static"},
|
||||
errsString: nil,
|
||||
},
|
||||
"all blocked with some duplicates": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_c"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_c\nsite_a"),
|
||||
},
|
||||
lines: []string{
|
||||
" local-zone: \"site_a\" static",
|
||||
" local-zone: \"site_b\" static",
|
||||
" local-zone: \"site_c\" static"},
|
||||
errsString: nil,
|
||||
},
|
||||
"all blocked with one errored": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_c"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
clientErr: fmt.Errorf("surveillance error"),
|
||||
},
|
||||
lines: []string{
|
||||
" local-zone: \"site_a\" static",
|
||||
" local-zone: \"site_b\" static",
|
||||
" local-zone: \"site_c\" static"},
|
||||
errsString: []string{"surveillance error"},
|
||||
},
|
||||
"blocked with allowed hostnames": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_c\nsite_d"),
|
||||
},
|
||||
allowedHostnames: []string{"site_b", "site_c"},
|
||||
lines: []string{
|
||||
" local-zone: \"site_a\" static",
|
||||
" local-zone: \"site_d\" static"},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := &mocks.Client{}
|
||||
if tc.malicious.blocked {
|
||||
client.On("GetContent", string(constants.MaliciousBlockListHostnamesURL)).
|
||||
Return(tc.malicious.content, 200, tc.malicious.clientErr).Once()
|
||||
}
|
||||
if tc.ads.blocked {
|
||||
client.On("GetContent", string(constants.AdsBlockListHostnamesURL)).
|
||||
Return(tc.ads.content, 200, tc.ads.clientErr).Once()
|
||||
}
|
||||
if tc.surveillance.blocked {
|
||||
client.On("GetContent", string(constants.SurveillanceBlockListHostnamesURL)).
|
||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Once()
|
||||
}
|
||||
lines, errs := buildBlockedHostnames(client,
|
||||
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked, tc.allowedHostnames)
|
||||
var errsString []string
|
||||
for _, err := range errs {
|
||||
errsString = append(errsString, err.Error())
|
||||
}
|
||||
assert.ElementsMatch(t, tc.errsString, errsString)
|
||||
assert.ElementsMatch(t, tc.lines, lines)
|
||||
client.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_buildBlockedIPs(t *testing.T) {
|
||||
t.Parallel()
|
||||
type blockParams struct {
|
||||
blocked bool
|
||||
content []byte
|
||||
clientErr error
|
||||
}
|
||||
tests := map[string]struct {
|
||||
malicious blockParams
|
||||
ads blockParams
|
||||
surveillance blockParams
|
||||
privateAddresses []string
|
||||
lines []string
|
||||
errsString []string
|
||||
}{
|
||||
"nothing blocked": {
|
||||
lines: nil,
|
||||
errsString: nil,
|
||||
},
|
||||
"only malicious blocked": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
clientErr: nil,
|
||||
},
|
||||
lines: []string{
|
||||
" private-address: site_a",
|
||||
" private-address: site_b"},
|
||||
errsString: nil,
|
||||
},
|
||||
"all blocked with some duplicates": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_c"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_c\nsite_a"),
|
||||
},
|
||||
lines: []string{
|
||||
" private-address: site_a",
|
||||
" private-address: site_b",
|
||||
" private-address: site_c"},
|
||||
errsString: nil,
|
||||
},
|
||||
"all blocked with one errored": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_c"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
clientErr: fmt.Errorf("surveillance error"),
|
||||
},
|
||||
lines: []string{
|
||||
" private-address: site_a",
|
||||
" private-address: site_b",
|
||||
" private-address: site_c"},
|
||||
errsString: []string{"surveillance error"},
|
||||
},
|
||||
"blocked with private addresses": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_c"),
|
||||
},
|
||||
privateAddresses: []string{"site_c", "site_d"},
|
||||
lines: []string{
|
||||
" private-address: site_a",
|
||||
" private-address: site_b",
|
||||
" private-address: site_c",
|
||||
" private-address: site_d"},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := &mocks.Client{}
|
||||
if tc.malicious.blocked {
|
||||
client.On("GetContent", string(constants.MaliciousBlockListIPsURL)).
|
||||
Return(tc.malicious.content, 200, tc.malicious.clientErr).Once()
|
||||
}
|
||||
if tc.ads.blocked {
|
||||
client.On("GetContent", string(constants.AdsBlockListIPsURL)).
|
||||
Return(tc.ads.content, 200, tc.ads.clientErr).Once()
|
||||
}
|
||||
if tc.surveillance.blocked {
|
||||
client.On("GetContent", string(constants.SurveillanceBlockListIPsURL)).
|
||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Once()
|
||||
}
|
||||
lines, errs := buildBlockedIPs(client,
|
||||
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked, tc.privateAddresses)
|
||||
var errsString []string
|
||||
for _, err := range errs {
|
||||
errsString = append(errsString, err.Error())
|
||||
}
|
||||
assert.ElementsMatch(t, tc.errsString, errsString)
|
||||
assert.ElementsMatch(t, tc.lines, lines)
|
||||
client.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
38
internal/dns/dns.go
Normal file
38
internal/dns/dns.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/settings"
|
||||
)
|
||||
|
||||
const logPrefix = "dns configurator"
|
||||
|
||||
type Configurator interface {
|
||||
DownloadRootHints(uid, gid int) error
|
||||
DownloadRootKey(uid, gid int) error
|
||||
MakeUnboundConf(settings settings.DNS, uid, gid int) (err error)
|
||||
SetLocalNameserver() error
|
||||
Start(logLevel uint8) (stdout io.ReadCloser, err error)
|
||||
Version() (version string, err error)
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
logger logging.Logger
|
||||
client network.Client
|
||||
fileManager files.FileManager
|
||||
commander command.Commander
|
||||
}
|
||||
|
||||
func NewConfigurator(logger logging.Logger, client network.Client, fileManager files.FileManager) Configurator {
|
||||
return &configurator{
|
||||
logger: logger,
|
||||
client: client,
|
||||
fileManager: fileManager,
|
||||
commander: command.NewCommander(),
|
||||
}
|
||||
}
|
||||
32
internal/dns/os.go
Normal file
32
internal/dns/os.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) SetLocalNameserver() error {
|
||||
c.logger.Info("%s: setting local nameserver to 127.0.0.1", logPrefix)
|
||||
data, err := c.fileManager.ReadFile(string(constants.ResolvConf))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s := strings.TrimSuffix(string(data), "\n")
|
||||
lines := strings.Split(s, "\n")
|
||||
if len(lines) == 1 && lines[0] == "" {
|
||||
lines = nil
|
||||
}
|
||||
found := false
|
||||
for i := range lines {
|
||||
if strings.HasPrefix(lines[i], "nameserver ") {
|
||||
lines[i] = "nameserver 127.0.0.1"
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
lines = append(lines, "nameserver 127.0.0.1")
|
||||
}
|
||||
data = []byte(strings.Join(lines, "\n"))
|
||||
return c.fileManager.WriteToFile(string(constants.ResolvConf), data)
|
||||
}
|
||||
72
internal/dns/os_test.go
Normal file
72
internal/dns/os_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
filesmocks "github.com/qdm12/golibs/files/mocks"
|
||||
loggingmocks "github.com/qdm12/golibs/logging/mocks"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_SetLocalNameserver(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
data []byte
|
||||
writtenData []byte
|
||||
readErr error
|
||||
writeErr error
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
writtenData: []byte("nameserver 127.0.0.1"),
|
||||
},
|
||||
"read error": {
|
||||
readErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"write error": {
|
||||
writtenData: []byte("nameserver 127.0.0.1"),
|
||||
writeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"lines without nameserver": {
|
||||
data: []byte("abc\ndef\n"),
|
||||
writtenData: []byte("abc\ndef\nnameserver 127.0.0.1"),
|
||||
},
|
||||
"lines with nameserver": {
|
||||
data: []byte("abc\nnameserver abc def\ndef\n"),
|
||||
writtenData: []byte("abc\nnameserver 127.0.0.1\ndef"),
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fileManager := &filesmocks.FileManager{}
|
||||
fileManager.On("ReadFile", string(constants.ResolvConf)).
|
||||
Return(tc.data, tc.readErr).Once()
|
||||
if tc.readErr == nil {
|
||||
fileManager.On("WriteToFile", string(constants.ResolvConf), tc.writtenData).
|
||||
Return(tc.writeErr).Once()
|
||||
}
|
||||
logger := &loggingmocks.Logger{}
|
||||
logger.On("Info", "%s: setting local nameserver to 127.0.0.1", logPrefix).Once()
|
||||
c := &configurator{
|
||||
fileManager: fileManager,
|
||||
logger: logger,
|
||||
}
|
||||
err := c.SetLocalNameserver()
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
fileManager.AssertExpectations(t)
|
||||
logger.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
38
internal/dns/roots.go
Normal file
38
internal/dns/roots.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) DownloadRootHints(uid, gid int) error {
|
||||
c.logger.Info("%s: downloading root hints from %s", logPrefix, constants.NamedRootURL)
|
||||
content, status, err := c.client.GetContent(string(constants.NamedRootURL))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != 200 {
|
||||
return fmt.Errorf("HTTP status code is %d for %s", status, constants.NamedRootURL)
|
||||
}
|
||||
return c.fileManager.WriteToFile(
|
||||
string(constants.RootHints),
|
||||
content,
|
||||
files.FileOwnership(uid, gid),
|
||||
files.FilePermissions(0400))
|
||||
}
|
||||
|
||||
func (c *configurator) DownloadRootKey(uid, gid int) error {
|
||||
c.logger.Info("%s: downloading root key from %s", logPrefix, constants.RootKeyURL)
|
||||
content, status, err := c.client.GetContent(string(constants.RootKeyURL))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != 200 {
|
||||
return fmt.Errorf("HTTP status code is %d for %s", status, constants.RootKeyURL)
|
||||
}
|
||||
return c.fileManager.WriteToFile(
|
||||
string(constants.RootKey),
|
||||
content,
|
||||
files.FileOwnership(uid, gid),
|
||||
files.FilePermissions(0400))
|
||||
}
|
||||
144
internal/dns/roots_test.go
Normal file
144
internal/dns/roots_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
filesMocks "github.com/qdm12/golibs/files/mocks"
|
||||
loggingMocks "github.com/qdm12/golibs/logging/mocks"
|
||||
networkMocks "github.com/qdm12/golibs/network/mocks"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func Test_DownloadRootHints(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
content []byte
|
||||
status int
|
||||
clientErr error
|
||||
writeErr error
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
status: http.StatusOK,
|
||||
},
|
||||
"bad status": {
|
||||
status: http.StatusBadRequest,
|
||||
err: fmt.Errorf("HTTP status code is 400 for https://raw.githubusercontent.com/qdm12/files/master/named.root.updated"),
|
||||
},
|
||||
"client error": {
|
||||
clientErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"write error": {
|
||||
status: http.StatusOK,
|
||||
writeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"data": {
|
||||
content: []byte("content"),
|
||||
status: http.StatusOK,
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := &loggingMocks.Logger{}
|
||||
logger.On("Info", "%s: downloading root hints from %s", logPrefix, constants.NamedRootURL).Once()
|
||||
client := &networkMocks.Client{}
|
||||
client.On("GetContent", string(constants.NamedRootURL)).
|
||||
Return(tc.content, tc.status, tc.clientErr).Once()
|
||||
fileManager := &filesMocks.FileManager{}
|
||||
if tc.clientErr == nil && tc.status == http.StatusOK {
|
||||
fileManager.On(
|
||||
"WriteToFile",
|
||||
string(constants.RootHints),
|
||||
tc.content,
|
||||
mock.AnythingOfType("files.WriteOptionSetter"),
|
||||
mock.AnythingOfType("files.WriteOptionSetter")).
|
||||
Return(tc.writeErr).Once()
|
||||
}
|
||||
c := &configurator{logger: logger, client: client, fileManager: fileManager}
|
||||
err := c.DownloadRootHints(1000, 1000)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
logger.AssertExpectations(t)
|
||||
client.AssertExpectations(t)
|
||||
fileManager.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_DownloadRootKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
content []byte
|
||||
status int
|
||||
clientErr error
|
||||
writeErr error
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
status: http.StatusOK,
|
||||
},
|
||||
"bad status": {
|
||||
status: http.StatusBadRequest,
|
||||
err: fmt.Errorf("HTTP status code is 400 for https://raw.githubusercontent.com/qdm12/files/master/root.key.updated"),
|
||||
},
|
||||
"client error": {
|
||||
clientErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"write error": {
|
||||
status: http.StatusOK,
|
||||
writeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"data": {
|
||||
content: []byte("content"),
|
||||
status: http.StatusOK,
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := &loggingMocks.Logger{}
|
||||
logger.On("Info", "%s: downloading root key from %s", logPrefix, constants.RootKeyURL).Once()
|
||||
client := &networkMocks.Client{}
|
||||
client.On("GetContent", string(constants.RootKeyURL)).
|
||||
Return(tc.content, tc.status, tc.clientErr).Once()
|
||||
fileManager := &filesMocks.FileManager{}
|
||||
if tc.clientErr == nil && tc.status == http.StatusOK {
|
||||
fileManager.On(
|
||||
"WriteToFile",
|
||||
string(constants.RootKey),
|
||||
tc.content,
|
||||
mock.AnythingOfType("files.WriteOptionSetter"),
|
||||
mock.AnythingOfType("files.WriteOptionSetter"),
|
||||
).Return(tc.writeErr).Once()
|
||||
}
|
||||
c := &configurator{logger: logger, client: client, fileManager: fileManager}
|
||||
err := c.DownloadRootKey(1000, 1001)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
logger.AssertExpectations(t)
|
||||
client.AssertExpectations(t)
|
||||
fileManager.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
40
internal/env/env.go
vendored
Normal file
40
internal/env/env.go
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type Env interface {
|
||||
FatalOnError(err error)
|
||||
PrintVersion(program string, commandFn func() (string, error))
|
||||
}
|
||||
|
||||
type env struct {
|
||||
logger logging.Logger
|
||||
osExit func(n int)
|
||||
}
|
||||
|
||||
func New(logger logging.Logger) Env {
|
||||
return &env{
|
||||
logger: logger,
|
||||
osExit: os.Exit,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *env) FatalOnError(err error) {
|
||||
if err != nil {
|
||||
e.logger.Error(err)
|
||||
e.osExit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *env) PrintVersion(program string, commandFn func() (string, error)) {
|
||||
version, err := commandFn()
|
||||
if err != nil {
|
||||
e.logger.Error(err)
|
||||
} else {
|
||||
e.logger.Info("%s version: %s", program, version)
|
||||
}
|
||||
}
|
||||
90
internal/env/env_test.go
vendored
Normal file
90
internal/env/env_test.go
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/qdm12/golibs/logging/mocks"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func Test_FatalOnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
err error
|
||||
}{
|
||||
"nil": {},
|
||||
"err": {fmt.Errorf("error")},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var logged string
|
||||
var exitCode int
|
||||
logger := &mocks.Logger{}
|
||||
if tc.err != nil {
|
||||
logger.On("Error", tc.err).
|
||||
Run(func(args mock.Arguments) {
|
||||
err := args.Get(0).(error)
|
||||
logged = err.Error()
|
||||
}).Once()
|
||||
}
|
||||
osExit := func(n int) { exitCode = n }
|
||||
e := &env{logger, osExit}
|
||||
e.FatalOnError(tc.err)
|
||||
if tc.err != nil {
|
||||
assert.Equal(t, logged, tc.err.Error())
|
||||
assert.Equal(t, exitCode, 1)
|
||||
} else {
|
||||
assert.Empty(t, logged)
|
||||
assert.Zero(t, exitCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_PrintVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
program string
|
||||
commandVersion string
|
||||
commandErr error
|
||||
}{
|
||||
"no data": {},
|
||||
"data": {"binu", "2.3-5", nil},
|
||||
"error": {"binu", "", fmt.Errorf("error")},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var logged string
|
||||
logger := &mocks.Logger{}
|
||||
if tc.commandErr != nil {
|
||||
logger.On("Error", tc.commandErr).
|
||||
Run(func(args mock.Arguments) {
|
||||
err := args.Get(0).(error)
|
||||
logged = err.Error()
|
||||
}).Once()
|
||||
} else {
|
||||
logger.On("Info", "%s version: %s", tc.program, tc.commandVersion).
|
||||
Run(func(args mock.Arguments) {
|
||||
format := args.Get(0).(string)
|
||||
program := args.Get(1).(string)
|
||||
version := args.Get(2).(string)
|
||||
logged = fmt.Sprintf(format, program, version)
|
||||
}).Once()
|
||||
}
|
||||
e := &env{logger: logger}
|
||||
commandFn := func() (string, error) { return tc.commandVersion, tc.commandErr }
|
||||
e.PrintVersion(tc.program, commandFn)
|
||||
if tc.commandErr != nil {
|
||||
assert.Equal(t, logged, tc.commandErr.Error())
|
||||
} else {
|
||||
assert.Equal(t, logged, fmt.Sprintf("%s version: %s", tc.program, tc.commandVersion))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
42
internal/firewall/firewall.go
Normal file
42
internal/firewall/firewall.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
const logPrefix = "firewall configurator"
|
||||
|
||||
// Configurator allows to change firewall rules and modify network routes
|
||||
type Configurator interface {
|
||||
Version() (string, error)
|
||||
AcceptAll() error
|
||||
Clear() error
|
||||
BlockAll() error
|
||||
CreateGeneralRules() error
|
||||
CreateVPNRules(dev models.VPNDevice, serverIPs []net.IP, defaultInterface string,
|
||||
port uint16, protocol models.NetworkProtocol) error
|
||||
CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error
|
||||
AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error
|
||||
GetDefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error)
|
||||
AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
commander command.Commander
|
||||
logger logging.Logger
|
||||
fileManager files.FileManager
|
||||
}
|
||||
|
||||
// NewConfigurator creates a new Configurator instance
|
||||
func NewConfigurator(logger logging.Logger, fileManager files.FileManager) Configurator {
|
||||
return &configurator{
|
||||
commander: command.NewCommander(),
|
||||
logger: logger,
|
||||
fileManager: fileManager,
|
||||
}
|
||||
}
|
||||
130
internal/firewall/iptables.go
Normal file
130
internal/firewall/iptables.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
// Version obtains the version of the installed iptables
|
||||
func (c *configurator) Version() (string, error) {
|
||||
output, err := c.commander.Run("iptables", "--version")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
words := strings.Fields(output)
|
||||
if len(words) < 2 {
|
||||
return "", fmt.Errorf("iptables --version: output is too short: %q", output)
|
||||
}
|
||||
return words[1], nil
|
||||
}
|
||||
|
||||
func (c *configurator) runIptablesInstructions(instructions []string) error {
|
||||
for _, instruction := range instructions {
|
||||
if err := c.runIptablesInstruction(instruction); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) runIptablesInstruction(instruction string) error {
|
||||
flags := strings.Fields(instruction)
|
||||
if output, err := c.commander.Run("iptables", flags...); err != nil {
|
||||
return fmt.Errorf("failed executing %q: %s: %w", instruction, output, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) Clear() error {
|
||||
c.logger.Info("%s: clearing all rules", logPrefix)
|
||||
return c.runIptablesInstructions([]string{
|
||||
"--flush",
|
||||
"--delete-chain",
|
||||
"-t nat --flush",
|
||||
"-t nat --delete-chain",
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) AcceptAll() error {
|
||||
c.logger.Info("%s: accepting all traffic", logPrefix)
|
||||
return c.runIptablesInstructions([]string{
|
||||
"-P INPUT ACCEPT",
|
||||
"-P OUTPUT ACCEPT",
|
||||
"-P FORWARD ACCEPT",
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) BlockAll() error {
|
||||
c.logger.Info("%s: blocking all traffic", logPrefix)
|
||||
return c.runIptablesInstructions([]string{
|
||||
"-P INPUT DROP",
|
||||
"-F OUTPUT",
|
||||
"-P OUTPUT DROP",
|
||||
"-P FORWARD DROP",
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) CreateGeneralRules() error {
|
||||
c.logger.Info("%s: creating general rules", logPrefix)
|
||||
return c.runIptablesInstructions([]string{
|
||||
"-A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
"-A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
||||
"-A OUTPUT -o lo -j ACCEPT",
|
||||
"-A INPUT -i lo -j ACCEPT",
|
||||
})
|
||||
}
|
||||
|
||||
func (c *configurator) CreateVPNRules(dev models.VPNDevice, serverIPs []net.IP,
|
||||
defaultInterface string, port uint16, protocol models.NetworkProtocol) error {
|
||||
for _, serverIP := range serverIPs {
|
||||
c.logger.Info("%s: allowing output traffic to VPN server %s through %s on port %s %d",
|
||||
logPrefix, serverIP, defaultInterface, protocol, port)
|
||||
if err := c.runIptablesInstruction(
|
||||
fmt.Sprintf("-A OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||
serverIP, defaultInterface, protocol, protocol, port)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := c.runIptablesInstruction(fmt.Sprintf("-A OUTPUT -o %s -j ACCEPT", dev)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error {
|
||||
subnetStr := subnet.String()
|
||||
c.logger.Info("%s: accepting input and output traffic for %s", logPrefix, subnetStr)
|
||||
if err := c.runIptablesInstructions([]string{
|
||||
fmt.Sprintf("-A INPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
|
||||
fmt.Sprintf("-A OUTPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, extraSubnet := range extraSubnets {
|
||||
extraSubnetStr := extraSubnet.String()
|
||||
c.logger.Info("%s: accepting input traffic through %s from %s to %s", logPrefix, defaultInterface, extraSubnetStr, subnetStr)
|
||||
if err := c.runIptablesInstruction(
|
||||
fmt.Sprintf("-A INPUT -i %s -s %s -d %s -j ACCEPT", defaultInterface, extraSubnetStr, subnetStr)); err != nil {
|
||||
return err
|
||||
}
|
||||
// Thanks to @npawelek
|
||||
c.logger.Info("%s: accepting output traffic through %s from %s to %s", logPrefix, defaultInterface, subnetStr, extraSubnetStr)
|
||||
if err := c.runIptablesInstruction(
|
||||
fmt.Sprintf("-A OUTPUT -o %s -s %s -d %s -j ACCEPT", defaultInterface, subnetStr, extraSubnetStr)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Used for port forwarding
|
||||
func (c *configurator) AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error {
|
||||
c.logger.Info("%s: accepting input traffic through %s on port %d", logPrefix, device, port)
|
||||
return c.runIptablesInstructions([]string{
|
||||
fmt.Sprintf("-A INPUT -i %s -p tcp --dport %d -j ACCEPT", device, port),
|
||||
fmt.Sprintf("-A INPUT -i %s -p udp --dport %d -j ACCEPT", device, port),
|
||||
})
|
||||
}
|
||||
88
internal/firewall/route.go
Normal file
88
internal/firewall/route.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"net"
|
||||
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error {
|
||||
for _, subnet := range subnets {
|
||||
subnetStr := subnet.String()
|
||||
output, err := c.commander.Run("ip", "route", "show", subnetStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read route %s: %s: %w", subnetStr, output, err)
|
||||
} else if len(output) > 0 { // thanks to @npawelek https://github.com/npawelek
|
||||
continue // already exists
|
||||
// TODO remove it instead and continue execution below
|
||||
}
|
||||
c.logger.Info("%s: adding %s as route via %s", logPrefix, subnetStr, defaultInterface)
|
||||
output, err = c.commander.Run("ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnetStr, defaultGateway.String(), "dev", defaultInterface, output, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configurator) GetDefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) {
|
||||
c.logger.Info("%s: detecting default network route", logPrefix)
|
||||
data, err := c.fileManager.ReadFile(string(constants.NetRoute))
|
||||
if err != nil {
|
||||
return "", nil, defaultSubnet, err
|
||||
}
|
||||
// Verify number of lines and fields
|
||||
lines := strings.Split(string(data), "\n")
|
||||
if len(lines) < 3 {
|
||||
return "", nil, defaultSubnet, fmt.Errorf("not enough lines (%d) found in %s", len(lines), constants.NetRoute)
|
||||
}
|
||||
fieldsLine1 := strings.Fields(lines[1])
|
||||
if len(fieldsLine1) < 3 {
|
||||
return "", nil, defaultSubnet, fmt.Errorf("not enough fields in %q", lines[1])
|
||||
}
|
||||
fieldsLine2 := strings.Fields(lines[2])
|
||||
if len(fieldsLine2) < 8 {
|
||||
return "", nil, defaultSubnet, fmt.Errorf("not enough fields in %q", lines[2])
|
||||
}
|
||||
// get information
|
||||
defaultInterface = fieldsLine1[0]
|
||||
defaultGateway, err = reversedHexToIPv4(fieldsLine1[2])
|
||||
if err != nil {
|
||||
return "", nil, defaultSubnet, err
|
||||
}
|
||||
netNumber, err := reversedHexToIPv4(fieldsLine2[1])
|
||||
if err != nil {
|
||||
return "", nil, defaultSubnet, err
|
||||
}
|
||||
netMask, err := hexToIPv4Mask(fieldsLine2[7])
|
||||
if err != nil {
|
||||
return "", nil, defaultSubnet, err
|
||||
}
|
||||
subnet := net.IPNet{IP: netNumber, Mask: netMask}
|
||||
c.logger.Info("%s: default route found: interface %s, gateway %s, subnet %s", logPrefix, defaultInterface, defaultGateway.String(), subnet.String())
|
||||
return defaultInterface, defaultGateway, subnet, nil
|
||||
}
|
||||
|
||||
func reversedHexToIPv4(reversedHex string) (IP net.IP, err error) {
|
||||
bytes, err := hex.DecodeString(reversedHex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse reversed IP hex %q: %s", reversedHex, err)
|
||||
} else if len(bytes) != 4 {
|
||||
return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes))
|
||||
}
|
||||
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
|
||||
}
|
||||
|
||||
func hexToIPv4Mask(hexString string) (mask net.IPMask, err error) {
|
||||
bytes, err := hex.DecodeString(hexString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse hex mask %q: %s", hexString, err)
|
||||
} else if len(bytes) != 4 {
|
||||
return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes))
|
||||
}
|
||||
return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil
|
||||
}
|
||||
171
internal/firewall/route_test.go
Normal file
171
internal/firewall/route_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
filesmocks "github.com/qdm12/golibs/files/mocks"
|
||||
loggingmocks "github.com/qdm12/golibs/logging/mocks"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func Test_getDefaultRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
data []byte
|
||||
readErr error
|
||||
defaultInterface string
|
||||
defaultGateway net.IP
|
||||
defaultSubnet net.IPNet
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
err: fmt.Errorf("not enough lines (1) found in %s", constants.NetRoute)},
|
||||
"read error": {
|
||||
readErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error")},
|
||||
"not enough fields line 1": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000
|
||||
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
|
||||
err: fmt.Errorf("not enough fields in \"eth0 00000000\"")},
|
||||
"not enough fields line 2": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
||||
eth0 000011AC 00000000 0001 0 0 0`),
|
||||
err: fmt.Errorf("not enough fields in \"eth0 000011AC 00000000 0001 0 0 0\"")},
|
||||
"bad gateway": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 x 0003 0 0 0 00000000 0 0 0
|
||||
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
|
||||
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
||||
"bad net number": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
||||
eth0 x 00000000 0001 0 0 0 0000FFFF 0 0 0`),
|
||||
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
||||
"bad net mask": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
||||
eth0 000011AC 00000000 0001 0 0 0 x 0 0 0`),
|
||||
err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
||||
"success": {
|
||||
data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0
|
||||
eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`),
|
||||
defaultInterface: "eth0",
|
||||
defaultGateway: net.IP{0xac, 0x11, 0x0, 0x1},
|
||||
defaultSubnet: net.IPNet{
|
||||
IP: net.IP{0xac, 0x11, 0x0, 0x0},
|
||||
Mask: net.IPMask{0xff, 0xff, 0x0, 0x0},
|
||||
}},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fileManager := &filesmocks.FileManager{}
|
||||
fileManager.On("ReadFile", string(constants.NetRoute)).
|
||||
Return(tc.data, tc.readErr).Once()
|
||||
logger := &loggingmocks.Logger{}
|
||||
logger.On("Info", "%s: detecting default network route", logPrefix).Once()
|
||||
if tc.err == nil {
|
||||
logger.On("Info", "%s: default route found: interface %s, gateway %s, subnet %s",
|
||||
logPrefix, tc.defaultInterface, tc.defaultGateway.String(), tc.defaultSubnet.String()).Once()
|
||||
}
|
||||
c := &configurator{logger: logger, fileManager: fileManager}
|
||||
defaultInterface, defaultGateway, defaultSubnet, err := c.GetDefaultRoute()
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.defaultInterface, defaultInterface)
|
||||
assert.Equal(t, tc.defaultGateway, defaultGateway)
|
||||
assert.Equal(t, tc.defaultSubnet, defaultSubnet)
|
||||
fileManager.AssertExpectations(t)
|
||||
logger.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_reversedHexToIPv4(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
reversedHex string
|
||||
IP net.IP
|
||||
err error
|
||||
}{
|
||||
"empty hex": {
|
||||
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
|
||||
"bad hex": {
|
||||
reversedHex: "x",
|
||||
err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
||||
"3 bytes hex": {
|
||||
reversedHex: "9abcde",
|
||||
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
|
||||
"correct hex": {
|
||||
reversedHex: "010011AC",
|
||||
IP: []byte{0xac, 0x11, 0x0, 0x1},
|
||||
err: nil},
|
||||
"correct hex 2": {
|
||||
reversedHex: "000011AC",
|
||||
IP: []byte{0xac, 0x11, 0x0, 0x0},
|
||||
err: nil},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
IP, err := reversedHexToIPv4(tc.reversedHex)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.IP, IP)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_hexMaskToDecMask(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
hexString string
|
||||
mask net.IPMask
|
||||
err error
|
||||
}{
|
||||
"empty hex": {
|
||||
err: fmt.Errorf("hex string contains 0 bytes instead of 4")},
|
||||
"bad hex": {
|
||||
hexString: "x",
|
||||
err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")},
|
||||
"3 bytes hex": {
|
||||
hexString: "9abcde",
|
||||
err: fmt.Errorf("hex string contains 3 bytes instead of 4")},
|
||||
"16": {
|
||||
hexString: "0000FFFF",
|
||||
mask: []byte{0xff, 0xff, 0x0, 0x0},
|
||||
err: nil},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mask, err := hexToIPv4Mask(tc.hexString)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.mask, mask)
|
||||
})
|
||||
}
|
||||
}
|
||||
24
internal/models/alias.go
Normal file
24
internal/models/alias.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package models
|
||||
|
||||
type (
|
||||
// VPNDevice is the device name used to tunnel using Openvpn
|
||||
VPNDevice string
|
||||
// DNSProvider is a DNS over TLS server provider name
|
||||
DNSProvider string
|
||||
// DNSForwardAddress is the Unbound formatted forward address
|
||||
DNSForwardAddress string
|
||||
// PIAEncryption defines the level of encryption for communication with PIA servers
|
||||
PIAEncryption string
|
||||
// PIARegion contains the list of regions available for PIA
|
||||
PIARegion string
|
||||
// URL is an HTTP(s) URL address
|
||||
URL string
|
||||
// Filepath is a local filesytem file path
|
||||
Filepath string
|
||||
// TinyProxyLogLevel is the log level for TinyProxy
|
||||
TinyProxyLogLevel string
|
||||
// VPNProvider is the name of the VPN provider to be used
|
||||
VPNProvider string
|
||||
// NetworkProtocol contains the network protocol to be used to communicate with the VPN servers
|
||||
NetworkProtocol string
|
||||
)
|
||||
23
internal/openvpn/auth.go
Normal file
23
internal/openvpn/auth.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
// WriteAuthFile writes the OpenVPN auth file to disk with the right permissions
|
||||
func (c *configurator) WriteAuthFile(user, password string, uid, gid int) error {
|
||||
authExists, err := c.fileManager.FileExists(string(constants.OpenVPNAuthConf))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if authExists { // in case of container stop/start
|
||||
c.logger.Info("%s: %s already exists", logPrefix, constants.OpenVPNAuthConf)
|
||||
return nil
|
||||
}
|
||||
c.logger.Info("%s: writing auth file %s", logPrefix, constants.OpenVPNAuthConf)
|
||||
return c.fileManager.WriteLinesToFile(
|
||||
string(constants.OpenVPNAuthConf),
|
||||
[]string{user, password},
|
||||
files.FileOwnership(uid, gid),
|
||||
files.FilePermissions(0400))
|
||||
}
|
||||
28
internal/openvpn/command.go
Normal file
28
internal/openvpn/command.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) Start() (stdout io.ReadCloser, err error) {
|
||||
c.logger.Info("%s: starting openvpn", logPrefix)
|
||||
stdout, _, _, err = c.commander.Start("openvpn", "--config", string(constants.OpenVPNConf))
|
||||
return stdout, err
|
||||
}
|
||||
|
||||
func (c *configurator) Version() (string, error) {
|
||||
output, err := c.commander.Run("openvpn", "--version")
|
||||
if err != nil && err.Error() != "exit status 1" {
|
||||
return "", err
|
||||
}
|
||||
firstLine := strings.Split(output, "\n")[0]
|
||||
words := strings.Fields(firstLine)
|
||||
if len(words) < 2 {
|
||||
return "", fmt.Errorf("openvpn --version: first line is too short: %q", firstLine)
|
||||
}
|
||||
return words[1], nil
|
||||
}
|
||||
35
internal/openvpn/openvpn.go
Normal file
35
internal/openvpn/openvpn.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
const logPrefix = "openvpn configurator"
|
||||
|
||||
type Configurator interface {
|
||||
Version() (string, error)
|
||||
WriteAuthFile(user, password string, uid, gid int) error
|
||||
CheckTUN() error
|
||||
Start() (stdout io.ReadCloser, err error)
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
fileManager files.FileManager
|
||||
logger logging.Logger
|
||||
commander command.Commander
|
||||
openFile func(name string, flag int, perm os.FileMode) (*os.File, error)
|
||||
}
|
||||
|
||||
func NewConfigurator(logger logging.Logger, fileManager files.FileManager) Configurator {
|
||||
return &configurator{
|
||||
fileManager: fileManager,
|
||||
logger: logger,
|
||||
commander: command.NewCommander(),
|
||||
openFile: os.OpenFile,
|
||||
}
|
||||
}
|
||||
21
internal/openvpn/tun.go
Normal file
21
internal/openvpn/tun.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
// CheckTUN checks the tunnel device is present and accessible
|
||||
func (c *configurator) CheckTUN() error {
|
||||
c.logger.Info("%s: checking for device %s", logPrefix, constants.TunnelDevice)
|
||||
f, err := c.openFile(string(constants.TunnelDevice), os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("TUN device is not available: %w", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
c.logger.Warn("Could not close TUN device file: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
93
internal/params/dns.go
Normal file
93
internal/params/dns.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package params
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
libparams "github.com/qdm12/golibs/params"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
// GetDNSOverTLS obtains if the DNS over TLS should be enabled
|
||||
// from the environment variable DOT
|
||||
func (p *paramsReader) GetDNSOverTLS() (DNSOverTLS bool, err error) {
|
||||
return p.envParams.GetOnOff("DOT", libparams.Default("on"))
|
||||
}
|
||||
|
||||
// GetDNSOverTLSProviders obtains the DNS over TLS providers to use
|
||||
// from the environment variable DOT_PROVIDERS
|
||||
func (p *paramsReader) GetDNSOverTLSProviders() (providers []models.DNSProvider, err error) {
|
||||
s, err := p.envParams.GetEnv("DOT_PROVIDERS", libparams.Default("cloudflare"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, word := range strings.Split(s, ",") {
|
||||
provider := models.DNSProvider(word)
|
||||
switch provider {
|
||||
case constants.Cloudflare, constants.Google, constants.Quad9, constants.Quadrant, constants.CleanBrowsing, constants.SecureDNS, constants.LibreDNS:
|
||||
providers = append(providers, provider)
|
||||
default:
|
||||
return nil, fmt.Errorf("DNS over TLS provider %q is not valid", provider)
|
||||
}
|
||||
}
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
// GetDNSOverTLSVerbosity obtains the verbosity level to use for Unbound
|
||||
// from the environment variable DOT_VERBOSITY
|
||||
func (p *paramsReader) GetDNSOverTLSVerbosity() (verbosityLevel uint8, err error) {
|
||||
n, err := p.envParams.GetEnvIntRange("DOT_VERBOSITY", 0, 5, libparams.Default("1"))
|
||||
return uint8(n), err
|
||||
}
|
||||
|
||||
// GetDNSOverTLSVerbosityDetails obtains the log level to use for Unbound
|
||||
// from the environment variable DOT_VERBOSITY_DETAILS
|
||||
func (p *paramsReader) GetDNSOverTLSVerbosityDetails() (verbosityDetailsLevel uint8, err error) {
|
||||
n, err := p.envParams.GetEnvIntRange("DOT_VERBOSITY_DETAILS", 0, 4, libparams.Default("0"))
|
||||
return uint8(n), err
|
||||
}
|
||||
|
||||
// GetDNSOverTLSValidationLogLevel obtains the log level to use for Unbound DOT validation
|
||||
// from the environment variable DOT_VALIDATION_LOGLEVEL
|
||||
func (p *paramsReader) GetDNSOverTLSValidationLogLevel() (validationLogLevel uint8, err error) {
|
||||
n, err := p.envParams.GetEnvIntRange("DOT_VALIDATION_LOGLEVEL", 0, 2, libparams.Default("0"))
|
||||
return uint8(n), err
|
||||
}
|
||||
|
||||
// GetDNSMaliciousBlocking obtains if malicious hostnames/IPs should be blocked
|
||||
// from being resolved by Unbound, using the environment variable BLOCK_MALICIOUS
|
||||
func (p *paramsReader) GetDNSMaliciousBlocking() (blocking bool, err error) {
|
||||
return p.envParams.GetOnOff("BLOCK_MALICIOUS", libparams.Default("on"))
|
||||
}
|
||||
|
||||
// GetDNSSurveillanceBlocking obtains if surveillance hostnames/IPs should be blocked
|
||||
// from being resolved by Unbound, using the environment variable BLOCK_NSA
|
||||
func (p *paramsReader) GetDNSSurveillanceBlocking() (blocking bool, err error) {
|
||||
return p.envParams.GetOnOff("BLOCK_NSA", libparams.Default("off"))
|
||||
}
|
||||
|
||||
// GetDNSAdsBlocking obtains if ads hostnames/IPs should be blocked
|
||||
// from being resolved by Unbound, using the environment variable BLOCK_ADS
|
||||
func (p *paramsReader) GetDNSAdsBlocking() (blocking bool, err error) {
|
||||
return p.envParams.GetOnOff("BLOCK_ADS", libparams.Default("off"))
|
||||
}
|
||||
|
||||
// GetDNSUnblockedHostnames obtains a list of hostnames to unblock from block lists
|
||||
// from the comma separated list for the environment variable UNBLOCK
|
||||
func (p *paramsReader) GetDNSUnblockedHostnames() (hostnames []string, err error) {
|
||||
s, err := p.envParams.GetEnv("UNBLOCK")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(s) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
hostnames = strings.Split(s, ",")
|
||||
for _, hostname := range hostnames {
|
||||
if !p.verifier.MatchHostname(hostname) {
|
||||
return nil, fmt.Errorf("hostname %q does not seem valid", hostname)
|
||||
}
|
||||
}
|
||||
return hostnames, nil
|
||||
}
|
||||
29
internal/params/firewall.go
Normal file
29
internal/params/firewall.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package params
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GetExtraSubnets obtains the CIDR subnets from the comma separated list of the
|
||||
// environment variable EXTRA_SUBNETS
|
||||
func (p *paramsReader) GetExtraSubnets() (extraSubnets []net.IPNet, err error) {
|
||||
s, err := p.envParams.GetEnv("EXTRA_SUBNETS")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if s == "" {
|
||||
return nil, nil
|
||||
}
|
||||
subnets := strings.Split(s, ",")
|
||||
for _, subnet := range subnets {
|
||||
_, cidr, err := net.ParseCIDR(subnet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse subnet %q from environment variable with key EXTRA_SUBNETS: %w", subnet, err)
|
||||
} else if cidr == nil {
|
||||
return nil, fmt.Errorf("parsing subnet %q resulted in a nil CIDR", subnet)
|
||||
}
|
||||
extraSubnets = append(extraSubnets, *cidr)
|
||||
}
|
||||
return extraSubnets, nil
|
||||
}
|
||||
13
internal/params/openvpn.go
Normal file
13
internal/params/openvpn.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package params
|
||||
|
||||
import (
|
||||
libparams "github.com/qdm12/golibs/params"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
// GetNetworkProtocol obtains the network protocol to use to connect to the
|
||||
// VPN servers from the environment variable PROTOCOL
|
||||
func (p *paramsReader) GetNetworkProtocol() (protocol models.NetworkProtocol, err error) {
|
||||
s, err := p.envParams.GetValueIfInside("PROTOCOL", []string{"tcp", "udp"}, libparams.Default("udp"))
|
||||
return models.NetworkProtocol(s), err
|
||||
}
|
||||
73
internal/params/params.go
Normal file
73
internal/params/params.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package params
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/qdm12/golibs/logging"
|
||||
libparams "github.com/qdm12/golibs/params"
|
||||
"github.com/qdm12/golibs/verification"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
// ParamsReader contains methods to obtain parameters
|
||||
type ParamsReader interface {
|
||||
// DNS over TLS getters
|
||||
GetDNSOverTLS() (DNSOverTLS bool, err error)
|
||||
GetDNSOverTLSProviders() (providers []models.DNSProvider, err error)
|
||||
GetDNSOverTLSVerbosity() (verbosityLevel uint8, err error)
|
||||
GetDNSOverTLSVerbosityDetails() (verbosityDetailsLevel uint8, err error)
|
||||
GetDNSOverTLSValidationLogLevel() (validationLogLevel uint8, err error)
|
||||
GetDNSMaliciousBlocking() (blocking bool, err error)
|
||||
GetDNSSurveillanceBlocking() (blocking bool, err error)
|
||||
GetDNSAdsBlocking() (blocking bool, err error)
|
||||
GetDNSUnblockedHostnames() (hostnames []string, err error)
|
||||
|
||||
// Firewall getters
|
||||
GetExtraSubnets() (extraSubnets []net.IPNet, err error)
|
||||
|
||||
// VPN getters
|
||||
GetNetworkProtocol() (protocol models.NetworkProtocol, err error)
|
||||
|
||||
// PIA getters
|
||||
GetUser() (s string, err error)
|
||||
GetPassword() (s string, err error)
|
||||
GetPortForwarding() (activated bool, err error)
|
||||
GetPortForwardingStatusFilepath() (filepath models.Filepath, err error)
|
||||
GetPIAEncryption() (models.PIAEncryption, error)
|
||||
GetPIARegion() (models.PIARegion, error)
|
||||
|
||||
// Shadowsocks getters
|
||||
GetShadowSocks() (activated bool, err error)
|
||||
GetShadowSocksLog() (activated bool, err error)
|
||||
GetShadowSocksPort() (port uint16, err error)
|
||||
GetShadowSocksPassword() (password string, err error)
|
||||
|
||||
// Tinyproxy getters
|
||||
GetTinyProxy() (activated bool, err error)
|
||||
GetTinyProxyLog() (models.TinyProxyLogLevel, error)
|
||||
GetTinyProxyPort() (port uint16, err error)
|
||||
GetTinyProxyUser() (user string, err error)
|
||||
GetTinyProxyPassword() (password string, err error)
|
||||
|
||||
// Version getters
|
||||
GetVersion() string
|
||||
GetBuildDate() string
|
||||
GetVcsRef() string
|
||||
}
|
||||
|
||||
type paramsReader struct {
|
||||
envParams libparams.EnvParams
|
||||
logger logging.Logger
|
||||
verifier verification.Verifier
|
||||
unsetEnv func(key string) error
|
||||
}
|
||||
|
||||
func NewParamsReader(logger logging.Logger) ParamsReader {
|
||||
return ¶msReader{
|
||||
envParams: libparams.NewEnvParams(),
|
||||
logger: logger,
|
||||
verifier: verification.NewVerifier(),
|
||||
unsetEnv: os.Unsetenv,
|
||||
}
|
||||
}
|
||||
87
internal/params/pia.go
Normal file
87
internal/params/pia.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package params
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
|
||||
libparams "github.com/qdm12/golibs/params"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
// GetUser obtains the user to use to connect to the VPN servers
|
||||
func (p *paramsReader) GetUser() (s string, err error) {
|
||||
defer func() {
|
||||
unsetenvErr := p.unsetEnv("USER")
|
||||
if err == nil {
|
||||
err = unsetenvErr
|
||||
}
|
||||
}()
|
||||
s, err = p.envParams.GetEnv("USER")
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if len(s) == 0 {
|
||||
return s, fmt.Errorf("USER environment variable cannot be empty")
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// GetPassword obtains the password to use to connect to the VPN servers
|
||||
func (p *paramsReader) GetPassword() (s string, err error) {
|
||||
defer func() {
|
||||
unsetenvErr := p.unsetEnv("PASSWORD")
|
||||
if err == nil {
|
||||
err = unsetenvErr
|
||||
}
|
||||
}()
|
||||
s, err = p.envParams.GetEnv("PASSWORD")
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if len(s) == 0 {
|
||||
return s, fmt.Errorf("PASSWORD environment variable cannot be empty")
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// GetPortForwarding obtains if port forwarding on the VPN provider server
|
||||
// side is enabled or not from the environment variable PORT_FORWARDING
|
||||
func (p *paramsReader) GetPortForwarding() (activated bool, err error) {
|
||||
s, err := p.envParams.GetEnv("PORT_FORWARDING", libparams.Default("off"))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
// Custom for retro-compatibility
|
||||
if s == "false" || s == "off" {
|
||||
return false, nil
|
||||
} else if s == "true" || s == "on" {
|
||||
return true, nil
|
||||
}
|
||||
return false, fmt.Errorf("PORT_FORWARDING can only be \"on\" or \"off\"")
|
||||
}
|
||||
|
||||
// GetPortForwardingStatusFilepath obtains the port forwarding status file path
|
||||
// from the environment variable PORT_FORWARDING_STATUS_FILE
|
||||
func (p *paramsReader) GetPortForwardingStatusFilepath() (filepath models.Filepath, err error) {
|
||||
filepathStr, err := p.envParams.GetPath("PORT_FORWARDING_STATUS_FILE", libparams.Default("/forwarded_port"))
|
||||
return models.Filepath(filepathStr), err
|
||||
}
|
||||
|
||||
// GetPIAEncryption obtains the encryption level for the PIA connection
|
||||
// from the environment variable ENCRYPTION
|
||||
func (p *paramsReader) GetPIAEncryption() (models.PIAEncryption, error) {
|
||||
s, err := p.envParams.GetValueIfInside("ENCRYPTION", []string{"normal", "strong"}, libparams.Default("strong"))
|
||||
return models.PIAEncryption(s), err
|
||||
}
|
||||
|
||||
// GetPIARegion obtains the region for the PIA server from the
|
||||
// environment variable REGION
|
||||
func (p *paramsReader) GetPIARegion() (region models.PIARegion, err error) {
|
||||
choices := []string{
|
||||
string(constants.AUMelbourne), string(constants.AUPerth), string(constants.AUSydney), string(constants.Austria), string(constants.Belgium), string(constants.CAMontreal), string(constants.CAToronto), string(constants.CAVancouver), string(constants.CzechRepublic), string(constants.DEBerlin), string(constants.DEFrankfurt), string(constants.Denmark), string(constants.Finland), string(constants.France), string(constants.HongKong), string(constants.Hungary), string(constants.India), string(constants.Ireland), string(constants.Israel), string(constants.Italy), string(constants.Japan), string(constants.Luxembourg), string(constants.Mexico), string(constants.Netherlands), string(constants.NewZealand), string(constants.Norway), string(constants.Poland), string(constants.Romania), string(constants.Singapore), string(constants.Spain), string(constants.Sweden), string(constants.Switzerland), string(constants.UAE), string(constants.UKLondon), string(constants.UKManchester), string(constants.UKSouthampton), string(constants.USAtlanta), string(constants.USCalifornia), string(constants.USChicago), string(constants.USDenver), string(constants.USEast), string(constants.USFlorida), string(constants.USHouston), string(constants.USLasVegas), string(constants.USNewYorkCity), string(constants.USSeattle), string(constants.USSiliconValley), string(constants.USTexas), string(constants.USWashingtonDC), string(constants.USWest),
|
||||
}
|
||||
s, err := p.envParams.GetValueIfInside("REGION", choices)
|
||||
if len(s) == 0 { // Suggestion by @rorph https://github.com/rorph
|
||||
s = choices[rand.Int()%len(choices)]
|
||||
}
|
||||
return models.PIARegion(s), err
|
||||
}
|
||||
40
internal/params/shadowsocks.go
Normal file
40
internal/params/shadowsocks.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package params
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
libparams "github.com/qdm12/golibs/params"
|
||||
)
|
||||
|
||||
// GetShadowSocks obtains if ShadowSocks is on from the environment variable
|
||||
// SHADOWSOCKS
|
||||
func (p *paramsReader) GetShadowSocks() (activated bool, err error) {
|
||||
return p.envParams.GetOnOff("SHADOWSOCKS", libparams.Default("off"))
|
||||
}
|
||||
|
||||
// GetShadowSocksLog obtains the ShadowSocks log level from the environment variable
|
||||
// SHADOWSOCKS_LOG
|
||||
func (p *paramsReader) GetShadowSocksLog() (activated bool, err error) {
|
||||
return p.envParams.GetOnOff("SHADOWSOCKS_LOG", libparams.Default("off"))
|
||||
}
|
||||
|
||||
// GetShadowSocksPort obtains the ShadowSocks listening port from the environment variable
|
||||
// SHADOWSOCKS_PORT
|
||||
func (p *paramsReader) GetShadowSocksPort() (port uint16, err error) {
|
||||
portStr, err := p.envParams.GetEnv("SHADOWSOCKS_PORT", libparams.Default("8388"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := p.verifier.VerifyPort(portStr); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
portUint64, err := strconv.ParseUint(portStr, 10, 16)
|
||||
return uint16(portUint64), err
|
||||
}
|
||||
|
||||
// GetShadowSocksPassword obtains the ShadowSocks server password from the environment variable
|
||||
// SHADOWSOCKS_PASSWORD
|
||||
func (p *paramsReader) GetShadowSocksPassword() (password string, err error) {
|
||||
defer p.unsetEnv("SHADOWSOCKS_PASSWORD")
|
||||
return p.envParams.GetEnv("SHADOWSOCKS_PASSWORD")
|
||||
}
|
||||
94
internal/params/tinyproxy.go
Normal file
94
internal/params/tinyproxy.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package params
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
libparams "github.com/qdm12/golibs/params"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
// GetTinyProxy obtains if TinyProxy is on from the environment variable
|
||||
// TINYPROXY, and using PROXY as a retro-compatibility name
|
||||
func (p *paramsReader) GetTinyProxy() (activated bool, err error) {
|
||||
// Retro-compatibility
|
||||
s, err := p.envParams.GetEnv("PROXY")
|
||||
if err != nil {
|
||||
return false, err
|
||||
} else if len(s) != 0 {
|
||||
p.logger.Warn("You are using the old environment variable PROXY, please consider changing it to TINYPROXY")
|
||||
return p.envParams.GetOnOff("PROXY", libparams.Compulsory())
|
||||
}
|
||||
return p.envParams.GetOnOff("TINYPROXY", libparams.Default("off"))
|
||||
}
|
||||
|
||||
// GetTinyProxyLog obtains the TinyProxy log level from the environment variable
|
||||
// TINYPROXY_LOG, and using PROXY_LOG_LEVEL as a retro-compatibility name
|
||||
func (p *paramsReader) GetTinyProxyLog() (models.TinyProxyLogLevel, error) {
|
||||
// Retro-compatibility
|
||||
s, err := p.envParams.GetEnv("PROXY_LOG_LEVEL")
|
||||
if err != nil {
|
||||
return models.TinyProxyLogLevel(s), err
|
||||
} else if len(s) != 0 {
|
||||
p.logger.Warn("You are using the old environment variable PROXY_LOG_LEVEL, please consider changing it to TINYPROXY_LOG")
|
||||
s, err = p.envParams.GetValueIfInside("PROXY_LOG_LEVEL", []string{"info", "warning", "error", "critical"}, libparams.Compulsory())
|
||||
return models.TinyProxyLogLevel(s), err
|
||||
}
|
||||
s, err = p.envParams.GetValueIfInside("TINYPROXY_LOG", []string{"info", "warning", "error", "critical"}, libparams.Default("info"))
|
||||
return models.TinyProxyLogLevel(s), err
|
||||
}
|
||||
|
||||
// GetTinyProxyPort obtains the TinyProxy listening port from the environment variable
|
||||
// TINYPROXY_PORT, and using PROXY_PORT as a retro-compatibility name
|
||||
func (p *paramsReader) GetTinyProxyPort() (port uint16, err error) {
|
||||
// Retro-compatibility
|
||||
portStr, err := p.envParams.GetEnv("PROXY_PORT")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if len(portStr) != 0 {
|
||||
p.logger.Warn("You are using the old environment variable PROXY_PORT, please consider changing it to TINYPROXY_PORT")
|
||||
} else {
|
||||
portStr, err = p.envParams.GetEnv("TINYPROXY_PORT", libparams.Default("8888"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if err := p.verifier.VerifyPort(portStr); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
portUint64, err := strconv.ParseUint(portStr, 10, 16)
|
||||
return uint16(portUint64), err
|
||||
}
|
||||
|
||||
// GetTinyProxyUser obtains the TinyProxy server user from the environment variable
|
||||
// TINYPROXY_USER, and using PROXY_USER as a retro-compatibility name
|
||||
func (p *paramsReader) GetTinyProxyUser() (user string, err error) {
|
||||
defer p.unsetEnv("PROXY_USER")
|
||||
defer p.unsetEnv("TINYPROXY_USER")
|
||||
// Retro-compatibility
|
||||
user, err = p.envParams.GetEnv("PROXY_USER")
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
if len(user) != 0 {
|
||||
p.logger.Warn("You are using the old environment variable PROXY_USER, please consider changing it to TINYPROXY_USER")
|
||||
return user, nil
|
||||
}
|
||||
return p.envParams.GetEnv("TINYPROXY_USER")
|
||||
}
|
||||
|
||||
// GetTinyProxyPassword obtains the TinyProxy server password from the environment variable
|
||||
// TINYPROXY_PASSWORD, and using PROXY_PASSWORD as a retro-compatibility name
|
||||
func (p *paramsReader) GetTinyProxyPassword() (password string, err error) {
|
||||
defer p.unsetEnv("PROXY_PASSWORD")
|
||||
defer p.unsetEnv("TINYPROXY_PASSWORD")
|
||||
// Retro-compatibility
|
||||
password, err = p.envParams.GetEnv("PROXY_PASSWORD")
|
||||
if err != nil {
|
||||
return password, err
|
||||
}
|
||||
if len(password) != 0 {
|
||||
p.logger.Warn("You are using the old environment variable PROXY_PASSWORD, please consider changing it to TINYPROXY_PASSWORD")
|
||||
return password, nil
|
||||
}
|
||||
return p.envParams.GetEnv("TINYPROXY_PASSWORD")
|
||||
}
|
||||
20
internal/params/version.go
Normal file
20
internal/params/version.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package params
|
||||
|
||||
import (
|
||||
"github.com/qdm12/golibs/params"
|
||||
)
|
||||
|
||||
func (p *paramsReader) GetVersion() string {
|
||||
version, _ := p.envParams.GetEnv("VERSION", params.Default("?"))
|
||||
return version
|
||||
}
|
||||
|
||||
func (p *paramsReader) GetBuildDate() string {
|
||||
buildDate, _ := p.envParams.GetEnv("BUILD_DATE", params.Default("?"))
|
||||
return buildDate
|
||||
}
|
||||
|
||||
func (p *paramsReader) GetVcsRef() string {
|
||||
buildDate, _ := p.envParams.GetEnv("VCS_REF", params.Default("?"))
|
||||
return buildDate
|
||||
}
|
||||
61
internal/pia/download.go
Normal file
61
internal/pia/download.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package pia
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
func (c *configurator) DownloadOvpnConfig(encryption models.PIAEncryption,
|
||||
protocol models.NetworkProtocol, region models.PIARegion) (lines []string, err error) {
|
||||
c.logger.Info("%s: downloading openvpn configuration files", logPrefix)
|
||||
URL := buildZipURL(encryption, protocol)
|
||||
content, status, err := c.client.GetContent(URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != 200 {
|
||||
return nil, fmt.Errorf("HTTP Get %s resulted in HTTP status code %d", URL, status)
|
||||
}
|
||||
filename := fmt.Sprintf("%s.ovpn", region)
|
||||
fileContent, err := getFileInZip(content, filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: %w", URL, err)
|
||||
}
|
||||
lines = strings.Split(string(fileContent), "\n")
|
||||
return lines, nil
|
||||
}
|
||||
|
||||
func buildZipURL(encryption models.PIAEncryption, protocol models.NetworkProtocol) (URL string) {
|
||||
URL = string(constants.PIAOpenVPNURL) + "/openvpn"
|
||||
if encryption == constants.PIAEncryptionStrong {
|
||||
URL += "-strong"
|
||||
}
|
||||
if protocol == constants.TCP {
|
||||
URL += "-tcp"
|
||||
}
|
||||
return URL + ".zip"
|
||||
}
|
||||
|
||||
func getFileInZip(zipContent []byte, filename string) (fileContent []byte, err error) {
|
||||
contentLength := int64(len(zipContent))
|
||||
r, err := zip.NewReader(bytes.NewReader(zipContent), contentLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, f := range r.File {
|
||||
if f.Name == filename {
|
||||
readCloser, err := f.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer readCloser.Close()
|
||||
return ioutil.ReadAll(readCloser)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("%s not found in zip archive file", filename)
|
||||
}
|
||||
31
internal/pia/modify.go
Normal file
31
internal/pia/modify.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package pia
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) ModifyLines(lines []string, IPs []net.IP, port uint16) (modifiedLines []string) {
|
||||
c.logger.Info("%s: adapting openvpn configuration for server IP addresses and port %d", logPrefix, port)
|
||||
// Remove lines
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, "privateinternetaccess.com") ||
|
||||
strings.Contains(line, "resolv-retry") {
|
||||
continue
|
||||
}
|
||||
modifiedLines = append(modifiedLines, line)
|
||||
}
|
||||
// Add lines
|
||||
for _, IP := range IPs {
|
||||
modifiedLines = append(modifiedLines, fmt.Sprintf("remote %s %d", IP.String(), port))
|
||||
}
|
||||
modifiedLines = append(modifiedLines, "auth-user-pass "+string(constants.OpenVPNAuthConf))
|
||||
modifiedLines = append(modifiedLines, "auth-retry nointeract")
|
||||
modifiedLines = append(modifiedLines, "pull-filter ignore \"auth-token\"") // prevent auth failed loops
|
||||
modifiedLines = append(modifiedLines, "user nonrootuser")
|
||||
modifiedLines = append(modifiedLines, "mute-replay-warnings")
|
||||
return modifiedLines
|
||||
}
|
||||
31
internal/pia/modify_test.go
Normal file
31
internal/pia/modify_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package pia
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
loggingMocks "github.com/qdm12/golibs/logging/mocks"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ModifyLines(t *testing.T) {
|
||||
t.Parallel()
|
||||
original, err := ioutil.ReadFile("testdata/ovpn.golden")
|
||||
require.NoError(t, err)
|
||||
originalLines := strings.Split(string(original), "\n")
|
||||
expected, err := ioutil.ReadFile("testdata/ovpn.modified.golden")
|
||||
require.NoError(t, err)
|
||||
expectedLines := strings.Split(string(expected), "\n")
|
||||
|
||||
var port uint16 = 3000
|
||||
IPs := []net.IP{net.IP{100, 10, 10, 10}, net.IP{100, 20, 20, 20}}
|
||||
logger := &loggingMocks.Logger{}
|
||||
logger.On("Info", "%s: adapting openvpn configuration for server IP addresses and port %d", logPrefix, port).Once()
|
||||
c := &configurator{logger: logger}
|
||||
modifiedLines := c.ModifyLines(originalLines, IPs, port)
|
||||
assert.Equal(t, expectedLines, modifiedLines)
|
||||
logger.AssertExpectations(t)
|
||||
}
|
||||
54
internal/pia/parse.go
Normal file
54
internal/pia/parse.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package pia
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
func (c *configurator) ParseConfig(lines []string) (IPs []net.IP, port uint16, device models.VPNDevice, err error) {
|
||||
c.logger.Info("%s: parsing openvpn configuration", logPrefix)
|
||||
remoteLineFound := false
|
||||
deviceLineFound := false
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(line, "remote ") {
|
||||
remoteLineFound = true
|
||||
words := strings.Fields(line)
|
||||
if len(words) != 3 {
|
||||
return nil, 0, "", fmt.Errorf("line %q misses information", line)
|
||||
}
|
||||
host := words[1]
|
||||
if err := c.verifyPort(words[2]); err != nil {
|
||||
return nil, 0, "", fmt.Errorf("line %q has an invalid port: %w", line, err)
|
||||
}
|
||||
portUint64, _ := strconv.ParseUint(words[2], 10, 16)
|
||||
port = uint16(portUint64)
|
||||
IPs, err = c.lookupIP(host)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
} else if strings.HasPrefix(line, "dev ") {
|
||||
deviceLineFound = true
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) != 2 {
|
||||
return nil, 0, "", fmt.Errorf("line %q misses information", line)
|
||||
}
|
||||
device = models.VPNDevice(fields[1] + "0")
|
||||
if device != constants.TUN && device != constants.TAP {
|
||||
return nil, 0, "", fmt.Errorf("device %q is not valid", device)
|
||||
}
|
||||
}
|
||||
}
|
||||
if remoteLineFound && deviceLineFound {
|
||||
c.logger.Info("%s: Found %d PIA server IP addresses, port %d and device %s", logPrefix, len(IPs), port, device)
|
||||
return IPs, port, device, nil
|
||||
} else if !remoteLineFound {
|
||||
return nil, 0, "", fmt.Errorf("remote line not found in Openvpn configuration")
|
||||
} else {
|
||||
return nil, 0, "", fmt.Errorf("device line not found in Openvpn configuration")
|
||||
}
|
||||
}
|
||||
99
internal/pia/parse_test.go
Normal file
99
internal/pia/parse_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package pia
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
loggingMocks "github.com/qdm12/golibs/logging/mocks"
|
||||
"github.com/qdm12/golibs/verification"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_ParseConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
original, err := ioutil.ReadFile("testdata/ovpn.golden")
|
||||
require.NoError(t, err)
|
||||
exampleLines := strings.Split(string(original), "\n")
|
||||
tests := map[string]struct {
|
||||
lines []string
|
||||
lookupIPs []net.IP
|
||||
lookupIPErr error
|
||||
IPs []net.IP
|
||||
port uint16
|
||||
device models.VPNDevice
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
err: fmt.Errorf("remote line not found in Openvpn configuration"),
|
||||
},
|
||||
"bad remote line": {
|
||||
lines: []string{"remote field2"},
|
||||
err: fmt.Errorf("line \"remote field2\" misses information"),
|
||||
},
|
||||
"bad remote port": {
|
||||
lines: []string{"remote field2 port"},
|
||||
err: fmt.Errorf("line \"remote field2 port\" has an invalid port: port \"port\" is not a valid integer"),
|
||||
},
|
||||
"lookupIP error": {
|
||||
lines: []string{"remote host 1000"},
|
||||
lookupIPErr: fmt.Errorf("lookup error"),
|
||||
err: fmt.Errorf("lookup error"),
|
||||
},
|
||||
"missing dev line": {
|
||||
lines: []string{"remote host 1994"},
|
||||
err: fmt.Errorf("device line not found in Openvpn configuration"),
|
||||
},
|
||||
"bad dev line": {
|
||||
lines: []string{"dev field2 field3"},
|
||||
err: fmt.Errorf("line \"dev field2 field3\" misses information"),
|
||||
},
|
||||
"bad device": {
|
||||
lines: []string{"dev xx"},
|
||||
err: fmt.Errorf("device \"xx0\" is not valid"),
|
||||
},
|
||||
"valid lines": {
|
||||
lines: []string{"remote host 1194", "dev tap", "blabla"},
|
||||
port: 1194,
|
||||
device: constants.TAP,
|
||||
},
|
||||
"real data": {
|
||||
lines: exampleLines,
|
||||
lookupIPs: []net.IP{{100, 100, 100, 100}, {100, 100, 200, 200}},
|
||||
IPs: []net.IP{{100, 100, 100, 100}, {100, 100, 200, 200}},
|
||||
port: 1198,
|
||||
device: constants.TUN,
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := &loggingMocks.Logger{}
|
||||
logger.On("Info", "%s: parsing openvpn configuration", logPrefix).Once()
|
||||
if tc.err == nil {
|
||||
logger.On("Info", "%s: Found %d PIA server IP addresses, port %d and device %s", logPrefix, len(tc.IPs), tc.port, tc.device).Once()
|
||||
}
|
||||
lookupIP := func(host string) ([]net.IP, error) {
|
||||
return tc.lookupIPs, tc.lookupIPErr
|
||||
}
|
||||
c := &configurator{logger: logger, verifyPort: verification.NewVerifier().VerifyPort, lookupIP: lookupIP}
|
||||
IPs, port, device, err := c.ParseConfig(tc.lines)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.IPs, IPs)
|
||||
assert.Equal(t, tc.port, port)
|
||||
assert.Equal(t, tc.device, device)
|
||||
logger.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
41
internal/pia/pia.go
Normal file
41
internal/pia/pia.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package pia
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/golibs/crypto/random"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/golibs/verification"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
const logPrefix = "PIA configurator"
|
||||
|
||||
// Configurator contains methods to download, read and modify the openvpn configuration to connect as a client
|
||||
type Configurator interface {
|
||||
DownloadOvpnConfig(encryption models.PIAEncryption,
|
||||
protocol models.NetworkProtocol, region models.PIARegion) (lines []string, err error)
|
||||
ParseConfig(lines []string) (IPs []net.IP, port uint16, device models.VPNDevice, err error)
|
||||
ModifyLines(lines []string, IPs []net.IP, port uint16) (modifiedLines []string)
|
||||
GetPortForward() (port uint16, err error)
|
||||
WritePortForward(filepath models.Filepath, port uint16) (err error)
|
||||
AllowPortForwardFirewall(device models.VPNDevice, port uint16) (err error)
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
client network.Client
|
||||
fileManager files.FileManager
|
||||
firewall firewall.Configurator
|
||||
logger logging.Logger
|
||||
random random.Random
|
||||
verifyPort func(port string) error
|
||||
lookupIP func(host string) ([]net.IP, error)
|
||||
}
|
||||
|
||||
// NewConfigurator returns a new Configurator object
|
||||
func NewConfigurator(client network.Client, fileManager files.FileManager, firewall firewall.Configurator, logger logging.Logger) Configurator {
|
||||
return &configurator{client, fileManager, firewall, logger, random.NewRandom(), verification.NewVerifier().VerifyPort, net.LookupIP}
|
||||
}
|
||||
46
internal/pia/portforward.go
Normal file
46
internal/pia/portforward.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package pia
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
func (c *configurator) GetPortForward() (port uint16, err error) {
|
||||
c.logger.Info("%s: Obtaining port to be forwarded", logPrefix)
|
||||
b, err := c.random.GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
clientID := hex.EncodeToString(b)
|
||||
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
|
||||
content, status, err := c.client.GetContent(url)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if status != 200 {
|
||||
return 0, fmt.Errorf("status is %d for %s; does your PIA server support port forwarding?", status, url)
|
||||
} else if len(content) == 0 {
|
||||
return 0, fmt.Errorf("port forwarding is already activated on this connection, has expired, or you are not connected to a PIA region that supports port forwarding")
|
||||
}
|
||||
body := struct {
|
||||
Port uint16 `json:"port"`
|
||||
}{}
|
||||
if err := json.Unmarshal(content, &body); err != nil {
|
||||
return 0, fmt.Errorf("port forwarding response: %w", err)
|
||||
}
|
||||
c.logger.Info("%s: Port forwarded is %d", logPrefix, port)
|
||||
return body.Port, nil
|
||||
}
|
||||
|
||||
func (c *configurator) WritePortForward(filepath models.Filepath, port uint16) (err error) {
|
||||
c.logger.Info("%s: Writing forwarded port to %s", logPrefix, filepath)
|
||||
return c.fileManager.WriteLinesToFile(string(filepath), []string{fmt.Sprintf("%d", port)})
|
||||
}
|
||||
|
||||
func (c *configurator) AllowPortForwardFirewall(device models.VPNDevice, port uint16) (err error) {
|
||||
c.logger.Info("%s: Allowing forwarded port %d through firewall", logPrefix, port)
|
||||
return c.firewall.AllowInputTrafficOnPort(device, port)
|
||||
}
|
||||
72
internal/pia/testdata/ovpn.golden
vendored
Normal file
72
internal/pia/testdata/ovpn.golden
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
client
|
||||
dev tun
|
||||
proto udp
|
||||
remote belgium.privateinternetaccess.com 1198
|
||||
resolv-retry infinite
|
||||
nobind
|
||||
persist-key
|
||||
persist-tun
|
||||
cipher aes-128-cbc
|
||||
auth sha1
|
||||
tls-client
|
||||
remote-cert-tls server
|
||||
|
||||
auth-user-pass
|
||||
compress
|
||||
verb 1
|
||||
reneg-sec 0
|
||||
<crl-verify>
|
||||
-----BEGIN X509 CRL-----
|
||||
MIICWDCCAUAwDQYJKoZIhvcNAQENBQAwgegxCzAJBgNVBAYTAlVTMQswCQYDVQQI
|
||||
EwJDQTETMBEGA1UEBxMKTG9zQW5nZWxlczEgMB4GA1UEChMXUHJpdmF0ZSBJbnRl
|
||||
cm5ldCBBY2Nlc3MxIDAeBgNVBAsTF1ByaXZhdGUgSW50ZXJuZXQgQWNjZXNzMSAw
|
||||
HgYDVQQDExdQcml2YXRlIEludGVybmV0IEFjY2VzczEgMB4GA1UEKRMXUHJpdmF0
|
||||
ZSBJbnRlcm5ldCBBY2Nlc3MxLzAtBgkqhkiG9w0BCQEWIHNlY3VyZUBwcml2YXRl
|
||||
aW50ZXJuZXRhY2Nlc3MuY29tFw0xNjA3MDgxOTAwNDZaFw0zNjA3MDMxOTAwNDZa
|
||||
MCYwEQIBARcMMTYwNzA4MTkwMDQ2MBECAQYXDDE2MDcwODE5MDA0NjANBgkqhkiG
|
||||
9w0BAQ0FAAOCAQEAQZo9X97ci8EcPYu/uK2HB152OZbeZCINmYyluLDOdcSvg6B5
|
||||
jI+ffKN3laDvczsG6CxmY3jNyc79XVpEYUnq4rT3FfveW1+Ralf+Vf38HdpwB8EW
|
||||
B4hZlQ205+21CALLvZvR8HcPxC9KEnev1mU46wkTiov0EKc+EdRxkj5yMgv0V2Re
|
||||
ze7AP+NQ9ykvDScH4eYCsmufNpIjBLhpLE2cuZZXBLcPhuRzVoU3l7A9lvzG9mjA
|
||||
5YijHJGHNjlWFqyrn1CfYS6koa4TGEPngBoAziWRbDGdhEgJABHrpoaFYaL61zqy
|
||||
MR6jC0K2ps9qyZAN74LEBedEfK7tBOzWMwr58A==
|
||||
-----END X509 CRL-----
|
||||
</crl-verify>
|
||||
|
||||
<ca>
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIFqzCCBJOgAwIBAgIJAKZ7D5Yv87qDMA0GCSqGSIb3DQEBDQUAMIHoMQswCQYD
|
||||
VQQGEwJVUzELMAkGA1UECBMCQ0ExEzARBgNVBAcTCkxvc0FuZ2VsZXMxIDAeBgNV
|
||||
BAoTF1ByaXZhdGUgSW50ZXJuZXQgQWNjZXNzMSAwHgYDVQQLExdQcml2YXRlIElu
|
||||
dGVybmV0IEFjY2VzczEgMB4GA1UEAxMXUHJpdmF0ZSBJbnRlcm5ldCBBY2Nlc3Mx
|
||||
IDAeBgNVBCkTF1ByaXZhdGUgSW50ZXJuZXQgQWNjZXNzMS8wLQYJKoZIhvcNAQkB
|
||||
FiBzZWN1cmVAcHJpdmF0ZWludGVybmV0YWNjZXNzLmNvbTAeFw0xNDA0MTcxNzM1
|
||||
MThaFw0zNDA0MTIxNzM1MThaMIHoMQswCQYDVQQGEwJVUzELMAkGA1UECBMCQ0Ex
|
||||
EzARBgNVBAcTCkxvc0FuZ2VsZXMxIDAeBgNVBAoTF1ByaXZhdGUgSW50ZXJuZXQg
|
||||
QWNjZXNzMSAwHgYDVQQLExdQcml2YXRlIEludGVybmV0IEFjY2VzczEgMB4GA1UE
|
||||
AxMXUHJpdmF0ZSBJbnRlcm5ldCBBY2Nlc3MxIDAeBgNVBCkTF1ByaXZhdGUgSW50
|
||||
ZXJuZXQgQWNjZXNzMS8wLQYJKoZIhvcNAQkBFiBzZWN1cmVAcHJpdmF0ZWludGVy
|
||||
bmV0YWNjZXNzLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAPXD
|
||||
L1L9tX6DGf36liA7UBTy5I869z0UVo3lImfOs/GSiFKPtInlesP65577nd7UNzzX
|
||||
lH/P/CnFPdBWlLp5ze3HRBCc/Avgr5CdMRkEsySL5GHBZsx6w2cayQ2EcRhVTwWp
|
||||
cdldeNO+pPr9rIgPrtXqT4SWViTQRBeGM8CDxAyTopTsobjSiYZCF9Ta1gunl0G/
|
||||
8Vfp+SXfYCC+ZzWvP+L1pFhPRqzQQ8k+wMZIovObK1s+nlwPaLyayzw9a8sUnvWB
|
||||
/5rGPdIYnQWPgoNlLN9HpSmsAcw2z8DXI9pIxbr74cb3/HSfuYGOLkRqrOk6h4RC
|
||||
OfuWoTrZup1uEOn+fw8CAwEAAaOCAVQwggFQMB0GA1UdDgQWBBQv63nQ/pJAt5tL
|
||||
y8VJcbHe22ZOsjCCAR8GA1UdIwSCARYwggESgBQv63nQ/pJAt5tLy8VJcbHe22ZO
|
||||
sqGB7qSB6zCB6DELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAkNBMRMwEQYDVQQHEwpM
|
||||
b3NBbmdlbGVzMSAwHgYDVQQKExdQcml2YXRlIEludGVybmV0IEFjY2VzczEgMB4G
|
||||
A1UECxMXUHJpdmF0ZSBJbnRlcm5ldCBBY2Nlc3MxIDAeBgNVBAMTF1ByaXZhdGUg
|
||||
SW50ZXJuZXQgQWNjZXNzMSAwHgYDVQQpExdQcml2YXRlIEludGVybmV0IEFjY2Vz
|
||||
czEvMC0GCSqGSIb3DQEJARYgc2VjdXJlQHByaXZhdGVpbnRlcm5ldGFjY2Vzcy5j
|
||||
b22CCQCmew+WL/O6gzAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBDQUAA4IBAQAn
|
||||
a5PgrtxfwTumD4+3/SYvwoD66cB8IcK//h1mCzAduU8KgUXocLx7QgJWo9lnZ8xU
|
||||
ryXvWab2usg4fqk7FPi00bED4f4qVQFVfGfPZIH9QQ7/48bPM9RyfzImZWUCenK3
|
||||
7pdw4Bvgoys2rHLHbGen7f28knT2j/cbMxd78tQc20TIObGjo8+ISTRclSTRBtyC
|
||||
GohseKYpTS9himFERpUgNtefvYHbn70mIOzfOJFTVqfrptf9jXa9N8Mpy3ayfodz
|
||||
1wiqdteqFXkTYoSDctgKMiZ6GdocK9nMroQipIQtpnwd4yBDWIyC6Bvlkrq5TQUt
|
||||
YDQ8z9v+DMO6iwyIDRiU
|
||||
-----END CERTIFICATE-----
|
||||
</ca>
|
||||
|
||||
disable-occ
|
||||
78
internal/pia/testdata/ovpn.modified.golden
vendored
Normal file
78
internal/pia/testdata/ovpn.modified.golden
vendored
Normal file
@@ -0,0 +1,78 @@
|
||||
client
|
||||
dev tun
|
||||
proto udp
|
||||
nobind
|
||||
persist-key
|
||||
persist-tun
|
||||
cipher aes-128-cbc
|
||||
auth sha1
|
||||
tls-client
|
||||
remote-cert-tls server
|
||||
|
||||
auth-user-pass
|
||||
compress
|
||||
verb 1
|
||||
reneg-sec 0
|
||||
<crl-verify>
|
||||
-----BEGIN X509 CRL-----
|
||||
MIICWDCCAUAwDQYJKoZIhvcNAQENBQAwgegxCzAJBgNVBAYTAlVTMQswCQYDVQQI
|
||||
EwJDQTETMBEGA1UEBxMKTG9zQW5nZWxlczEgMB4GA1UEChMXUHJpdmF0ZSBJbnRl
|
||||
cm5ldCBBY2Nlc3MxIDAeBgNVBAsTF1ByaXZhdGUgSW50ZXJuZXQgQWNjZXNzMSAw
|
||||
HgYDVQQDExdQcml2YXRlIEludGVybmV0IEFjY2VzczEgMB4GA1UEKRMXUHJpdmF0
|
||||
ZSBJbnRlcm5ldCBBY2Nlc3MxLzAtBgkqhkiG9w0BCQEWIHNlY3VyZUBwcml2YXRl
|
||||
aW50ZXJuZXRhY2Nlc3MuY29tFw0xNjA3MDgxOTAwNDZaFw0zNjA3MDMxOTAwNDZa
|
||||
MCYwEQIBARcMMTYwNzA4MTkwMDQ2MBECAQYXDDE2MDcwODE5MDA0NjANBgkqhkiG
|
||||
9w0BAQ0FAAOCAQEAQZo9X97ci8EcPYu/uK2HB152OZbeZCINmYyluLDOdcSvg6B5
|
||||
jI+ffKN3laDvczsG6CxmY3jNyc79XVpEYUnq4rT3FfveW1+Ralf+Vf38HdpwB8EW
|
||||
B4hZlQ205+21CALLvZvR8HcPxC9KEnev1mU46wkTiov0EKc+EdRxkj5yMgv0V2Re
|
||||
ze7AP+NQ9ykvDScH4eYCsmufNpIjBLhpLE2cuZZXBLcPhuRzVoU3l7A9lvzG9mjA
|
||||
5YijHJGHNjlWFqyrn1CfYS6koa4TGEPngBoAziWRbDGdhEgJABHrpoaFYaL61zqy
|
||||
MR6jC0K2ps9qyZAN74LEBedEfK7tBOzWMwr58A==
|
||||
-----END X509 CRL-----
|
||||
</crl-verify>
|
||||
|
||||
<ca>
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIFqzCCBJOgAwIBAgIJAKZ7D5Yv87qDMA0GCSqGSIb3DQEBDQUAMIHoMQswCQYD
|
||||
VQQGEwJVUzELMAkGA1UECBMCQ0ExEzARBgNVBAcTCkxvc0FuZ2VsZXMxIDAeBgNV
|
||||
BAoTF1ByaXZhdGUgSW50ZXJuZXQgQWNjZXNzMSAwHgYDVQQLExdQcml2YXRlIElu
|
||||
dGVybmV0IEFjY2VzczEgMB4GA1UEAxMXUHJpdmF0ZSBJbnRlcm5ldCBBY2Nlc3Mx
|
||||
IDAeBgNVBCkTF1ByaXZhdGUgSW50ZXJuZXQgQWNjZXNzMS8wLQYJKoZIhvcNAQkB
|
||||
FiBzZWN1cmVAcHJpdmF0ZWludGVybmV0YWNjZXNzLmNvbTAeFw0xNDA0MTcxNzM1
|
||||
MThaFw0zNDA0MTIxNzM1MThaMIHoMQswCQYDVQQGEwJVUzELMAkGA1UECBMCQ0Ex
|
||||
EzARBgNVBAcTCkxvc0FuZ2VsZXMxIDAeBgNVBAoTF1ByaXZhdGUgSW50ZXJuZXQg
|
||||
QWNjZXNzMSAwHgYDVQQLExdQcml2YXRlIEludGVybmV0IEFjY2VzczEgMB4GA1UE
|
||||
AxMXUHJpdmF0ZSBJbnRlcm5ldCBBY2Nlc3MxIDAeBgNVBCkTF1ByaXZhdGUgSW50
|
||||
ZXJuZXQgQWNjZXNzMS8wLQYJKoZIhvcNAQkBFiBzZWN1cmVAcHJpdmF0ZWludGVy
|
||||
bmV0YWNjZXNzLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAPXD
|
||||
L1L9tX6DGf36liA7UBTy5I869z0UVo3lImfOs/GSiFKPtInlesP65577nd7UNzzX
|
||||
lH/P/CnFPdBWlLp5ze3HRBCc/Avgr5CdMRkEsySL5GHBZsx6w2cayQ2EcRhVTwWp
|
||||
cdldeNO+pPr9rIgPrtXqT4SWViTQRBeGM8CDxAyTopTsobjSiYZCF9Ta1gunl0G/
|
||||
8Vfp+SXfYCC+ZzWvP+L1pFhPRqzQQ8k+wMZIovObK1s+nlwPaLyayzw9a8sUnvWB
|
||||
/5rGPdIYnQWPgoNlLN9HpSmsAcw2z8DXI9pIxbr74cb3/HSfuYGOLkRqrOk6h4RC
|
||||
OfuWoTrZup1uEOn+fw8CAwEAAaOCAVQwggFQMB0GA1UdDgQWBBQv63nQ/pJAt5tL
|
||||
y8VJcbHe22ZOsjCCAR8GA1UdIwSCARYwggESgBQv63nQ/pJAt5tLy8VJcbHe22ZO
|
||||
sqGB7qSB6zCB6DELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAkNBMRMwEQYDVQQHEwpM
|
||||
b3NBbmdlbGVzMSAwHgYDVQQKExdQcml2YXRlIEludGVybmV0IEFjY2VzczEgMB4G
|
||||
A1UECxMXUHJpdmF0ZSBJbnRlcm5ldCBBY2Nlc3MxIDAeBgNVBAMTF1ByaXZhdGUg
|
||||
SW50ZXJuZXQgQWNjZXNzMSAwHgYDVQQpExdQcml2YXRlIEludGVybmV0IEFjY2Vz
|
||||
czEvMC0GCSqGSIb3DQEJARYgc2VjdXJlQHByaXZhdGVpbnRlcm5ldGFjY2Vzcy5j
|
||||
b22CCQCmew+WL/O6gzAMBgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBDQUAA4IBAQAn
|
||||
a5PgrtxfwTumD4+3/SYvwoD66cB8IcK//h1mCzAduU8KgUXocLx7QgJWo9lnZ8xU
|
||||
ryXvWab2usg4fqk7FPi00bED4f4qVQFVfGfPZIH9QQ7/48bPM9RyfzImZWUCenK3
|
||||
7pdw4Bvgoys2rHLHbGen7f28knT2j/cbMxd78tQc20TIObGjo8+ISTRclSTRBtyC
|
||||
GohseKYpTS9himFERpUgNtefvYHbn70mIOzfOJFTVqfrptf9jXa9N8Mpy3ayfodz
|
||||
1wiqdteqFXkTYoSDctgKMiZ6GdocK9nMroQipIQtpnwd4yBDWIyC6Bvlkrq5TQUt
|
||||
YDQ8z9v+DMO6iwyIDRiU
|
||||
-----END CERTIFICATE-----
|
||||
</ca>
|
||||
|
||||
disable-occ
|
||||
|
||||
remote 100.10.10.10 3000
|
||||
remote 100.20.20.20 3000
|
||||
auth-user-pass /etc/openvpn/auth.conf
|
||||
auth-retry nointeract
|
||||
pull-filter ignore "auth-token"
|
||||
user nonrootuser
|
||||
mute-replay-warnings
|
||||
108
internal/settings/dns.go
Normal file
108
internal/settings/dns.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/params"
|
||||
)
|
||||
|
||||
// DNS contains settings to configure Unbound for DNS over TLS operation
|
||||
type DNS struct {
|
||||
Enabled bool
|
||||
Providers []models.DNSProvider
|
||||
AllowedHostnames []string
|
||||
PrivateAddresses []string
|
||||
BlockMalicious bool
|
||||
BlockSurveillance bool
|
||||
BlockAds bool
|
||||
VerbosityLevel uint8
|
||||
VerbosityDetailsLevel uint8
|
||||
ValidationLogLevel uint8
|
||||
}
|
||||
|
||||
func (d *DNS) String() string {
|
||||
if !d.Enabled {
|
||||
return "DNS over TLS settings: disabled"
|
||||
}
|
||||
blockMalicious, blockSurveillance, blockAds := "disabed", "disabed", "disabed"
|
||||
if d.BlockMalicious {
|
||||
blockMalicious = "enabled"
|
||||
}
|
||||
if d.BlockSurveillance {
|
||||
blockSurveillance = "enabled"
|
||||
}
|
||||
if d.BlockAds {
|
||||
blockAds = "enabled"
|
||||
}
|
||||
var providersStr []string
|
||||
for _, provider := range d.Providers {
|
||||
providersStr = append(providersStr, string(provider))
|
||||
}
|
||||
settingsList := []string{
|
||||
"DNS over TLS settings:",
|
||||
"DNS over TLS provider: \n |--" + strings.Join(providersStr, "\n |--"),
|
||||
"Block malicious: " + blockMalicious,
|
||||
"Block surveillance: " + blockSurveillance,
|
||||
"Block ads: " + blockAds,
|
||||
"Allowed hostnames: " + strings.Join(d.AllowedHostnames, ", "),
|
||||
"Private addresses:\n |--" + strings.Join(d.PrivateAddresses, "\n |--"),
|
||||
"Verbosity level: " + fmt.Sprintf("%d/5", d.VerbosityLevel),
|
||||
"Verbosity details level: " + fmt.Sprintf("%d/4", d.VerbosityDetailsLevel),
|
||||
"Validation log level: " + fmt.Sprintf("%d/2", d.ValidationLogLevel),
|
||||
}
|
||||
return strings.Join(settingsList, "\n |--")
|
||||
}
|
||||
|
||||
// GetDNSSettings obtains DNS over TLS settings from environment variables using the params package.
|
||||
func GetDNSSettings(params params.ParamsReader) (settings DNS, err error) {
|
||||
settings.Enabled, err = params.GetDNSOverTLS()
|
||||
if err != nil || !settings.Enabled {
|
||||
return settings, err
|
||||
}
|
||||
settings.Providers, err = params.GetDNSOverTLSProviders()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.AllowedHostnames, err = params.GetDNSUnblockedHostnames()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.BlockMalicious, err = params.GetDNSMaliciousBlocking()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.BlockSurveillance, err = params.GetDNSSurveillanceBlocking()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.BlockAds, err = params.GetDNSAdsBlocking()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.VerbosityLevel, err = params.GetDNSOverTLSVerbosity()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.VerbosityDetailsLevel, err = params.GetDNSOverTLSVerbosityDetails()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.ValidationLogLevel, err = params.GetDNSOverTLSValidationLogLevel()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.PrivateAddresses = []string{ // TODO make env variable
|
||||
"127.0.0.1/8",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"169.254.0.0/16",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
"fe80::/10",
|
||||
"::ffff:0:0/96",
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
34
internal/settings/firewall.go
Normal file
34
internal/settings/firewall.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/params"
|
||||
)
|
||||
|
||||
// Firewall contains settings to customize the firewall operation
|
||||
type Firewall struct {
|
||||
AllowedSubnets []net.IPNet
|
||||
}
|
||||
|
||||
func (f *Firewall) String() string {
|
||||
var allowedSubnets []string
|
||||
for _, net := range f.AllowedSubnets {
|
||||
allowedSubnets = append(allowedSubnets, net.String())
|
||||
}
|
||||
settingsList := []string{
|
||||
"Firewall settings:",
|
||||
"Allowed subnets: " + strings.Join(allowedSubnets, ", "),
|
||||
}
|
||||
return strings.Join(settingsList, "\n |--")
|
||||
}
|
||||
|
||||
// GetFirewallSettings obtains firewall settings from environment variables using the params package.
|
||||
func GetFirewallSettings(params params.ParamsReader) (settings Firewall, err error) {
|
||||
settings.AllowedSubnets, err = params.GetExtraSubnets()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
30
internal/settings/openvpn.go
Normal file
30
internal/settings/openvpn.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/params"
|
||||
)
|
||||
|
||||
// OpenVPN contains settings to configure the OpenVPN client
|
||||
type OpenVPN struct {
|
||||
NetworkProtocol models.NetworkProtocol
|
||||
}
|
||||
|
||||
// GetOpenVPNSettings obtains the OpenVPN settings using the params functions
|
||||
func GetOpenVPNSettings(params params.ParamsReader) (settings OpenVPN, err error) {
|
||||
settings.NetworkProtocol, err = params.GetNetworkProtocol()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func (o *OpenVPN) String() string {
|
||||
settingsList := []string{
|
||||
"OpenVPN settings:",
|
||||
"Network protocol: " + string(o.NetworkProtocol),
|
||||
}
|
||||
return strings.Join(settingsList, "\n|--")
|
||||
}
|
||||
72
internal/settings/pia.go
Normal file
72
internal/settings/pia.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/params"
|
||||
)
|
||||
|
||||
// PIA contains the settings to connect to a PIA server
|
||||
type PIA struct {
|
||||
User string
|
||||
Password string
|
||||
Encryption models.PIAEncryption
|
||||
Region models.PIARegion
|
||||
PortForwarding PortForwarding
|
||||
}
|
||||
|
||||
// PortForwarding contains settings for port forwarding
|
||||
type PortForwarding struct {
|
||||
Enabled bool
|
||||
Filepath models.Filepath
|
||||
}
|
||||
|
||||
func (p *PortForwarding) String() string {
|
||||
if p.Enabled {
|
||||
return fmt.Sprintf("on, saved in %s", p.Filepath)
|
||||
}
|
||||
return "off"
|
||||
}
|
||||
|
||||
func (p *PIA) String() string {
|
||||
settingsList := []string{
|
||||
"PIA settings:",
|
||||
"Region: " + string(p.Region),
|
||||
"Encryption: " + string(p.Encryption),
|
||||
"Port forwarding: " + p.PortForwarding.String(),
|
||||
}
|
||||
return strings.Join(settingsList, "\n |--")
|
||||
}
|
||||
|
||||
// GetPIASettings obtains PIA settings from environment variables using the params package.
|
||||
func GetPIASettings(params params.ParamsReader) (settings PIA, err error) {
|
||||
settings.User, err = params.GetUser()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.Password, err = params.GetPassword()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.Encryption, err = params.GetPIAEncryption()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.Region, err = params.GetPIARegion()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.PortForwarding.Enabled, err = params.GetPortForwarding()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
if settings.PortForwarding.Enabled {
|
||||
settings.PortForwarding.Filepath, err = params.GetPortForwardingStatusFilepath()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
60
internal/settings/settings.go
Normal file
60
internal/settings/settings.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/params"
|
||||
)
|
||||
|
||||
// Settings contains all settings for the program to run
|
||||
type Settings struct {
|
||||
OpenVPN OpenVPN
|
||||
PIA PIA
|
||||
DNS DNS
|
||||
Firewall Firewall
|
||||
TinyProxy TinyProxy
|
||||
ShadowSocks ShadowSocks
|
||||
}
|
||||
|
||||
func (s *Settings) String() string {
|
||||
return strings.Join([]string{
|
||||
"Settings summary below:",
|
||||
s.OpenVPN.String(),
|
||||
s.PIA.String(),
|
||||
s.DNS.String(),
|
||||
s.Firewall.String(),
|
||||
s.TinyProxy.String(),
|
||||
s.ShadowSocks.String(),
|
||||
"", // new line at the end
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
// GetAllSettings obtains all settings for the program and returns an error as soon
|
||||
// as an error is encountered reading them.
|
||||
func GetAllSettings(params params.ParamsReader) (settings Settings, err error) {
|
||||
settings.OpenVPN, err = GetOpenVPNSettings(params)
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.PIA, err = GetPIASettings(params)
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.DNS, err = GetDNSSettings(params)
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.Firewall, err = GetFirewallSettings(params)
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.TinyProxy, err = GetTinyProxySettings(params)
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.ShadowSocks, err = GetShadowSocksSettings(params)
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
48
internal/settings/shadowsocks.go
Normal file
48
internal/settings/shadowsocks.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/params"
|
||||
)
|
||||
|
||||
// ShadowSocks contains settings to configure the Shadowsocks server
|
||||
type ShadowSocks struct {
|
||||
Enabled bool
|
||||
Password string
|
||||
Log bool
|
||||
Port uint16
|
||||
}
|
||||
|
||||
func (s *ShadowSocks) String() string {
|
||||
if !s.Enabled {
|
||||
return "ShadowSocks settings: disabled"
|
||||
}
|
||||
settingsList := []string{
|
||||
"ShadowSocks settings:",
|
||||
fmt.Sprintf("Port: %d", s.Port),
|
||||
}
|
||||
return strings.Join(settingsList, "\n |--")
|
||||
}
|
||||
|
||||
// GetShadowSocksSettings obtains ShadowSocks settings from environment variables using the params package.
|
||||
func GetShadowSocksSettings(params params.ParamsReader) (settings ShadowSocks, err error) {
|
||||
settings.Enabled, err = params.GetShadowSocks()
|
||||
if err != nil || !settings.Enabled {
|
||||
return settings, err
|
||||
}
|
||||
settings.Port, err = params.GetShadowSocksPort()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.Password, err = params.GetShadowSocksPassword()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.Log, err = params.GetShadowSocksLog()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
60
internal/settings/tinyproxy.go
Normal file
60
internal/settings/tinyproxy.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package settings
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/params"
|
||||
)
|
||||
|
||||
// TinyProxy contains settings to configure TinyProxy
|
||||
type TinyProxy struct {
|
||||
Enabled bool
|
||||
User string
|
||||
Password string
|
||||
Port uint16
|
||||
LogLevel models.TinyProxyLogLevel
|
||||
}
|
||||
|
||||
func (t *TinyProxy) String() string {
|
||||
if !t.Enabled {
|
||||
return "TinyProxy settings: disabled"
|
||||
}
|
||||
auth := "disabled"
|
||||
if t.User != "" {
|
||||
auth = "enabled"
|
||||
}
|
||||
settingsList := []string{
|
||||
"TinyProxy settings:",
|
||||
fmt.Sprintf("Port: %d", t.Port),
|
||||
"Authentication: " + auth,
|
||||
"Log level: " + string(t.LogLevel),
|
||||
}
|
||||
return "TinyProxy settings:\n" + strings.Join(settingsList, "\n |--")
|
||||
}
|
||||
|
||||
// GetTinyProxySettings obtains TinyProxy settings from environment variables using the params package.
|
||||
func GetTinyProxySettings(params params.ParamsReader) (settings TinyProxy, err error) {
|
||||
settings.Enabled, err = params.GetTinyProxy()
|
||||
if err != nil || !settings.Enabled {
|
||||
return settings, err
|
||||
}
|
||||
settings.User, err = params.GetTinyProxyUser()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.Password, err = params.GetTinyProxyPassword()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.Port, err = params.GetTinyProxyPort()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.LogLevel, err = params.GetTinyProxyLog()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
40
internal/shadowsocks/command.go
Normal file
40
internal/shadowsocks/command.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package shadowsocks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) Start(server string, port uint16, password string, log bool) (stdout io.ReadCloser, err error) {
|
||||
c.logger.Info("%s: starting shadowsocks server", logPrefix)
|
||||
args := []string{
|
||||
"-c", string(constants.ShadowsocksConf),
|
||||
"-p", fmt.Sprintf("%d", port),
|
||||
"-k", password,
|
||||
}
|
||||
if log {
|
||||
args = append(args, "-v")
|
||||
}
|
||||
stdout, _, _, err = c.commander.Start("ss-server", args...)
|
||||
return stdout, err
|
||||
}
|
||||
|
||||
// Version obtains the version of the installed shadowsocks server
|
||||
func (c *configurator) Version() (string, error) {
|
||||
output, err := c.commander.Run("ss-server", "-h")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
lines := strings.Split(output, "\n")
|
||||
if len(lines) < 2 {
|
||||
return "", fmt.Errorf("ss-server -h: not enough lines in %q", output)
|
||||
}
|
||||
words := strings.Fields(lines[1])
|
||||
if len(words) < 2 {
|
||||
return "", fmt.Errorf("ss-server -h: line 2 is too short: %q", lines[1])
|
||||
}
|
||||
return words[1], nil
|
||||
}
|
||||
49
internal/shadowsocks/conf.go
Normal file
49
internal/shadowsocks/conf.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package shadowsocks
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) MakeConf(port uint16, password string, uid, gid int) (err error) {
|
||||
c.logger.Info("%s: generating configuration file", logPrefix)
|
||||
data := generateConf(port, password)
|
||||
return c.fileManager.WriteToFile(
|
||||
string(constants.ShadowsocksConf),
|
||||
data,
|
||||
files.FileOwnership(uid, gid),
|
||||
files.FilePermissions(0400))
|
||||
}
|
||||
|
||||
func generateConf(port uint16, password string) (data []byte) {
|
||||
conf := struct {
|
||||
Server string `json:"server"`
|
||||
User string `json:"user"`
|
||||
Method string `json:"method"`
|
||||
Timeout uint `json:"timeout"`
|
||||
FastOpen bool `json:"fast_open"`
|
||||
Mode string `json:"mode"`
|
||||
PortPassword map[string]string `json:"port_password"`
|
||||
Workers uint `json:"workers"`
|
||||
Interface string `json:"interface"`
|
||||
Nameserver string `json:"nameserver"`
|
||||
}{
|
||||
Server: "0.0.0.0",
|
||||
User: "nonrootuser",
|
||||
Method: "chacha20-ietf-poly1305",
|
||||
Timeout: 30,
|
||||
FastOpen: false,
|
||||
Mode: "tcp_and_udp",
|
||||
PortPassword: map[string]string{
|
||||
fmt.Sprintf("%d", port): password,
|
||||
},
|
||||
Workers: 2,
|
||||
Interface: "tun",
|
||||
Nameserver: "127.0.0.1",
|
||||
}
|
||||
data, _ = json.Marshal(conf)
|
||||
return data
|
||||
}
|
||||
79
internal/shadowsocks/conf_test.go
Normal file
79
internal/shadowsocks/conf_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package shadowsocks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
filesMocks "github.com/qdm12/golibs/files/mocks"
|
||||
loggingMocks "github.com/qdm12/golibs/logging/mocks"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_generateConf(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
port uint16
|
||||
password string
|
||||
data []byte
|
||||
}{
|
||||
"no data": {
|
||||
data: []byte(`{"server":"0.0.0.0","user":"nonrootuser","method":"chacha20-ietf-poly1305","timeout":30,"fast_open":false,"mode":"tcp_and_udp","port_password":{"0":""},"workers":2,"interface":"tun","nameserver":"127.0.0.1"}`),
|
||||
},
|
||||
"data": {
|
||||
port: 2000,
|
||||
password: "abcde",
|
||||
data: []byte(`{"server":"0.0.0.0","user":"nonrootuser","method":"chacha20-ietf-poly1305","timeout":30,"fast_open":false,"mode":"tcp_and_udp","port_password":{"2000":"abcde"},"workers":2,"interface":"tun","nameserver":"127.0.0.1"}`),
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
data := generateConf(tc.port, tc.password)
|
||||
assert.Equal(t, tc.data, data)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_MakeConf(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
writeErr error
|
||||
err error
|
||||
}{
|
||||
"no write error": {},
|
||||
"write error": {
|
||||
writeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
logger := &loggingMocks.Logger{}
|
||||
logger.On("Info", "%s: generating configuration file", logPrefix).Once()
|
||||
fileManager := &filesMocks.FileManager{}
|
||||
fileManager.On("WriteToFile",
|
||||
string(constants.ShadowsocksConf),
|
||||
[]byte(`{"server":"0.0.0.0","user":"nonrootuser","method":"chacha20-ietf-poly1305","timeout":30,"fast_open":false,"mode":"tcp_and_udp","port_password":{"2000":"abcde"},"workers":2,"interface":"tun","nameserver":"127.0.0.1"}`),
|
||||
mock.AnythingOfType("files.WriteOptionSetter"),
|
||||
mock.AnythingOfType("files.WriteOptionSetter"),
|
||||
).
|
||||
Return(tc.writeErr).Once()
|
||||
c := &configurator{logger: logger, fileManager: fileManager}
|
||||
err := c.MakeConf(2000, "abcde", 1000, 1001)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
logger.AssertExpectations(t)
|
||||
fileManager.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
27
internal/shadowsocks/shadowsocks.go
Normal file
27
internal/shadowsocks/shadowsocks.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package shadowsocks
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
const logPrefix = "shadowsocks configurator"
|
||||
|
||||
type Configurator interface {
|
||||
Version() (string, error)
|
||||
MakeConf(port uint16, password string, uid, gid int) (err error)
|
||||
Start(server string, port uint16, password string, log bool) (stdout io.ReadCloser, err error)
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
fileManager files.FileManager
|
||||
logger logging.Logger
|
||||
commander command.Commander
|
||||
}
|
||||
|
||||
func NewConfigurator(fileManager files.FileManager, logger logging.Logger) Configurator {
|
||||
return &configurator{fileManager, logger, command.NewCommander()}
|
||||
}
|
||||
55
internal/splash/splash.go
Normal file
55
internal/splash/splash.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package splash
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/kyokomi/emoji"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/params"
|
||||
)
|
||||
|
||||
func Splash(paramsReader params.ParamsReader) string {
|
||||
version := paramsReader.GetVersion()
|
||||
vcsRef := paramsReader.GetVcsRef()
|
||||
buildDate := paramsReader.GetBuildDate()
|
||||
lines := title()
|
||||
lines = append(lines, "")
|
||||
lines = append(lines, fmt.Sprintf("Running version %s built on %s (commit %s)", version, buildDate, vcsRef))
|
||||
lines = append(lines, "")
|
||||
lines = append(lines, annoucement()...)
|
||||
lines = append(lines, "")
|
||||
lines = append(lines, links()...)
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func title() []string {
|
||||
return []string{
|
||||
"=========================================",
|
||||
"============= PIA container =============",
|
||||
"========== An exquisite mix of ==========",
|
||||
"==== OpenVPN, Unbound, DNS over TLS, ====",
|
||||
"===== Shadowsocks, Tinyproxy and Go =====",
|
||||
"=========================================",
|
||||
"=== Made with " + emoji.Sprint(":heart:") + " by github.com/qdm12 ====",
|
||||
"=========================================",
|
||||
}
|
||||
}
|
||||
|
||||
func annoucement() []string {
|
||||
timestamp := time.Now().UnixNano() / 1000000000
|
||||
if timestamp < constants.AnnoucementExpiration {
|
||||
return []string{emoji.Sprint(":rotating_light: ") + constants.Annoucement}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func links() []string {
|
||||
return []string{
|
||||
emoji.Sprint(":wrench: ") + "Need help? " + constants.IssueLink,
|
||||
emoji.Sprint(":computer: ") + "Email? quentin.mcgaw@gmail.com",
|
||||
emoji.Sprint(":coffee: ") + "Slack? Join from the Slack button on Github",
|
||||
emoji.Sprint(":money_with_wings: ") + "Help me? https://github.com/sponsors/qdm12",
|
||||
}
|
||||
}
|
||||
26
internal/tinyproxy/command.go
Normal file
26
internal/tinyproxy/command.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package tinyproxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *configurator) Start() (stdout io.ReadCloser, err error) {
|
||||
c.logger.Info("%s: starting tinyproxy server", logPrefix)
|
||||
stdout, _, _, err = c.commander.Start("tinyproxy", "-d")
|
||||
return stdout, err
|
||||
}
|
||||
|
||||
// Version obtains the version of the installed Tinyproxy server
|
||||
func (c *configurator) Version() (string, error) {
|
||||
output, err := c.commander.Run("tinyproxy", "-v")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
words := strings.Fields(output)
|
||||
if len(words) < 2 {
|
||||
return "", fmt.Errorf("tinyproxy -v: output is too short: %q", output)
|
||||
}
|
||||
return words[1], nil
|
||||
}
|
||||
44
internal/tinyproxy/conf.go
Normal file
44
internal/tinyproxy/conf.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package tinyproxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
func (c *configurator) MakeConf(logLevel models.TinyProxyLogLevel, port uint16, user, password string, uid, gid int) error {
|
||||
c.logger.Info("%s: generating tinyproxy configuration file", logPrefix)
|
||||
lines := generateConf(logLevel, port, user, password)
|
||||
return c.fileManager.WriteLinesToFile(string(constants.TinyProxyConf),
|
||||
lines,
|
||||
files.FileOwnership(uid, gid),
|
||||
files.FilePermissions(0400))
|
||||
}
|
||||
|
||||
func generateConf(logLevel models.TinyProxyLogLevel, port uint16, user, password string) (lines []string) {
|
||||
confMapping := map[string]string{
|
||||
"User": "nonrootuser",
|
||||
"Group": "tinyproxy",
|
||||
"Port": fmt.Sprintf("%d", port),
|
||||
"Timeout": "600",
|
||||
"DefaultErrorFile": "/usr/share/tinyproxy/default.html",
|
||||
"MaxClients": "100",
|
||||
"MinSpareServers": "5",
|
||||
"MaxSpareServers": "20",
|
||||
"StartServers": "10",
|
||||
"MaxRequestsPerChild": "0",
|
||||
"DisableViaHeader": "Yes",
|
||||
"LogLevel": string(logLevel),
|
||||
// "StatFile": "/usr/share/tinyproxy/stats.html",
|
||||
}
|
||||
if len(user) > 0 {
|
||||
confMapping["BasicAuth"] = fmt.Sprintf("%s %s", user, password)
|
||||
}
|
||||
for k, v := range confMapping {
|
||||
line := fmt.Sprintf("%s %s", k, v)
|
||||
lines = append(lines, line)
|
||||
}
|
||||
return lines
|
||||
}
|
||||
28
internal/tinyproxy/tinyproxy.go
Normal file
28
internal/tinyproxy/tinyproxy.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package tinyproxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
const logPrefix = "tinyproxy configurator"
|
||||
|
||||
type Configurator interface {
|
||||
Version() (string, error)
|
||||
MakeConf(logLevel models.TinyProxyLogLevel, port uint16, user, password string, uid, gid int) error
|
||||
Start() (stdout io.ReadCloser, err error)
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
fileManager files.FileManager
|
||||
logger logging.Logger
|
||||
commander command.Commander
|
||||
}
|
||||
|
||||
func NewConfigurator(fileManager files.FileManager, logger logging.Logger) Configurator {
|
||||
return &configurator{fileManager, logger, command.NewCommander()}
|
||||
}
|
||||
Reference in New Issue
Block a user