Updated dependencies

This commit is contained in:
Quentin McGaw
2020-04-19 18:13:48 +00:00
parent cbd11bfdf2
commit e805d42197
21 changed files with 140 additions and 108 deletions

View File

@@ -65,11 +65,11 @@ func main() {
defer cancel()
streamMerger := command.NewStreamMerger(ctx)
e.PrintVersion("OpenVPN", ovpnConf.Version)
e.PrintVersion("Unbound", dnsConf.Version)
e.PrintVersion("IPtables", firewallConf.Version)
e.PrintVersion("TinyProxy", tinyProxyConf.Version)
e.PrintVersion("ShadowSocks", shadowsocksConf.Version)
e.PrintVersion(ctx, "OpenVPN", ovpnConf.Version)
e.PrintVersion(ctx, "Unbound", dnsConf.Version)
e.PrintVersion(ctx, "IPtables", firewallConf.Version)
e.PrintVersion(ctx, "TinyProxy", tinyProxyConf.Version)
e.PrintVersion(ctx, "ShadowSocks", shadowsocksConf.Version)
allSettings, err := settings.GetAllSettings(paramsReader)
e.FatalOnError(err)
@@ -111,7 +111,7 @@ func main() {
// pre-exist, preventing the nslookup of the PIA region address. These will
// simply be redundant at Docker runtime as they will already be set this way
// Thanks to @npawelek https://github.com/npawelek
err = firewallConf.AcceptAll()
err = firewallConf.AcceptAll(ctx)
e.FatalOnError(err)
go func() {
@@ -120,7 +120,7 @@ func main() {
err = streamMerger.CollectLines(func(line string) {
logger.Info(line)
if strings.Contains(line, "Initialization Sequence Completed") {
onConnected(logger, routingConf, fileManager, piaConf,
onConnected(ctx, logger, routingConf, fileManager, piaConf,
defaultInterface,
allSettings.VPNSP,
allSettings.PIA.PortForwarding.Enabled,
@@ -142,12 +142,12 @@ func main() {
e.FatalOnError(err)
err = dnsConf.MakeUnboundConf(allSettings.DNS, allSettings.System.UID, allSettings.System.GID)
e.FatalOnError(err)
stream, waitFn, err := dnsConf.Start(allSettings.DNS.VerbosityDetailsLevel)
stream, waitFn, err := dnsConf.Start(ctx, allSettings.DNS.VerbosityDetailsLevel)
e.FatalOnError(err)
go func() {
e.FatalOnError(waitFn())
}()
go streamMerger.Merge("unbound", stream)
go streamMerger.Merge(stream, command.MergeName("unbound"))
dnsConf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
err = dnsConf.UseDNSSystemWide(net.IP{127, 0, 0, 1}) // use Unbound
e.FatalOnError(err)
@@ -209,17 +209,17 @@ func main() {
e.FatalOnError(err)
}
err = routingConf.AddRoutesVia(allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface)
err = routingConf.AddRoutesVia(ctx, allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface)
e.FatalOnError(err)
err = firewallConf.Clear()
err = firewallConf.Clear(ctx)
e.FatalOnError(err)
err = firewallConf.BlockAll()
err = firewallConf.BlockAll(ctx)
e.FatalOnError(err)
err = firewallConf.CreateGeneralRules()
err = firewallConf.CreateGeneralRules(ctx)
e.FatalOnError(err)
err = firewallConf.CreateVPNRules(constants.TUN, defaultInterface, connections)
err = firewallConf.CreateVPNRules(ctx, constants.TUN, defaultInterface, connections)
e.FatalOnError(err)
err = firewallConf.CreateLocalSubnetsRules(defaultSubnet, allSettings.Firewall.AllowedSubnets, defaultInterface)
err = firewallConf.CreateLocalSubnetsRules(ctx, defaultSubnet, allSettings.Firewall.AllowedSubnets, defaultInterface)
e.FatalOnError(err)
if allSettings.TinyProxy.Enabled {
@@ -231,16 +231,16 @@ func main() {
allSettings.System.UID,
allSettings.System.GID)
e.FatalOnError(err)
err = firewallConf.AllowAnyIncomingOnPort(allSettings.TinyProxy.Port)
err = firewallConf.AllowAnyIncomingOnPort(ctx, allSettings.TinyProxy.Port)
e.FatalOnError(err)
stream, waitFn, err := tinyProxyConf.Start()
stream, waitFn, err := tinyProxyConf.Start(ctx)
e.FatalOnError(err)
go func() {
if err := waitFn(); err != nil {
logger.Error(err)
}
}()
go streamMerger.Merge("tinyproxy", stream)
go streamMerger.Merge(stream, command.MergeName("tinyproxy"))
}
if allSettings.ShadowSocks.Enabled {
@@ -251,22 +251,22 @@ func main() {
allSettings.System.UID,
allSettings.System.GID)
e.FatalOnError(err)
err = firewallConf.AllowAnyIncomingOnPort(allSettings.ShadowSocks.Port)
err = firewallConf.AllowAnyIncomingOnPort(ctx, allSettings.ShadowSocks.Port)
e.FatalOnError(err)
stdout, stderr, waitFn, err := shadowsocksConf.Start("0.0.0.0", allSettings.ShadowSocks.Port, allSettings.ShadowSocks.Password, allSettings.ShadowSocks.Log)
stdout, stderr, waitFn, err := shadowsocksConf.Start(ctx, "0.0.0.0", allSettings.ShadowSocks.Port, allSettings.ShadowSocks.Password, allSettings.ShadowSocks.Log)
e.FatalOnError(err)
go func() {
if err := waitFn(); err != nil {
logger.Error(err)
}
}()
go streamMerger.Merge("shadowsocks", stdout)
go streamMerger.Merge("shadowsocks error", stderr)
go streamMerger.Merge(stdout, command.MergeName("shadowsocks"))
go streamMerger.Merge(stderr, command.MergeName("shadowsocks error"))
}
stream, waitFn, err := ovpnConf.Start()
stream, waitFn, err := ovpnConf.Start(ctx)
e.FatalOnError(err)
go streamMerger.Merge("openvpn", stream)
go streamMerger.Merge(stream, command.MergeName("openvpn"))
go signals.WaitForExit(func(signal string) int {
logger.Warn("Caught OS signal %s, shutting down", signal)
if allSettings.VPNSP == "pia" && allSettings.PIA.PortForwarding.Enabled {
@@ -281,6 +281,7 @@ func main() {
}
func onConnected(
ctx context.Context,
logger logging.Logger,
routingConf routing.Routing,
fileManager files.FileManager,
@@ -319,7 +320,7 @@ func onConnected(
logger.Error("port forwarding:", err)
return
}
if err := piaConf.AllowPortForwardFirewall(constants.TUN, port); err != nil {
if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
logger.Error("port forwarding:", err)
return
}

6
go.mod
View File

@@ -4,8 +4,8 @@ go 1.14
require (
github.com/golang/mock v1.4.3
github.com/kyokomi/emoji v2.2.1+incompatible
github.com/qdm12/golibs v0.0.0-20200412175259-da41d65db446
github.com/kyokomi/emoji v2.2.2+incompatible
github.com/qdm12/golibs v0.0.0-20200419174016-f1c612728dfa
github.com/stretchr/testify v1.5.1
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa
golang.org/x/sys v0.0.0-20200413165638-669c56c373c4
)

17
go.sum
View File

@@ -10,6 +10,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/docker/go-units v0.3.3/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s=
github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU=
github.com/globalsign/mgo v0.0.0-20180905125535-1ca0a4f7cbcb h1:D4uzjWwKYQ5XnAvUbuvHW93esHg7F8N/OYeBBcJoTr0=
github.com/globalsign/mgo v0.0.0-20180905125535-1ca0a4f7cbcb/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q=
github.com/go-openapi/analysis v0.0.0-20180825180245-b006789cd277/go.mod h1:k70tL6pCuVxPJOHXQ+wIac1FUrvNkHolPie/cLEU6hI=
@@ -50,8 +52,15 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kyokomi/emoji v2.2.1+incompatible h1:uP/6J5y5U0XxPh6fv8YximpVD1uMrshXG78I1+uF5SA=
github.com/kyokomi/emoji v2.2.1+incompatible/go.mod h1:mZ6aGCD7yk8j6QY6KICwnZ2pxoszVseX1DNoGtU2tBA=
github.com/kyokomi/emoji v2.2.2+incompatible h1:gaQFbK2+uSxOR4iGZprJAbpmtqTrHhSdgOyIMD6Oidc=
github.com/kyokomi/emoji v2.2.2+incompatible/go.mod h1:mZ6aGCD7yk8j6QY6KICwnZ2pxoszVseX1DNoGtU2tBA=
github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329 h1:2gxZ0XQIU/5z3Z3bUBu+FXuk2pFbkN6tcwi/pjyaDic=
github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM=
github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE=
github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/mr-tron/base58 v1.1.3 h1:v+sk57XuaCKGXpWtVBX8YJzO7hMGx4Aajh4TQbdEFdc=
@@ -63,8 +72,8 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/qdm12/golibs v0.0.0-20200412175259-da41d65db446 h1:sBPYLwDSqRsOqHi7f34c7QMcoR1xLD1wLnOl0L7br6c=
github.com/qdm12/golibs v0.0.0-20200412175259-da41d65db446/go.mod h1:y4hRtiU2Al0+y2UP1I9e0yYu9VqemnMwyJVCkyhy9r8=
github.com/qdm12/golibs v0.0.0-20200419174016-f1c612728dfa h1:7kFbnjnVF87U1gF3LdTYi3b63oIaUWJXv8pZvRdJoNA=
github.com/qdm12/golibs v0.0.0-20200419174016-f1c612728dfa/go.mod h1:pikkTN7g7zRuuAnERwqW1yAFq6pYmxrxpjiwGvb0Ysc=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -95,10 +104,14 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859 h1:R/3boaszxrf1GEUWTVDzSKVwL
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU=
golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200413165638-669c56c373c4 h1:opSr2sbRXk5X5/givKrrKj9HXxFpW2sdCiP8MJSKLQY=
golang.org/x/sys v0.0.0-20200413165638-669c56c373c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@@ -1,6 +1,7 @@
package dns
import (
"context"
"fmt"
"io"
"strings"
@@ -8,19 +9,19 @@ import (
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) Start(verbosityDetailsLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error) {
func (c *configurator) Start(ctx context.Context, verbosityDetailsLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error) {
c.logger.Info("starting unbound")
args := []string{"-d", "-c", string(constants.UnboundConf)}
if verbosityDetailsLevel > 0 {
args = append(args, "-"+strings.Repeat("v", int(verbosityDetailsLevel)))
}
// Only logs to stderr
_, stdout, waitFn, err = c.commander.Start("unbound", args...)
_, stdout, waitFn, err = c.commander.Start(ctx, "unbound", args...)
return stdout, waitFn, err
}
func (c *configurator) Version() (version string, err error) {
output, err := c.commander.Run("unbound", "-V")
func (c *configurator) Version(ctx context.Context) (version string, err error) {
output, err := c.commander.Run(ctx, "unbound", "-V")
if err != nil {
return "", fmt.Errorf("unbound version: %w", err)
}

View File

@@ -1,6 +1,7 @@
package dns
import (
"context"
"fmt"
"testing"
@@ -20,10 +21,10 @@ func Test_Start(t *testing.T) {
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("starting unbound").Times(1)
commander := mock_command.NewMockCommander(mockCtrl)
commander.EXPECT().Start("unbound", "-d", "-c", string(constants.UnboundConf), "-vv").
commander.EXPECT().Start(context.Background(), "unbound", "-d", "-c", string(constants.UnboundConf), "-vv").
Return(nil, nil, nil, nil).Times(1)
c := &configurator{commander: commander, logger: logger}
stdout, waitFn, err := c.Start(2)
stdout, waitFn, err := c.Start(context.Background(), 2)
assert.Nil(t, stdout)
assert.Nil(t, waitFn)
assert.NoError(t, err)
@@ -56,10 +57,10 @@ func Test_Version(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
commander := mock_command.NewMockCommander(mockCtrl)
commander.EXPECT().Run("unbound", "-V").
commander.EXPECT().Run(context.Background(), "unbound", "-V").
Return(tc.runOutput, tc.runErr).Times(1)
c := &configurator{commander: commander}
version, err := c.Version()
version, err := c.Version(context.Background())
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())

View File

@@ -1,6 +1,7 @@
package dns
import (
"context"
"io"
"net"
@@ -17,9 +18,9 @@ type Configurator interface {
MakeUnboundConf(settings settings.DNS, uid, gid int) (err error)
UseDNSInternally(IP net.IP)
UseDNSSystemWide(IP net.IP) error
Start(logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error)
Start(ctx context.Context, logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error)
WaitForUnbound() (err error)
Version() (version string, err error)
Version(ctx context.Context) (version string, err error)
}
type configurator struct {

7
internal/env/env.go vendored
View File

@@ -1,6 +1,7 @@
package env
import (
"context"
"os"
"github.com/qdm12/golibs/logging"
@@ -8,7 +9,7 @@ import (
type Env interface {
FatalOnError(err error)
PrintVersion(program string, commandFn func() (string, error))
PrintVersion(ctx context.Context, program string, commandFn func(ctx context.Context) (string, error))
}
type env struct {
@@ -30,8 +31,8 @@ func (e *env) FatalOnError(err error) {
}
}
func (e *env) PrintVersion(program string, commandFn func() (string, error)) {
version, err := commandFn()
func (e *env) PrintVersion(ctx context.Context, program string, commandFn func(ctx context.Context) (string, error)) {
version, err := commandFn(ctx)
if err != nil {
e.logger.Error(err)
} else {

View File

@@ -1,6 +1,7 @@
package env
import (
"context"
"fmt"
"testing"
@@ -75,8 +76,8 @@ func Test_PrintVersion(t *testing.T) {
}).Times(1)
}
e := &env{logger: logger}
commandFn := func() (string, error) { return tc.commandVersion, tc.commandErr }
e.PrintVersion(tc.program, commandFn)
commandFn := func(ctx context.Context) (string, error) { return tc.commandVersion, tc.commandErr }
e.PrintVersion(context.Background(), tc.program, commandFn)
if tc.commandErr != nil {
assert.Equal(t, logged, tc.commandErr.Error())
} else {

View File

@@ -1,6 +1,7 @@
package firewall
import (
"context"
"net"
"github.com/qdm12/golibs/command"
@@ -10,15 +11,15 @@ import (
// Configurator allows to change firewall rules and modify network routes
type Configurator interface {
Version() (string, error)
AcceptAll() error
Clear() error
BlockAll() error
CreateGeneralRules() error
CreateVPNRules(dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error
CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error
AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error
AllowAnyIncomingOnPort(port uint16) error
Version(ctx context.Context) (string, error)
AcceptAll(ctx context.Context) error
Clear(ctx context.Context) error
BlockAll(ctx context.Context) error
CreateGeneralRules(ctx context.Context) error
CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error
CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error
AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error
AllowAnyIncomingOnPort(ctx context.Context, port uint16) error
}
type configurator struct {

View File

@@ -1,6 +1,7 @@
package firewall
import (
"context"
"fmt"
"net"
"strings"
@@ -9,8 +10,8 @@ import (
)
// Version obtains the version of the installed iptables
func (c *configurator) Version() (string, error) {
output, err := c.commander.Run("iptables", "--version")
func (c *configurator) Version(ctx context.Context) (string, error) {
output, err := c.commander.Run(ctx, "iptables", "--version")
if err != nil {
return "", err
}
@@ -21,26 +22,26 @@ func (c *configurator) Version() (string, error) {
return words[1], nil
}
func (c *configurator) runIptablesInstructions(instructions []string) error {
func (c *configurator) runIptablesInstructions(ctx context.Context, instructions []string) error {
for _, instruction := range instructions {
if err := c.runIptablesInstruction(instruction); err != nil {
if err := c.runIptablesInstruction(ctx, instruction); err != nil {
return err
}
}
return nil
}
func (c *configurator) runIptablesInstruction(instruction string) error {
func (c *configurator) runIptablesInstruction(ctx context.Context, instruction string) error {
flags := strings.Fields(instruction)
if output, err := c.commander.Run("iptables", flags...); err != nil {
if output, err := c.commander.Run(ctx, "iptables", flags...); err != nil {
return fmt.Errorf("failed executing %q: %s: %w", instruction, output, err)
}
return nil
}
func (c *configurator) Clear() error {
func (c *configurator) Clear(ctx context.Context) error {
c.logger.Info("clearing all rules")
return c.runIptablesInstructions([]string{
return c.runIptablesInstructions(ctx, []string{
"--flush",
"--delete-chain",
"-t nat --flush",
@@ -48,18 +49,18 @@ func (c *configurator) Clear() error {
})
}
func (c *configurator) AcceptAll() error {
func (c *configurator) AcceptAll(ctx context.Context) error {
c.logger.Info("accepting all traffic")
return c.runIptablesInstructions([]string{
return c.runIptablesInstructions(ctx, []string{
"-P INPUT ACCEPT",
"-P OUTPUT ACCEPT",
"-P FORWARD ACCEPT",
})
}
func (c *configurator) BlockAll() error {
func (c *configurator) BlockAll(ctx context.Context) error {
c.logger.Info("blocking all traffic")
return c.runIptablesInstructions([]string{
return c.runIptablesInstructions(ctx, []string{
"-P INPUT DROP",
"-F OUTPUT",
"-P OUTPUT DROP",
@@ -67,9 +68,9 @@ func (c *configurator) BlockAll() error {
})
}
func (c *configurator) CreateGeneralRules() error {
func (c *configurator) CreateGeneralRules(ctx context.Context) error {
c.logger.Info("creating general rules")
return c.runIptablesInstructions([]string{
return c.runIptablesInstructions(ctx, []string{
"-A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
"-A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
"-A OUTPUT -o lo -j ACCEPT",
@@ -77,26 +78,26 @@ func (c *configurator) CreateGeneralRules() error {
})
}
func (c *configurator) CreateVPNRules(dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error {
func (c *configurator) CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error {
for _, connection := range connections {
c.logger.Info("allowing output traffic to VPN server %s through %s on port %s %d",
connection.IP, defaultInterface, connection.Protocol, connection.Port)
if err := c.runIptablesInstruction(
if err := c.runIptablesInstruction(ctx,
fmt.Sprintf("-A OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port)); err != nil {
return err
}
}
if err := c.runIptablesInstruction(fmt.Sprintf("-A OUTPUT -o %s -j ACCEPT", dev)); err != nil {
if err := c.runIptablesInstruction(ctx, fmt.Sprintf("-A OUTPUT -o %s -j ACCEPT", dev)); err != nil {
return err
}
return nil
}
func (c *configurator) CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error {
func (c *configurator) CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error {
subnetStr := subnet.String()
c.logger.Info("accepting input and output traffic for %s", subnetStr)
if err := c.runIptablesInstructions([]string{
if err := c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-A INPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
fmt.Sprintf("-A OUTPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
}); err != nil {
@@ -105,13 +106,13 @@ func (c *configurator) CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []
for _, extraSubnet := range extraSubnets {
extraSubnetStr := extraSubnet.String()
c.logger.Info("accepting input traffic through %s from %s to %s", defaultInterface, extraSubnetStr, subnetStr)
if err := c.runIptablesInstruction(
if err := c.runIptablesInstruction(ctx,
fmt.Sprintf("-A INPUT -i %s -s %s -d %s -j ACCEPT", defaultInterface, extraSubnetStr, subnetStr)); err != nil {
return err
}
// Thanks to @npawelek
c.logger.Info("accepting output traffic through %s from %s to %s", defaultInterface, subnetStr, extraSubnetStr)
if err := c.runIptablesInstruction(
if err := c.runIptablesInstruction(ctx,
fmt.Sprintf("-A OUTPUT -o %s -s %s -d %s -j ACCEPT", defaultInterface, subnetStr, extraSubnetStr)); err != nil {
return err
}
@@ -120,17 +121,17 @@ func (c *configurator) CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []
}
// Used for port forwarding
func (c *configurator) AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error {
func (c *configurator) AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error {
c.logger.Info("accepting input traffic through %s on port %d", device, port)
return c.runIptablesInstructions([]string{
return c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-A INPUT -i %s -p tcp --dport %d -j ACCEPT", device, port),
fmt.Sprintf("-A INPUT -i %s -p udp --dport %d -j ACCEPT", device, port),
})
}
func (c *configurator) AllowAnyIncomingOnPort(port uint16) error {
func (c *configurator) AllowAnyIncomingOnPort(ctx context.Context, port uint16) error {
c.logger.Info("accepting any input traffic on port %d", port)
return c.runIptablesInstructions([]string{
return c.runIptablesInstructions(ctx, []string{
fmt.Sprintf("-A INPUT -p tcp --dport %d -j ACCEPT", port),
fmt.Sprintf("-A INPUT -p udp --dport %d -j ACCEPT", port),
})

View File

@@ -1,6 +1,7 @@
package openvpn
import (
"context"
"fmt"
"io"
"strings"
@@ -8,14 +9,14 @@ import (
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) Start() (stdout io.ReadCloser, waitFn func() error, err error) {
func (c *configurator) Start(ctx context.Context) (stdout io.ReadCloser, waitFn func() error, err error) {
c.logger.Info("starting openvpn")
stdout, _, waitFn, err = c.commander.Start("openvpn", "--config", string(constants.OpenVPNConf))
stdout, _, waitFn, err = c.commander.Start(ctx, "openvpn", "--config", string(constants.OpenVPNConf))
return stdout, waitFn, err
}
func (c *configurator) Version() (string, error) {
output, err := c.commander.Run("openvpn", "--version")
func (c *configurator) Version(ctx context.Context) (string, error) {
output, err := c.commander.Run(ctx, "openvpn", "--version")
if err != nil && err.Error() != "exit status 1" {
return "", err
}

View File

@@ -1,6 +1,7 @@
package openvpn
import (
"context"
"io"
"os"
@@ -11,11 +12,11 @@ import (
)
type Configurator interface {
Version() (string, error)
Version(ctx context.Context) (string, error)
WriteAuthFile(user, password string, uid, gid int) error
CheckTUN() error
CreateTUN() error
Start() (stdout io.ReadCloser, waitFn func() error, err error)
Start(ctx context.Context) (stdout io.ReadCloser, waitFn func() error, err error)
}
type configurator struct {

View File

@@ -1,6 +1,7 @@
package pia
import (
"context"
"net"
"github.com/qdm12/golibs/crypto/random"
@@ -20,7 +21,7 @@ type Configurator interface {
GetPortForward() (port uint16, err error)
WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error)
ClearPortForward(filepath models.Filepath, uid, gid int) (err error)
AllowPortForwardFirewall(device models.VPNDevice, port uint16) (err error)
AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error)
}
type configurator struct {

View File

@@ -1,6 +1,7 @@
package pia
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
@@ -47,9 +48,9 @@ func (c *configurator) WritePortForward(filepath models.Filepath, port uint16, u
files.Permissions(0400))
}
func (c *configurator) AllowPortForwardFirewall(device models.VPNDevice, port uint16) (err error) {
func (c *configurator) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
c.logger.Info("Allowing forwarded port %d through firewall", port)
return c.firewall.AllowInputTrafficOnPort(device, port)
return c.firewall.AllowInputTrafficOnPort(ctx, device, port)
}
func (c *configurator) ClearPortForward(filepath models.Filepath, uid, gid int) (err error) {

View File

@@ -1,23 +1,24 @@
package routing
import (
"context"
"net"
"fmt"
)
func (r *routing) AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error {
func (r *routing) AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error {
for _, subnet := range subnets {
exists, err := r.routeExists(subnet)
if err != nil {
return err
} else if exists { // thanks to @npawelek https://github.com/npawelek
if err := r.removeRoute(subnet); err != nil {
if err := r.removeRoute(ctx, subnet); err != nil {
return err
}
}
r.logger.Info("adding %s as route via %s", subnet.String(), defaultInterface)
output, err := r.commander.Run("ip", "route", "add", subnet.String(), "via", defaultGateway.String(), "dev", defaultInterface)
output, err := r.commander.Run(ctx, "ip", "route", "add", subnet.String(), "via", defaultGateway.String(), "dev", defaultInterface)
if err != nil {
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnet.String(), defaultGateway.String(), "dev", defaultInterface, output, err)
}
@@ -25,8 +26,8 @@ func (r *routing) AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defau
return nil
}
func (r *routing) removeRoute(subnet net.IPNet) (err error) {
output, err := r.commander.Run("ip", "route", "del", subnet.String())
func (r *routing) removeRoute(ctx context.Context, subnet net.IPNet) (err error) {
output, err := r.commander.Run(ctx, "ip", "route", "del", subnet.String())
if err != nil {
return fmt.Errorf("cannot delete route for %s: %s: %w", subnet.String(), output, err)
}

View File

@@ -1,6 +1,7 @@
package routing
import (
"context"
"fmt"
"net"
"testing"
@@ -51,10 +52,10 @@ func Test_removeRoute(t *testing.T) {
defer mockCtrl.Finish()
commander := mock_command.NewMockCommander(mockCtrl)
commander.EXPECT().Run("ip", "route", "del", tc.subnet.String()).
commander.EXPECT().Run(context.Background(), "ip", "route", "del", tc.subnet.String()).
Return(tc.runOutput, tc.runErr).Times(1)
r := &routing{commander: commander}
err := r.removeRoute(tc.subnet)
err := r.removeRoute(context.Background(), tc.subnet)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())

View File

@@ -1,6 +1,7 @@
package routing
import (
"context"
"net"
"github.com/qdm12/golibs/command"
@@ -9,7 +10,7 @@ import (
)
type Routing interface {
AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error
AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error
DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error)
CurrentPublicIP(defaultInterface string) (ip net.IP, err error)
}

View File

@@ -1,6 +1,7 @@
package shadowsocks
import (
"context"
"fmt"
"io"
"strings"
@@ -8,7 +9,7 @@ import (
"github.com/qdm12/private-internet-access-docker/internal/constants"
)
func (c *configurator) Start(server string, port uint16, password string, log bool) (stdout, stderr io.ReadCloser, waitFn func() error, err error) {
func (c *configurator) Start(ctx context.Context, server string, port uint16, password string, log bool) (stdout, stderr io.ReadCloser, waitFn func() error, err error) {
c.logger.Info("starting shadowsocks server")
args := []string{
"-c", string(constants.ShadowsocksConf),
@@ -18,13 +19,13 @@ func (c *configurator) Start(server string, port uint16, password string, log bo
if log {
args = append(args, "-v")
}
stdout, stderr, waitFn, err = c.commander.Start("ss-server", args...)
stdout, stderr, waitFn, err = c.commander.Start(ctx, "ss-server", args...)
return stdout, stderr, waitFn, err
}
// Version obtains the version of the installed shadowsocks server
func (c *configurator) Version() (string, error) {
output, err := c.commander.Run("ss-server", "-h")
func (c *configurator) Version(ctx context.Context) (string, error) {
output, err := c.commander.Run(ctx, "ss-server", "-h")
if err != nil {
return "", err
}

View File

@@ -1,6 +1,7 @@
package shadowsocks
import (
"context"
"io"
"github.com/qdm12/golibs/command"
@@ -9,9 +10,9 @@ import (
)
type Configurator interface {
Version() (string, error)
Version(ctx context.Context) (string, error)
MakeConf(port uint16, password, method string, uid, gid int) (err error)
Start(server string, port uint16, password string, log bool) (stdout, stderr io.ReadCloser, waitFn func() error, err error)
Start(ctx context.Context, server string, port uint16, password string, log bool) (stdout, stderr io.ReadCloser, waitFn func() error, err error)
}
type configurator struct {

View File

@@ -1,20 +1,21 @@
package tinyproxy
import (
"context"
"fmt"
"io"
"strings"
)
func (c *configurator) Start() (stdout io.ReadCloser, waitFn func() error, err error) {
func (c *configurator) Start(ctx context.Context) (stdout io.ReadCloser, waitFn func() error, err error) {
c.logger.Info("starting tinyproxy server")
stdout, _, waitFn, err = c.commander.Start("tinyproxy", "-d")
stdout, _, waitFn, err = c.commander.Start(ctx, "tinyproxy", "-d")
return stdout, waitFn, err
}
// Version obtains the version of the installed Tinyproxy server
func (c *configurator) Version() (string, error) {
output, err := c.commander.Run("tinyproxy", "-v")
func (c *configurator) Version(ctx context.Context) (string, error) {
output, err := c.commander.Run(ctx, "tinyproxy", "-v")
if err != nil {
return "", err
}

View File

@@ -1,6 +1,7 @@
package tinyproxy
import (
"context"
"io"
"github.com/qdm12/golibs/command"
@@ -10,9 +11,9 @@ import (
)
type Configurator interface {
Version() (string, error)
Version(ctx context.Context) (string, error)
MakeConf(logLevel models.TinyProxyLogLevel, port uint16, user, password string, uid, gid int) error
Start() (stdout io.ReadCloser, waitFn func() error, err error)
Start(ctx context.Context) (stdout io.ReadCloser, waitFn func() error, err error)
}
type configurator struct {