Firewall refactoring
- Ability to enable and disable rules in various loops - Simplified code overall - Port forwarding moved into openvpn loop - Route addition and removal improved
This commit is contained in:
@@ -17,13 +17,10 @@ import (
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/alpine"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/cli"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/dns"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/openvpn"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/params"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/provider"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/publicip"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/routing"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/server"
|
||||
@@ -72,8 +69,8 @@ func _main(background context.Context, args []string) int {
|
||||
alpineConf := alpine.NewConfigurator(fileManager)
|
||||
ovpnConf := openvpn.NewConfigurator(logger, fileManager)
|
||||
dnsConf := dns.NewConfigurator(logger, client, fileManager)
|
||||
firewallConf := firewall.NewConfigurator(logger)
|
||||
routingConf := routing.NewRouting(logger, fileManager)
|
||||
firewallConf := firewall.NewConfigurator(logger, routingConf, fileManager)
|
||||
tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger)
|
||||
shadowsocksConf := shadowsocks.NewConfigurator(fileManager, logger)
|
||||
streamMerger := command.NewStreamMerger()
|
||||
@@ -93,12 +90,6 @@ func _main(background context.Context, args []string) int {
|
||||
// Should never change
|
||||
uid, gid := allSettings.System.UID, allSettings.System.GID
|
||||
|
||||
providerConf := provider.New(allSettings.VPNSP, logger, client, fileManager, firewallConf)
|
||||
|
||||
if !allSettings.Firewall.Enabled {
|
||||
firewallConf.Disable()
|
||||
}
|
||||
|
||||
err = alpineConf.CreateUser("nonrootuser", uid)
|
||||
fatalOnError(err)
|
||||
err = fileManager.SetOwnership("/etc/unbound", uid, gid)
|
||||
@@ -112,17 +103,6 @@ func _main(background context.Context, args []string) int {
|
||||
fatalOnError(err)
|
||||
}
|
||||
|
||||
defaultInterface, defaultGateway, defaultSubnet, err := routingConf.DefaultRoute()
|
||||
fatalOnError(err)
|
||||
|
||||
// Temporarily reset chain policies allowing Kubernetes sidecar to
|
||||
// successfully restart the container. Without this, the existing rules will
|
||||
// pre-exist, preventing the nslookup of the PIA region address. These will
|
||||
// simply be redundant at Docker runtime as they will already be set this way
|
||||
// Thanks to @npawelek https://github.com/npawelek
|
||||
err = firewallConf.AcceptAll(ctx)
|
||||
fatalOnError(err)
|
||||
|
||||
connectedCh := make(chan struct{})
|
||||
signalConnected := func() {
|
||||
connectedCh <- struct{}{}
|
||||
@@ -130,44 +110,23 @@ func _main(background context.Context, args []string) int {
|
||||
defer close(connectedCh)
|
||||
go collectStreamLines(ctx, streamMerger, logger, signalConnected)
|
||||
|
||||
connections, err := providerConf.GetOpenVPNConnections(allSettings.OpenVPN.Provider.ServerSelection)
|
||||
fatalOnError(err)
|
||||
err = providerConf.BuildConf(
|
||||
connections,
|
||||
allSettings.OpenVPN.Verbosity,
|
||||
uid,
|
||||
gid,
|
||||
allSettings.OpenVPN.Root,
|
||||
allSettings.OpenVPN.Cipher,
|
||||
allSettings.OpenVPN.Auth,
|
||||
allSettings.OpenVPN.Provider.ExtraConfigOptions,
|
||||
)
|
||||
fatalOnError(err)
|
||||
|
||||
err = routingConf.AddRoutesVia(ctx, allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface)
|
||||
fatalOnError(err)
|
||||
err = firewallConf.Clear(ctx)
|
||||
fatalOnError(err)
|
||||
err = firewallConf.BlockAll(ctx)
|
||||
fatalOnError(err)
|
||||
err = firewallConf.CreateGeneralRules(ctx)
|
||||
fatalOnError(err)
|
||||
err = firewallConf.CreateVPNRules(ctx, constants.TUN, defaultInterface, connections)
|
||||
fatalOnError(err)
|
||||
err = firewallConf.CreateLocalSubnetsRules(ctx, defaultSubnet, allSettings.Firewall.AllowedSubnets, defaultInterface)
|
||||
fatalOnError(err)
|
||||
err = firewallConf.RunUserPostRules(ctx, fileManager, "/iptables/post-rules.txt")
|
||||
fatalOnError(err)
|
||||
|
||||
// TODO replace these with methods on loopers and pass loopers around
|
||||
restartOpenvpn := make(chan struct{})
|
||||
portForward := make(chan struct{})
|
||||
restartUnbound := make(chan struct{})
|
||||
restartPublicIP := make(chan struct{})
|
||||
restartTinyproxy := make(chan struct{})
|
||||
restartShadowsocks := make(chan struct{})
|
||||
|
||||
openvpnLooper := openvpn.NewLooper(ovpnConf, allSettings.OpenVPN, logger, streamMerger, fatalOnError, uid, gid)
|
||||
if allSettings.Firewall.Enabled {
|
||||
err := firewallConf.SetEnabled(ctx, true) // disabled by default
|
||||
fatalOnError(err)
|
||||
}
|
||||
|
||||
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid,
|
||||
ovpnConf, firewallConf, logger, client, fileManager, streamMerger, fatalOnError)
|
||||
// wait for restartOpenvpn
|
||||
go openvpnLooper.Run(ctx, restartOpenvpn, wg)
|
||||
go openvpnLooper.Run(ctx, restartOpenvpn, portForward, wg)
|
||||
|
||||
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid)
|
||||
// wait for restartUnbound
|
||||
@@ -191,7 +150,6 @@ func _main(background context.Context, args []string) int {
|
||||
}
|
||||
|
||||
go func() {
|
||||
first := true
|
||||
var restartTickerContext context.Context
|
||||
var restartTickerCancel context.CancelFunc = func() {}
|
||||
for {
|
||||
@@ -200,14 +158,10 @@ func _main(background context.Context, args []string) int {
|
||||
restartTickerCancel()
|
||||
return
|
||||
case <-connectedCh: // blocks until openvpn is connected
|
||||
if first {
|
||||
first = false
|
||||
restartUnbound <- struct{}{}
|
||||
}
|
||||
restartTickerCancel()
|
||||
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
||||
go unboundLooper.RunRestartTicker(restartTickerContext, restartUnbound)
|
||||
onConnected(allSettings, logger, routingConf, defaultInterface, providerConf, restartPublicIP)
|
||||
onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -224,11 +178,10 @@ func _main(background context.Context, args []string) int {
|
||||
syscall.SIGTERM,
|
||||
os.Interrupt,
|
||||
)
|
||||
exitStatus := 0
|
||||
shutdownErrorsCount := 0
|
||||
select {
|
||||
case signal := <-signalsCh:
|
||||
logger.Warn("Caught OS signal %s, shutting down", signal)
|
||||
exitStatus = 1
|
||||
cancel()
|
||||
case <-ctx.Done():
|
||||
logger.Warn("context canceled, shutting down")
|
||||
@@ -236,20 +189,37 @@ func _main(background context.Context, args []string) int {
|
||||
logger.Info("Clearing ip status file %s", allSettings.System.IPStatusFilepath)
|
||||
if err := fileManager.Remove(string(allSettings.System.IPStatusFilepath)); err != nil {
|
||||
logger.Error(err)
|
||||
exitStatus = 1
|
||||
shutdownErrorsCount++
|
||||
}
|
||||
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
||||
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
|
||||
if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
|
||||
logger.Error(err)
|
||||
exitStatus = 1
|
||||
shutdownErrorsCount++
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
return exitStatus
|
||||
waiting, waited := context.WithTimeout(context.Background(), time.Second)
|
||||
go func() {
|
||||
defer waited()
|
||||
wg.Wait()
|
||||
}()
|
||||
<-waiting.Done()
|
||||
if waiting.Err() == context.DeadlineExceeded {
|
||||
if shutdownErrorsCount > 0 {
|
||||
logger.Warn("Shutdown had %d errors", shutdownErrorsCount)
|
||||
}
|
||||
logger.Warn("Shutdown timed out")
|
||||
return 1
|
||||
}
|
||||
if shutdownErrorsCount > 0 {
|
||||
logger.Warn("Shutdown had %d errors")
|
||||
return 1
|
||||
}
|
||||
logger.Info("Shutdown successful")
|
||||
return 0
|
||||
}
|
||||
|
||||
func makeFatalOnError(logger logging.Logger, cancel func(), wg *sync.WaitGroup) func(err error) {
|
||||
func makeFatalOnError(logger logging.Logger, cancel context.CancelFunc, wg *sync.WaitGroup) func(err error) {
|
||||
return func(err error) {
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
@@ -321,48 +291,25 @@ func trimEventualProgramPrefix(s string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func onConnected(allSettings settings.Settings,
|
||||
logger logging.Logger, routingConf routing.Routing, defaultInterface string,
|
||||
providerConf provider.Provider, restartPublicIP chan<- struct{},
|
||||
func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing,
|
||||
portForward, restartUnbound, restartPublicIP chan<- struct{},
|
||||
) {
|
||||
restartUnbound <- struct{}{}
|
||||
restartPublicIP <- struct{}{}
|
||||
uid, gid := allSettings.System.UID, allSettings.System.GID
|
||||
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
||||
time.AfterFunc(5*time.Second, func() {
|
||||
setupPortForwarding(logger, providerConf, allSettings.OpenVPN.Provider.PortForwarding.Filepath, uid, gid)
|
||||
portForward <- struct{}{}
|
||||
})
|
||||
}
|
||||
|
||||
vpnGatewayIP, err := routingConf.VPNGatewayIP(defaultInterface)
|
||||
defaultInterface, _, _, err := routingConf.DefaultRoute()
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
} else {
|
||||
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
|
||||
}
|
||||
}
|
||||
|
||||
func setupPortForwarding(logger logging.Logger, providerConf provider.Provider, filepath models.Filepath, uid, gid int) {
|
||||
pfLogger := logger.WithPrefix("port forwarding: ")
|
||||
var port uint16
|
||||
var err error
|
||||
for {
|
||||
port, err = providerConf.GetPortForward()
|
||||
vpnGatewayIP, err := routingConf.VPNGatewayIP(defaultInterface)
|
||||
if err != nil {
|
||||
pfLogger.Error(err)
|
||||
pfLogger.Info("retrying in 5 seconds...")
|
||||
time.Sleep(5 * time.Second)
|
||||
logger.Warn(err)
|
||||
} else {
|
||||
pfLogger.Info("port forwarded is %d", port)
|
||||
break
|
||||
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
|
||||
}
|
||||
}
|
||||
pfLogger.Info("writing forwarded port to %s", filepath)
|
||||
if err := providerConf.WritePortForward(filepath, port, uid, gid); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := providerConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user