Firewall refactoring
- Ability to enable and disable rules in various loops - Simplified code overall - Port forwarding moved into openvpn loop - Route addition and removal improved
This commit is contained in:
@@ -17,13 +17,10 @@ import (
|
|||||||
"github.com/qdm12/golibs/network"
|
"github.com/qdm12/golibs/network"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/alpine"
|
"github.com/qdm12/private-internet-access-docker/internal/alpine"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/cli"
|
"github.com/qdm12/private-internet-access-docker/internal/cli"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/dns"
|
"github.com/qdm12/private-internet-access-docker/internal/dns"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/openvpn"
|
"github.com/qdm12/private-internet-access-docker/internal/openvpn"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/params"
|
"github.com/qdm12/private-internet-access-docker/internal/params"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/provider"
|
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/publicip"
|
"github.com/qdm12/private-internet-access-docker/internal/publicip"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/routing"
|
"github.com/qdm12/private-internet-access-docker/internal/routing"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/server"
|
"github.com/qdm12/private-internet-access-docker/internal/server"
|
||||||
@@ -72,8 +69,8 @@ func _main(background context.Context, args []string) int {
|
|||||||
alpineConf := alpine.NewConfigurator(fileManager)
|
alpineConf := alpine.NewConfigurator(fileManager)
|
||||||
ovpnConf := openvpn.NewConfigurator(logger, fileManager)
|
ovpnConf := openvpn.NewConfigurator(logger, fileManager)
|
||||||
dnsConf := dns.NewConfigurator(logger, client, fileManager)
|
dnsConf := dns.NewConfigurator(logger, client, fileManager)
|
||||||
firewallConf := firewall.NewConfigurator(logger)
|
|
||||||
routingConf := routing.NewRouting(logger, fileManager)
|
routingConf := routing.NewRouting(logger, fileManager)
|
||||||
|
firewallConf := firewall.NewConfigurator(logger, routingConf, fileManager)
|
||||||
tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger)
|
tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger)
|
||||||
shadowsocksConf := shadowsocks.NewConfigurator(fileManager, logger)
|
shadowsocksConf := shadowsocks.NewConfigurator(fileManager, logger)
|
||||||
streamMerger := command.NewStreamMerger()
|
streamMerger := command.NewStreamMerger()
|
||||||
@@ -93,12 +90,6 @@ func _main(background context.Context, args []string) int {
|
|||||||
// Should never change
|
// Should never change
|
||||||
uid, gid := allSettings.System.UID, allSettings.System.GID
|
uid, gid := allSettings.System.UID, allSettings.System.GID
|
||||||
|
|
||||||
providerConf := provider.New(allSettings.VPNSP, logger, client, fileManager, firewallConf)
|
|
||||||
|
|
||||||
if !allSettings.Firewall.Enabled {
|
|
||||||
firewallConf.Disable()
|
|
||||||
}
|
|
||||||
|
|
||||||
err = alpineConf.CreateUser("nonrootuser", uid)
|
err = alpineConf.CreateUser("nonrootuser", uid)
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
err = fileManager.SetOwnership("/etc/unbound", uid, gid)
|
err = fileManager.SetOwnership("/etc/unbound", uid, gid)
|
||||||
@@ -112,17 +103,6 @@ func _main(background context.Context, args []string) int {
|
|||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultInterface, defaultGateway, defaultSubnet, err := routingConf.DefaultRoute()
|
|
||||||
fatalOnError(err)
|
|
||||||
|
|
||||||
// Temporarily reset chain policies allowing Kubernetes sidecar to
|
|
||||||
// successfully restart the container. Without this, the existing rules will
|
|
||||||
// pre-exist, preventing the nslookup of the PIA region address. These will
|
|
||||||
// simply be redundant at Docker runtime as they will already be set this way
|
|
||||||
// Thanks to @npawelek https://github.com/npawelek
|
|
||||||
err = firewallConf.AcceptAll(ctx)
|
|
||||||
fatalOnError(err)
|
|
||||||
|
|
||||||
connectedCh := make(chan struct{})
|
connectedCh := make(chan struct{})
|
||||||
signalConnected := func() {
|
signalConnected := func() {
|
||||||
connectedCh <- struct{}{}
|
connectedCh <- struct{}{}
|
||||||
@@ -130,44 +110,23 @@ func _main(background context.Context, args []string) int {
|
|||||||
defer close(connectedCh)
|
defer close(connectedCh)
|
||||||
go collectStreamLines(ctx, streamMerger, logger, signalConnected)
|
go collectStreamLines(ctx, streamMerger, logger, signalConnected)
|
||||||
|
|
||||||
connections, err := providerConf.GetOpenVPNConnections(allSettings.OpenVPN.Provider.ServerSelection)
|
// TODO replace these with methods on loopers and pass loopers around
|
||||||
fatalOnError(err)
|
|
||||||
err = providerConf.BuildConf(
|
|
||||||
connections,
|
|
||||||
allSettings.OpenVPN.Verbosity,
|
|
||||||
uid,
|
|
||||||
gid,
|
|
||||||
allSettings.OpenVPN.Root,
|
|
||||||
allSettings.OpenVPN.Cipher,
|
|
||||||
allSettings.OpenVPN.Auth,
|
|
||||||
allSettings.OpenVPN.Provider.ExtraConfigOptions,
|
|
||||||
)
|
|
||||||
fatalOnError(err)
|
|
||||||
|
|
||||||
err = routingConf.AddRoutesVia(ctx, allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface)
|
|
||||||
fatalOnError(err)
|
|
||||||
err = firewallConf.Clear(ctx)
|
|
||||||
fatalOnError(err)
|
|
||||||
err = firewallConf.BlockAll(ctx)
|
|
||||||
fatalOnError(err)
|
|
||||||
err = firewallConf.CreateGeneralRules(ctx)
|
|
||||||
fatalOnError(err)
|
|
||||||
err = firewallConf.CreateVPNRules(ctx, constants.TUN, defaultInterface, connections)
|
|
||||||
fatalOnError(err)
|
|
||||||
err = firewallConf.CreateLocalSubnetsRules(ctx, defaultSubnet, allSettings.Firewall.AllowedSubnets, defaultInterface)
|
|
||||||
fatalOnError(err)
|
|
||||||
err = firewallConf.RunUserPostRules(ctx, fileManager, "/iptables/post-rules.txt")
|
|
||||||
fatalOnError(err)
|
|
||||||
|
|
||||||
restartOpenvpn := make(chan struct{})
|
restartOpenvpn := make(chan struct{})
|
||||||
|
portForward := make(chan struct{})
|
||||||
restartUnbound := make(chan struct{})
|
restartUnbound := make(chan struct{})
|
||||||
restartPublicIP := make(chan struct{})
|
restartPublicIP := make(chan struct{})
|
||||||
restartTinyproxy := make(chan struct{})
|
restartTinyproxy := make(chan struct{})
|
||||||
restartShadowsocks := make(chan struct{})
|
restartShadowsocks := make(chan struct{})
|
||||||
|
|
||||||
openvpnLooper := openvpn.NewLooper(ovpnConf, allSettings.OpenVPN, logger, streamMerger, fatalOnError, uid, gid)
|
if allSettings.Firewall.Enabled {
|
||||||
|
err := firewallConf.SetEnabled(ctx, true) // disabled by default
|
||||||
|
fatalOnError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid,
|
||||||
|
ovpnConf, firewallConf, logger, client, fileManager, streamMerger, fatalOnError)
|
||||||
// wait for restartOpenvpn
|
// wait for restartOpenvpn
|
||||||
go openvpnLooper.Run(ctx, restartOpenvpn, wg)
|
go openvpnLooper.Run(ctx, restartOpenvpn, portForward, wg)
|
||||||
|
|
||||||
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid)
|
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid)
|
||||||
// wait for restartUnbound
|
// wait for restartUnbound
|
||||||
@@ -191,7 +150,6 @@ func _main(background context.Context, args []string) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
first := true
|
|
||||||
var restartTickerContext context.Context
|
var restartTickerContext context.Context
|
||||||
var restartTickerCancel context.CancelFunc = func() {}
|
var restartTickerCancel context.CancelFunc = func() {}
|
||||||
for {
|
for {
|
||||||
@@ -200,14 +158,10 @@ func _main(background context.Context, args []string) int {
|
|||||||
restartTickerCancel()
|
restartTickerCancel()
|
||||||
return
|
return
|
||||||
case <-connectedCh: // blocks until openvpn is connected
|
case <-connectedCh: // blocks until openvpn is connected
|
||||||
if first {
|
|
||||||
first = false
|
|
||||||
restartUnbound <- struct{}{}
|
|
||||||
}
|
|
||||||
restartTickerCancel()
|
restartTickerCancel()
|
||||||
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
||||||
go unboundLooper.RunRestartTicker(restartTickerContext, restartUnbound)
|
go unboundLooper.RunRestartTicker(restartTickerContext, restartUnbound)
|
||||||
onConnected(allSettings, logger, routingConf, defaultInterface, providerConf, restartPublicIP)
|
onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -224,11 +178,10 @@ func _main(background context.Context, args []string) int {
|
|||||||
syscall.SIGTERM,
|
syscall.SIGTERM,
|
||||||
os.Interrupt,
|
os.Interrupt,
|
||||||
)
|
)
|
||||||
exitStatus := 0
|
shutdownErrorsCount := 0
|
||||||
select {
|
select {
|
||||||
case signal := <-signalsCh:
|
case signal := <-signalsCh:
|
||||||
logger.Warn("Caught OS signal %s, shutting down", signal)
|
logger.Warn("Caught OS signal %s, shutting down", signal)
|
||||||
exitStatus = 1
|
|
||||||
cancel()
|
cancel()
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
logger.Warn("context canceled, shutting down")
|
logger.Warn("context canceled, shutting down")
|
||||||
@@ -236,20 +189,37 @@ func _main(background context.Context, args []string) int {
|
|||||||
logger.Info("Clearing ip status file %s", allSettings.System.IPStatusFilepath)
|
logger.Info("Clearing ip status file %s", allSettings.System.IPStatusFilepath)
|
||||||
if err := fileManager.Remove(string(allSettings.System.IPStatusFilepath)); err != nil {
|
if err := fileManager.Remove(string(allSettings.System.IPStatusFilepath)); err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
exitStatus = 1
|
shutdownErrorsCount++
|
||||||
}
|
}
|
||||||
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
||||||
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
|
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
|
||||||
if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
|
if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
exitStatus = 1
|
shutdownErrorsCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
waiting, waited := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
go func() {
|
||||||
|
defer waited()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
return exitStatus
|
}()
|
||||||
|
<-waiting.Done()
|
||||||
|
if waiting.Err() == context.DeadlineExceeded {
|
||||||
|
if shutdownErrorsCount > 0 {
|
||||||
|
logger.Warn("Shutdown had %d errors", shutdownErrorsCount)
|
||||||
|
}
|
||||||
|
logger.Warn("Shutdown timed out")
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if shutdownErrorsCount > 0 {
|
||||||
|
logger.Warn("Shutdown had %d errors")
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
logger.Info("Shutdown successful")
|
||||||
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeFatalOnError(logger logging.Logger, cancel func(), wg *sync.WaitGroup) func(err error) {
|
func makeFatalOnError(logger logging.Logger, cancel context.CancelFunc, wg *sync.WaitGroup) func(err error) {
|
||||||
return func(err error) {
|
return func(err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
@@ -321,48 +291,25 @@ func trimEventualProgramPrefix(s string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func onConnected(allSettings settings.Settings,
|
func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing,
|
||||||
logger logging.Logger, routingConf routing.Routing, defaultInterface string,
|
portForward, restartUnbound, restartPublicIP chan<- struct{},
|
||||||
providerConf provider.Provider, restartPublicIP chan<- struct{},
|
|
||||||
) {
|
) {
|
||||||
|
restartUnbound <- struct{}{}
|
||||||
restartPublicIP <- struct{}{}
|
restartPublicIP <- struct{}{}
|
||||||
uid, gid := allSettings.System.UID, allSettings.System.GID
|
|
||||||
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
||||||
time.AfterFunc(5*time.Second, func() {
|
time.AfterFunc(5*time.Second, func() {
|
||||||
setupPortForwarding(logger, providerConf, allSettings.OpenVPN.Provider.PortForwarding.Filepath, uid, gid)
|
portForward <- struct{}{}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
defaultInterface, _, _, err := routingConf.DefaultRoute()
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn(err)
|
||||||
|
} else {
|
||||||
vpnGatewayIP, err := routingConf.VPNGatewayIP(defaultInterface)
|
vpnGatewayIP, err := routingConf.VPNGatewayIP(defaultInterface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
} else {
|
} else {
|
||||||
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
|
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func setupPortForwarding(logger logging.Logger, providerConf provider.Provider, filepath models.Filepath, uid, gid int) {
|
|
||||||
pfLogger := logger.WithPrefix("port forwarding: ")
|
|
||||||
var port uint16
|
|
||||||
var err error
|
|
||||||
for {
|
|
||||||
port, err = providerConf.GetPortForward()
|
|
||||||
if err != nil {
|
|
||||||
pfLogger.Error(err)
|
|
||||||
pfLogger.Info("retrying in 5 seconds...")
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
} else {
|
|
||||||
pfLogger.Info("port forwarded is %d", port)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pfLogger.Info("writing forwarded port to %s", filepath)
|
|
||||||
if err := providerConf.WritePortForward(filepath, port, uid, gid); err != nil {
|
|
||||||
pfLogger.Error(err)
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := providerConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
|
|
||||||
pfLogger.Error(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
149
internal/firewall/enable.go
Normal file
149
internal/firewall/enable.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *configurator) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
||||||
|
c.stateMutex.Lock()
|
||||||
|
defer c.stateMutex.Unlock()
|
||||||
|
|
||||||
|
if enabled == c.enabled {
|
||||||
|
if enabled {
|
||||||
|
c.logger.Info("already enabled")
|
||||||
|
} else {
|
||||||
|
c.logger.Info("already disabled")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !enabled {
|
||||||
|
c.logger.Info("disabling...")
|
||||||
|
if err = c.disable(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.enabled = false
|
||||||
|
c.logger.Info("disabled successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("enabling...")
|
||||||
|
|
||||||
|
if err := c.enable(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.enabled = true
|
||||||
|
c.logger.Info("enabled successfully")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *configurator) disable(ctx context.Context) (err error) {
|
||||||
|
if err = c.clearAllRules(ctx); err != nil {
|
||||||
|
return fmt.Errorf("cannot disable firewall: %w", err)
|
||||||
|
}
|
||||||
|
if err = c.setAllPolicies(ctx, "ACCEPT"); err != nil {
|
||||||
|
return fmt.Errorf("cannot disable firewall: %w", err)
|
||||||
|
}
|
||||||
|
// TODO routes?
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// To use in defered call when enabling the firewall
|
||||||
|
func (c *configurator) fallbackToDisabled(ctx context.Context) {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := c.SetEnabled(ctx, false); err != nil {
|
||||||
|
c.logger.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocognit
|
||||||
|
defaultInterface, defaultGateway, defaultSubnet, err := c.routing.DefaultRoute()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(1)
|
||||||
|
if err = c.setAllPolicies(ctx, "DROP"); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const remove = false
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
c.fallbackToDisabled(ctx)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Loopback traffic
|
||||||
|
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
for _, conn := range c.vpnConnections {
|
||||||
|
if err = c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
if err := c.acceptInputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
if err := c.acceptOutputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
for _, subnet := range c.allowedSubnets {
|
||||||
|
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Re-ensure all routes exist
|
||||||
|
for _, subnet := range c.allowedSubnets {
|
||||||
|
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for port := range c.allowedPorts {
|
||||||
|
// TODO restrict interface
|
||||||
|
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.portForwarded > 0 {
|
||||||
|
const tun = string(constants.TUN)
|
||||||
|
if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.runUserPostRules(ctx, "/iptables/post-rules.txt", remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -3,42 +3,49 @@ package firewall
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/qdm12/golibs/command"
|
"github.com/qdm12/golibs/command"
|
||||||
"github.com/qdm12/golibs/files"
|
"github.com/qdm12/golibs/files"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||||
|
"github.com/qdm12/private-internet-access-docker/internal/routing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Configurator allows to change firewall rules and modify network routes
|
// Configurator allows to change firewall rules and modify network routes
|
||||||
type Configurator interface {
|
type Configurator interface {
|
||||||
Version(ctx context.Context) (string, error)
|
Version(ctx context.Context) (string, error)
|
||||||
AcceptAll(ctx context.Context) error
|
SetEnabled(ctx context.Context, enabled bool) (err error)
|
||||||
Clear(ctx context.Context) error
|
SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error)
|
||||||
BlockAll(ctx context.Context) error
|
SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error)
|
||||||
CreateGeneralRules(ctx context.Context) error
|
SetAllowedPort(ctx context.Context, port uint16) error
|
||||||
CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error
|
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||||
CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error
|
SetPortForward(ctx context.Context, port uint16) (err error)
|
||||||
AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error
|
|
||||||
AllowAnyIncomingOnPort(ctx context.Context, port uint16) error
|
|
||||||
RunUserPostRules(ctx context.Context, fileManager files.FileManager, filepath string) error
|
|
||||||
Disable()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type configurator struct {
|
type configurator struct { //nolint:maligned
|
||||||
commander command.Commander
|
commander command.Commander
|
||||||
logger logging.Logger
|
logger logging.Logger
|
||||||
disabled bool
|
routing routing.Routing
|
||||||
|
fileManager files.FileManager // for custom iptables rules
|
||||||
|
iptablesMutex sync.Mutex
|
||||||
|
|
||||||
|
// State
|
||||||
|
enabled bool
|
||||||
|
vpnConnections []models.OpenVPNConnection
|
||||||
|
allowedSubnets []net.IPNet
|
||||||
|
allowedPorts map[uint16]struct{}
|
||||||
|
portForwarded uint16
|
||||||
|
stateMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConfigurator creates a new Configurator instance
|
// NewConfigurator creates a new Configurator instance
|
||||||
func NewConfigurator(logger logging.Logger) Configurator {
|
func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager files.FileManager) Configurator {
|
||||||
return &configurator{
|
return &configurator{
|
||||||
commander: command.NewCommander(),
|
commander: command.NewCommander(),
|
||||||
logger: logger.WithPrefix("firewall configurator: "),
|
logger: logger.WithPrefix("firewall: "),
|
||||||
|
routing: routing,
|
||||||
|
fileManager: fileManager,
|
||||||
|
allowedPorts: make(map[uint16]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) Disable() {
|
|
||||||
c.disabled = true
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,10 +6,32 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/qdm12/golibs/files"
|
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func appendOrDelete(remove bool) string {
|
||||||
|
if remove {
|
||||||
|
return "--delete"
|
||||||
|
}
|
||||||
|
return "--append"
|
||||||
|
}
|
||||||
|
|
||||||
|
// flipRule changes an append rule in a delete rule or a delete rule into an
|
||||||
|
// append rule.
|
||||||
|
func flipRule(rule string) string {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(rule, "-A"):
|
||||||
|
return strings.Replace(rule, "-A", "-D", 1)
|
||||||
|
case strings.HasPrefix(rule, "--append"):
|
||||||
|
return strings.Replace(rule, "--append", "-D", 1)
|
||||||
|
case strings.HasPrefix(rule, "-D"):
|
||||||
|
return strings.Replace(rule, "-D", "-A", 1)
|
||||||
|
case strings.HasPrefix(rule, "--delete"):
|
||||||
|
return strings.Replace(rule, "--delete", "-A", 1)
|
||||||
|
}
|
||||||
|
return rule
|
||||||
|
}
|
||||||
|
|
||||||
// Version obtains the version of the installed iptables
|
// Version obtains the version of the installed iptables
|
||||||
func (c *configurator) Version(ctx context.Context) (string, error) {
|
func (c *configurator) Version(ctx context.Context) (string, error) {
|
||||||
output, err := c.commander.Run(ctx, "iptables", "--version")
|
output, err := c.commander.Run(ctx, "iptables", "--version")
|
||||||
@@ -33,6 +55,8 @@ func (c *configurator) runIptablesInstructions(ctx context.Context, instructions
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) runIptablesInstruction(ctx context.Context, instruction string) error {
|
func (c *configurator) runIptablesInstruction(ctx context.Context, instruction string) error {
|
||||||
|
c.iptablesMutex.Lock() // only one iptables command at once
|
||||||
|
defer c.iptablesMutex.Unlock()
|
||||||
flags := strings.Fields(instruction)
|
flags := strings.Fields(instruction)
|
||||||
if output, err := c.commander.Run(ctx, "iptables", flags...); err != nil {
|
if output, err := c.commander.Run(ctx, "iptables", flags...); err != nil {
|
||||||
return fmt.Errorf("failed executing \"iptables %s\": %s: %w", instruction, output, err)
|
return fmt.Errorf("failed executing \"iptables %s\": %s: %w", instruction, output, err)
|
||||||
@@ -40,146 +64,119 @@ func (c *configurator) runIptablesInstruction(ctx context.Context, instruction s
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) Clear(ctx context.Context) error {
|
func (c *configurator) clearAllRules(ctx context.Context) error {
|
||||||
if c.disabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
c.logger.Info("clearing all rules")
|
|
||||||
return c.runIptablesInstructions(ctx, []string{
|
return c.runIptablesInstructions(ctx, []string{
|
||||||
"--flush",
|
"--flush", // flush all chains
|
||||||
"--delete-chain",
|
"--delete-chain", // delete all chains
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) AcceptAll(ctx context.Context) error {
|
func (c *configurator) setAllPolicies(ctx context.Context, policy string) error {
|
||||||
if c.disabled {
|
switch policy {
|
||||||
return nil
|
case "ACCEPT", "DROP":
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("policy %q not recognized", policy)
|
||||||
}
|
}
|
||||||
c.logger.Info("accepting all traffic")
|
|
||||||
return c.runIptablesInstructions(ctx, []string{
|
return c.runIptablesInstructions(ctx, []string{
|
||||||
"-P INPUT ACCEPT",
|
fmt.Sprintf("--policy INPUT %s", policy),
|
||||||
"-P OUTPUT ACCEPT",
|
fmt.Sprintf("--policy OUTPUT %s", policy),
|
||||||
"-P FORWARD ACCEPT",
|
fmt.Sprintf("--policy FORWARD %s", policy),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) BlockAll(ctx context.Context) error {
|
func (c *configurator) acceptInputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||||
if c.disabled {
|
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||||
return nil
|
"%s INPUT -i %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||||
}
|
))
|
||||||
c.logger.Info("blocking all traffic")
|
}
|
||||||
|
|
||||||
|
func (c *configurator) acceptOutputThroughInterface(ctx context.Context, intf string, remove bool) error {
|
||||||
|
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||||
|
"%s OUTPUT -o %s -j ACCEPT", appendOrDelete(remove), intf,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *configurator) acceptEstablishedRelatedTraffic(ctx context.Context, remove bool) error {
|
||||||
return c.runIptablesInstructions(ctx, []string{
|
return c.runIptablesInstructions(ctx, []string{
|
||||||
"-P INPUT DROP",
|
fmt.Sprintf("%s OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||||
"-F OUTPUT",
|
fmt.Sprintf("%s INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", appendOrDelete(remove)),
|
||||||
"-P OUTPUT DROP",
|
|
||||||
"-P FORWARD DROP",
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) CreateGeneralRules(ctx context.Context) error {
|
func (c *configurator) acceptOutputTrafficToVPN(ctx context.Context, defaultInterface string, connection models.OpenVPNConnection, remove bool) error {
|
||||||
if c.disabled {
|
return c.runIptablesInstruction(ctx,
|
||||||
return nil
|
fmt.Sprintf("%s OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
||||||
}
|
appendOrDelete(remove), connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port))
|
||||||
c.logger.Info("creating general rules")
|
|
||||||
return c.runIptablesInstructions(ctx, []string{
|
|
||||||
"-A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
|
||||||
"-A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT",
|
|
||||||
"-A OUTPUT -o lo -j ACCEPT",
|
|
||||||
"-A INPUT -i lo -j ACCEPT",
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error {
|
func (c *configurator) acceptInputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error {
|
||||||
if c.disabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
for _, connection := range connections {
|
|
||||||
c.logger.Info("allowing output traffic to VPN server %s through %s on port %s %d",
|
|
||||||
connection.IP, defaultInterface, connection.Protocol, connection.Port)
|
|
||||||
if err := c.runIptablesInstruction(ctx,
|
|
||||||
fmt.Sprintf("-A OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT",
|
|
||||||
connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := c.runIptablesInstruction(ctx, fmt.Sprintf("-A OUTPUT -o %s -j ACCEPT", dev)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *configurator) CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error {
|
|
||||||
if c.disabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
subnetStr := subnet.String()
|
subnetStr := subnet.String()
|
||||||
c.logger.Info("accepting input and output traffic for %s", subnetStr)
|
interfaceFlag := "-i " + intf
|
||||||
if err := c.runIptablesInstructions(ctx, []string{
|
if intf == "*" { // all interfaces
|
||||||
fmt.Sprintf("-A INPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
|
interfaceFlag = ""
|
||||||
fmt.Sprintf("-A OUTPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr),
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
for _, extraSubnet := range extraSubnets {
|
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||||
extraSubnetStr := extraSubnet.String()
|
"%s INPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr,
|
||||||
c.logger.Info("accepting input traffic through %s from %s to %s", defaultInterface, extraSubnetStr, subnetStr)
|
))
|
||||||
if err := c.runIptablesInstruction(ctx,
|
|
||||||
fmt.Sprintf("-A INPUT -i %s -s %s -d %s -j ACCEPT", defaultInterface, extraSubnetStr, subnetStr)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Thanks to @npawelek
|
|
||||||
c.logger.Info("accepting output traffic through %s from %s to %s", defaultInterface, subnetStr, extraSubnetStr)
|
|
||||||
if err := c.runIptablesInstruction(ctx,
|
|
||||||
fmt.Sprintf("-A OUTPUT -o %s -s %s -d %s -j ACCEPT", defaultInterface, subnetStr, extraSubnetStr)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Used for port forwarding
|
// Thanks to @npawelek
|
||||||
func (c *configurator) AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error {
|
func (c *configurator) acceptOutputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error {
|
||||||
if c.disabled {
|
subnetStr := subnet.String()
|
||||||
return nil
|
interfaceFlag := "-o " + intf
|
||||||
|
if intf == "*" { // all interfaces
|
||||||
|
interfaceFlag = ""
|
||||||
}
|
}
|
||||||
c.logger.Info("accepting input traffic through %s on port %d", device, port)
|
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||||
return c.runIptablesInstructions(ctx, []string{
|
"%s OUTPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr,
|
||||||
fmt.Sprintf("-A INPUT -i %s -p tcp --dport %d -j ACCEPT", device, port),
|
))
|
||||||
fmt.Sprintf("-A INPUT -i %s -p udp --dport %d -j ACCEPT", device, port),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) AllowAnyIncomingOnPort(ctx context.Context, port uint16) error {
|
// Used for port forwarding, with intf set to tun
|
||||||
if c.disabled {
|
func (c *configurator) acceptInputToPort(ctx context.Context, intf string, protocol models.NetworkProtocol, port uint16, remove bool) error {
|
||||||
return nil
|
interfaceFlag := "-i " + intf
|
||||||
|
if intf == "*" { // all interfaces
|
||||||
|
interfaceFlag = ""
|
||||||
}
|
}
|
||||||
c.logger.Info("accepting any input traffic on port %d", port)
|
return c.runIptablesInstruction(ctx,
|
||||||
return c.runIptablesInstructions(ctx, []string{
|
fmt.Sprintf("%s INPUT %s -p %s --dport %d -j ACCEPT", appendOrDelete(remove), interfaceFlag, protocol, port),
|
||||||
fmt.Sprintf("-A INPUT -p tcp --dport %d -j ACCEPT", port),
|
)
|
||||||
fmt.Sprintf("-A INPUT -p udp --dport %d -j ACCEPT", port),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) RunUserPostRules(ctx context.Context, fileManager files.FileManager, filepath string) error {
|
func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
|
||||||
exists, err := fileManager.FileExists(filepath)
|
exists, err := c.fileManager.FileExists(filepath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
} else if !exists {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
if exists {
|
b, err := c.fileManager.ReadFile(filepath)
|
||||||
b, err := fileManager.ReadFile(filepath)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
lines := strings.Split(string(b), "\n")
|
lines := strings.Split(string(b), "\n")
|
||||||
var rules []string
|
successfulRules := []string{}
|
||||||
|
defer func() {
|
||||||
|
// transaction-like rollback
|
||||||
|
if err == nil || ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, rule := range successfulRules {
|
||||||
|
_ = c.runIptablesInstruction(ctx, flipRule(rule))
|
||||||
|
}
|
||||||
|
}()
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
if !strings.HasPrefix(line, "iptables ") {
|
if !strings.HasPrefix(line, "iptables ") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
rules = append(rules, strings.TrimPrefix(line, "iptables "))
|
rule := strings.TrimPrefix(line, "iptables ")
|
||||||
c.logger.Info("running user post firewall rule: %s", line)
|
if remove {
|
||||||
|
rule = flipRule(rule)
|
||||||
}
|
}
|
||||||
return c.runIptablesInstructions(ctx, rules)
|
if err = c.runIptablesInstruction(ctx, rule); err != nil {
|
||||||
|
return fmt.Errorf("cannot run custom rule: %w", err)
|
||||||
|
}
|
||||||
|
successfulRules = append(successfulRules, rule)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
109
internal/firewall/ports.go
Normal file
109
internal/firewall/ports.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *configurator) SetAllowedPort(ctx context.Context, port uint16) (err error) {
|
||||||
|
c.stateMutex.Lock()
|
||||||
|
defer c.stateMutex.Unlock()
|
||||||
|
|
||||||
|
if port == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.enabled {
|
||||||
|
c.logger.Info("firewall disabled, only updating allowed ports internal list")
|
||||||
|
c.allowedPorts[port] = struct{}{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("setting allowed port %d through firewall...", port)
|
||||||
|
|
||||||
|
if _, ok := c.allowedPorts[port]; ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const remove = false
|
||||||
|
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err)
|
||||||
|
}
|
||||||
|
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot set allowed port %d through firewall: %w", port, err)
|
||||||
|
}
|
||||||
|
c.allowedPorts[port] = struct{}{}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *configurator) RemoveAllowedPort(ctx context.Context, port uint16) (err error) {
|
||||||
|
c.stateMutex.Lock()
|
||||||
|
defer c.stateMutex.Unlock()
|
||||||
|
|
||||||
|
if port == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.enabled {
|
||||||
|
c.logger.Info("firewall disabled, only updating allowed ports internal list")
|
||||||
|
delete(c.allowedPorts, port)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("removing allowed port %d through firewall...", port)
|
||||||
|
|
||||||
|
if _, ok := c.allowedPorts[port]; !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const remove = true
|
||||||
|
if err := c.acceptInputToPort(ctx, "*", constants.TCP, port, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err)
|
||||||
|
}
|
||||||
|
if err := c.acceptInputToPort(ctx, "*", constants.UDP, port, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot remove allowed port %d through firewall: %w", port, err)
|
||||||
|
}
|
||||||
|
delete(c.allowedPorts, port)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use 0 to remove
|
||||||
|
func (c *configurator) SetPortForward(ctx context.Context, port uint16) (err error) {
|
||||||
|
c.stateMutex.Lock()
|
||||||
|
defer c.stateMutex.Unlock()
|
||||||
|
|
||||||
|
if port == c.portForwarded {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.enabled {
|
||||||
|
c.logger.Info("firewall disabled, only updating port forwarded internally")
|
||||||
|
c.portForwarded = port
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const tun = string(constants.TUN)
|
||||||
|
if err := c.acceptInputToPort(ctx, tun, constants.TCP, c.portForwarded, true); err != nil {
|
||||||
|
return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err)
|
||||||
|
}
|
||||||
|
if err := c.acceptInputToPort(ctx, tun, constants.UDP, c.portForwarded, true); err != nil {
|
||||||
|
return fmt.Errorf("cannot remove outdated port forward rule from firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if port == 0 { // not changing port
|
||||||
|
c.portForwarded = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.acceptInputToPort(ctx, tun, constants.TCP, port, false); err != nil {
|
||||||
|
return fmt.Errorf("cannot accept port forwarded through firewall: %w", err)
|
||||||
|
}
|
||||||
|
if err := c.acceptInputToPort(ctx, tun, constants.UDP, port, false); err != nil {
|
||||||
|
return fmt.Errorf("cannot accept port forwarded through firewall: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
127
internal/firewall/subnets.go
Normal file
127
internal/firewall/subnets.go
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNet) (err error) {
|
||||||
|
c.stateMutex.Lock()
|
||||||
|
defer c.stateMutex.Unlock()
|
||||||
|
|
||||||
|
if !c.enabled {
|
||||||
|
c.logger.Info("firewall disabled, only updating allowed subnets internal list")
|
||||||
|
c.allowedSubnets = make([]net.IPNet, len(subnets))
|
||||||
|
copy(c.allowedSubnets, subnets)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("setting allowed subnets through firewall...")
|
||||||
|
|
||||||
|
subnetsToAdd := findSubnetsToAdd(c.allowedSubnets, subnets)
|
||||||
|
subnetsToRemove := findSubnetsToRemove(c.allowedSubnets, subnets)
|
||||||
|
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultInterface, defaultGateway, _, err := c.routing.DefaultRoute()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.removeSubnets(ctx, subnetsToRemove, defaultInterface)
|
||||||
|
if err := c.addSubnets(ctx, subnetsToAdd, defaultInterface, defaultGateway); err != nil {
|
||||||
|
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func findSubnetsToAdd(oldSubnets, newSubnets []net.IPNet) (subnetsToAdd []net.IPNet) {
|
||||||
|
for _, newSubnet := range newSubnets {
|
||||||
|
found := false
|
||||||
|
for _, oldSubnet := range oldSubnets {
|
||||||
|
if subnetsAreEqual(oldSubnet, newSubnet) {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
subnetsToAdd = append(subnetsToAdd, newSubnet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return subnetsToAdd
|
||||||
|
}
|
||||||
|
|
||||||
|
func findSubnetsToRemove(oldSubnets, newSubnets []net.IPNet) (subnetsToRemove []net.IPNet) {
|
||||||
|
for _, oldSubnet := range oldSubnets {
|
||||||
|
found := false
|
||||||
|
for _, newSubnet := range newSubnets {
|
||||||
|
if subnetsAreEqual(oldSubnet, newSubnet) {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
subnetsToRemove = append(subnetsToRemove, oldSubnet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return subnetsToRemove
|
||||||
|
}
|
||||||
|
|
||||||
|
func subnetsAreEqual(a, b net.IPNet) bool {
|
||||||
|
return a.IP.Equal(b.IP) && a.Mask.String() == b.Mask.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeSubnetFromSubnets(subnets []net.IPNet, subnet net.IPNet) []net.IPNet {
|
||||||
|
L := len(subnets)
|
||||||
|
for i := range subnets {
|
||||||
|
if subnetsAreEqual(subnet, subnets[i]) {
|
||||||
|
subnets[i] = subnets[L-1]
|
||||||
|
subnets = subnets[:L-1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return subnets
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string) {
|
||||||
|
const remove = true
|
||||||
|
for _, subnet := range subnets {
|
||||||
|
failed := false
|
||||||
|
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
||||||
|
failed = true
|
||||||
|
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
|
||||||
|
}
|
||||||
|
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
||||||
|
failed = true
|
||||||
|
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
|
||||||
|
}
|
||||||
|
if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil {
|
||||||
|
failed = true
|
||||||
|
c.logger.Error("cannot remove outdated allowed subnet route: %s", err)
|
||||||
|
}
|
||||||
|
if failed {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.allowedSubnets = removeSubnetFromSubnets(c.allowedSubnets, subnet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string, defaultGateway net.IP) error {
|
||||||
|
const remove = false
|
||||||
|
for _, subnet := range subnets {
|
||||||
|
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
|
||||||
|
}
|
||||||
|
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
|
||||||
|
}
|
||||||
|
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
|
||||||
|
return fmt.Errorf("cannot add route for allowed subnet: %w", err)
|
||||||
|
}
|
||||||
|
c.allowedSubnets = append(c.allowedSubnets, subnet)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
112
internal/firewall/vpn.go
Normal file
112
internal/firewall/vpn.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||||
|
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *configurator) SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error) {
|
||||||
|
c.stateMutex.Lock()
|
||||||
|
defer c.stateMutex.Unlock()
|
||||||
|
|
||||||
|
if !c.enabled {
|
||||||
|
c.logger.Info("firewall disabled, only updating VPN connections internal list")
|
||||||
|
c.vpnConnections = make([]models.OpenVPNConnection, len(connections))
|
||||||
|
copy(c.vpnConnections, connections)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Info("setting VPN connections through firewall...")
|
||||||
|
|
||||||
|
connectionsToAdd := findConnectionsToAdd(c.vpnConnections, connections)
|
||||||
|
connectionsToRemove := findConnectionsToRemove(c.vpnConnections, connections)
|
||||||
|
if len(connectionsToAdd) == 0 && len(connectionsToRemove) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultInterface, _, _, err := c.routing.DefaultRoute()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO remove elsewhere?
|
||||||
|
if err := c.acceptOutputThroughInterface(ctx, string(constants.TUN), false); err != nil {
|
||||||
|
return fmt.Errorf("cannot allow traffic through tunnel: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.removeConnections(ctx, connectionsToRemove, defaultInterface)
|
||||||
|
if err := c.addConnections(ctx, connectionsToAdd, defaultInterface); err != nil {
|
||||||
|
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeConnectionFromConnections(connections []models.OpenVPNConnection, connection models.OpenVPNConnection) []models.OpenVPNConnection {
|
||||||
|
L := len(connections)
|
||||||
|
for i := range connections {
|
||||||
|
if connection.Equal(connections[i]) {
|
||||||
|
connections[i] = connections[L-1]
|
||||||
|
connections = connections[:L-1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return connections
|
||||||
|
}
|
||||||
|
|
||||||
|
func findConnectionsToAdd(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToAdd []models.OpenVPNConnection) {
|
||||||
|
for _, newConnection := range newConnections {
|
||||||
|
found := false
|
||||||
|
for _, oldConnection := range oldConnections {
|
||||||
|
if oldConnection.Equal(newConnection) {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
connectionsToAdd = append(connectionsToAdd, newConnection)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return connectionsToAdd
|
||||||
|
}
|
||||||
|
|
||||||
|
func findConnectionsToRemove(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToRemove []models.OpenVPNConnection) {
|
||||||
|
for _, oldConnection := range oldConnections {
|
||||||
|
found := false
|
||||||
|
for _, newConnection := range newConnections {
|
||||||
|
if oldConnection.Equal(newConnection) {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
connectionsToRemove = append(connectionsToRemove, oldConnection)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return connectionsToRemove
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *configurator) removeConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) {
|
||||||
|
for _, conn := range connections {
|
||||||
|
const remove = true
|
||||||
|
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
|
||||||
|
c.logger.Error("cannot remove outdated VPN connection through firewall: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.vpnConnections = removeConnectionFromConnections(c.vpnConnections, conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *configurator) addConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) error {
|
||||||
|
const remove = false
|
||||||
|
for _, conn := range connections {
|
||||||
|
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.vpnConnections = append(c.vpnConnections, conn)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -7,3 +7,7 @@ type OpenVPNConnection struct {
|
|||||||
Port uint16
|
Port uint16
|
||||||
Protocol NetworkProtocol
|
Protocol NetworkProtocol
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *OpenVPNConnection) Equal(other OpenVPNConnection) bool {
|
||||||
|
return o.IP.Equal(other.IP) && o.Port == other.Port && o.Protocol == other.Protocol
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,43 +2,64 @@ package openvpn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/golibs/command"
|
"github.com/qdm12/golibs/command"
|
||||||
|
"github.com/qdm12/golibs/files"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
|
"github.com/qdm12/golibs/network"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||||
|
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
||||||
|
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||||
|
"github.com/qdm12/private-internet-access-docker/internal/provider"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/settings"
|
"github.com/qdm12/private-internet-access-docker/internal/settings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup)
|
Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup)
|
||||||
}
|
}
|
||||||
|
|
||||||
type looper struct {
|
type looper struct {
|
||||||
conf Configurator
|
// Variable parameters
|
||||||
|
provider models.VPNProvider
|
||||||
settings settings.OpenVPN
|
settings settings.OpenVPN
|
||||||
logger logging.Logger
|
// Fixed parameters
|
||||||
streamMerger command.StreamMerger
|
|
||||||
fatalOnError func(err error)
|
|
||||||
uid int
|
uid int
|
||||||
gid int
|
gid int
|
||||||
|
// Configurators
|
||||||
|
conf Configurator
|
||||||
|
fw firewall.Configurator
|
||||||
|
// Other objects
|
||||||
|
logger logging.Logger
|
||||||
|
client network.Client
|
||||||
|
fileManager files.FileManager
|
||||||
|
streamMerger command.StreamMerger
|
||||||
|
fatalOnError func(err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLooper(conf Configurator, settings settings.OpenVPN, logger logging.Logger,
|
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
||||||
streamMerger command.StreamMerger, fatalOnError func(err error), uid, gid int) Looper {
|
uid, gid int,
|
||||||
|
conf Configurator, fw firewall.Configurator,
|
||||||
|
logger logging.Logger, client network.Client, fileManager files.FileManager,
|
||||||
|
streamMerger command.StreamMerger, fatalOnError func(err error)) Looper {
|
||||||
return &looper{
|
return &looper{
|
||||||
conf: conf,
|
provider: provider,
|
||||||
settings: settings,
|
settings: settings,
|
||||||
logger: logger.WithPrefix("openvpn: "),
|
|
||||||
streamMerger: streamMerger,
|
|
||||||
fatalOnError: fatalOnError,
|
|
||||||
uid: uid,
|
uid: uid,
|
||||||
gid: gid,
|
gid: gid,
|
||||||
|
conf: conf,
|
||||||
|
fw: fw,
|
||||||
|
logger: logger.WithPrefix("openvpn: "),
|
||||||
|
client: client,
|
||||||
|
fileManager: fileManager,
|
||||||
|
streamMerger: streamMerger,
|
||||||
|
fatalOnError: fatalOnError,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) {
|
func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
select {
|
select {
|
||||||
@@ -46,17 +67,51 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for {
|
defer l.logger.Warn("loop exited")
|
||||||
openvpnCtx, openvpnCancel := context.WithCancel(ctx)
|
|
||||||
err := l.conf.WriteAuthFile(
|
for ctx.Err() == nil {
|
||||||
l.settings.User,
|
providerConf := provider.New(l.provider, l.client, l.fileManager)
|
||||||
l.settings.Password,
|
connections, err := providerConf.GetOpenVPNConnections(l.settings.Provider.ServerSelection)
|
||||||
|
l.fatalOnError(err)
|
||||||
|
err = providerConf.BuildConf(
|
||||||
|
connections,
|
||||||
|
l.settings.Verbosity,
|
||||||
l.uid,
|
l.uid,
|
||||||
l.gid,
|
l.gid,
|
||||||
|
l.settings.Root,
|
||||||
|
l.settings.Cipher,
|
||||||
|
l.settings.Auth,
|
||||||
|
l.settings.Provider.ExtraConfigOptions,
|
||||||
)
|
)
|
||||||
l.fatalOnError(err)
|
l.fatalOnError(err)
|
||||||
stream, waitFn, err := l.conf.Start(openvpnCtx)
|
|
||||||
|
err = l.conf.WriteAuthFile(l.settings.User, l.settings.Password, l.uid, l.gid)
|
||||||
l.fatalOnError(err)
|
l.fatalOnError(err)
|
||||||
|
|
||||||
|
if err := l.fw.SetVPNConnections(ctx, connections); err != nil {
|
||||||
|
l.fatalOnError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
openvpnCtx, openvpnCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
stream, waitFn, err := l.conf.Start(openvpnCtx)
|
||||||
|
if err != nil {
|
||||||
|
openvpnCancel()
|
||||||
|
l.logAndWait(ctx, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
go func(ctx context.Context) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-portForward:
|
||||||
|
l.portForward(ctx, providerConf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(openvpnCtx)
|
||||||
|
|
||||||
go l.streamMerger.Merge(openvpnCtx, stream,
|
go l.streamMerger.Merge(openvpnCtx, stream,
|
||||||
command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn()))
|
command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn()))
|
||||||
waitError := make(chan error)
|
waitError := make(chan error)
|
||||||
@@ -74,13 +129,53 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
|||||||
case <-restart: // triggered restart
|
case <-restart: // triggered restart
|
||||||
l.logger.Info("restarting")
|
l.logger.Info("restarting")
|
||||||
openvpnCancel()
|
openvpnCancel()
|
||||||
|
<-waitError
|
||||||
close(waitError)
|
close(waitError)
|
||||||
case err := <-waitError: // unexpected error
|
case err := <-waitError: // unexpected error
|
||||||
l.logger.Warn(err)
|
|
||||||
l.logger.Info("restarting")
|
|
||||||
openvpnCancel()
|
openvpnCancel()
|
||||||
close(waitError)
|
close(waitError)
|
||||||
time.Sleep(time.Second)
|
l.logAndWait(ctx, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||||
|
l.logger.Error(err)
|
||||||
|
l.logger.Info("retrying in 30 seconds")
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel() // just for the linter
|
||||||
|
<-ctx.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) portForward(ctx context.Context, providerConf provider.Provider) {
|
||||||
|
if !l.settings.Provider.PortForwarding.Enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var port uint16
|
||||||
|
err := fmt.Errorf("")
|
||||||
|
for err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
port, err = providerConf.GetPortForward()
|
||||||
|
if err != nil {
|
||||||
|
l.logAndWait(ctx, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
l.logger.Info("port forwarded is %d", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
filepath := l.settings.Provider.PortForwarding.Filepath
|
||||||
|
l.logger.Info("writing forwarded port to %s", filepath)
|
||||||
|
err = l.fileManager.WriteLinesToFile(
|
||||||
|
string(filepath), []string{fmt.Sprintf("%d", port)},
|
||||||
|
files.Ownership(l.uid, l.gid), files.Permissions(0400),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
l.logger.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := l.fw.SetPortForward(ctx, port); err != nil {
|
||||||
|
l.logger.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -125,11 +124,3 @@ func (c *cyberghost) BuildConf(connections []models.OpenVPNConnection, verbosity
|
|||||||
func (c *cyberghost) GetPortForward() (port uint16, err error) {
|
func (c *cyberghost) GetPortForward() (port uint16, err error) {
|
||||||
panic("port forwarding is not supported for cyberghost")
|
panic("port forwarding is not supported for cyberghost")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *cyberghost) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
|
|
||||||
panic("port forwarding is not supported for cyberghost")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cyberghost) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
|
|
||||||
panic("port forwarding is not supported for cyberghost")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,24 +1,20 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/qdm12/golibs/files"
|
"github.com/qdm12/golibs/files"
|
||||||
"github.com/qdm12/golibs/logging"
|
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mullvad struct {
|
type mullvad struct {
|
||||||
fileManager files.FileManager
|
fileManager files.FileManager
|
||||||
logger logging.Logger
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMullvad(fileManager files.FileManager, logger logging.Logger) *mullvad {
|
func newMullvad(fileManager files.FileManager) *mullvad {
|
||||||
return &mullvad{
|
return &mullvad{
|
||||||
fileManager: fileManager,
|
fileManager: fileManager,
|
||||||
logger: logger.WithPrefix("Mullvad configurator: "),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,11 +102,3 @@ func (m *mullvad) BuildConf(connections []models.OpenVPNConnection, verbosity, u
|
|||||||
func (m *mullvad) GetPortForward() (port uint16, err error) {
|
func (m *mullvad) GetPortForward() (port uint16, err error) {
|
||||||
panic("port forwarding is not supported for mullvad")
|
panic("port forwarding is not supported for mullvad")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mullvad) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
|
|
||||||
panic("port forwarding is not supported for mullvad")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mullvad) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
|
|
||||||
panic("port forwarding is not supported for mullvad")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -14,24 +13,21 @@ import (
|
|||||||
"github.com/qdm12/golibs/network"
|
"github.com/qdm12/golibs/network"
|
||||||
"github.com/qdm12/golibs/verification"
|
"github.com/qdm12/golibs/verification"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
type pia struct {
|
type pia struct {
|
||||||
client network.Client
|
client network.Client
|
||||||
fileManager files.FileManager
|
fileManager files.FileManager
|
||||||
firewall firewall.Configurator
|
|
||||||
random random.Random
|
random random.Random
|
||||||
verifyPort func(port string) error
|
verifyPort func(port string) error
|
||||||
lookupIP func(host string) ([]net.IP, error)
|
lookupIP func(host string) ([]net.IP, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPrivateInternetAccess(client network.Client, fileManager files.FileManager, firewall firewall.Configurator) *pia {
|
func newPrivateInternetAccess(client network.Client, fileManager files.FileManager) *pia {
|
||||||
return &pia{
|
return &pia{
|
||||||
client: client,
|
client: client,
|
||||||
fileManager: fileManager,
|
fileManager: fileManager,
|
||||||
firewall: firewall,
|
|
||||||
random: random.NewRandom(),
|
random: random.NewRandom(),
|
||||||
verifyPort: verification.NewVerifier().VerifyPort,
|
verifyPort: verification.NewVerifier().VerifyPort,
|
||||||
lookupIP: net.LookupIP}
|
lookupIP: net.LookupIP}
|
||||||
@@ -168,7 +164,7 @@ func (p *pia) GetPortForward() (port uint16, err error) {
|
|||||||
}
|
}
|
||||||
clientID := hex.EncodeToString(b)
|
clientID := hex.EncodeToString(b)
|
||||||
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
|
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
|
||||||
content, status, err := p.client.GetContent(url)
|
content, status, err := p.client.GetContent(url) // TODO add ctx
|
||||||
switch {
|
switch {
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -185,15 +181,3 @@ func (p *pia) GetPortForward() (port uint16, err error) {
|
|||||||
}
|
}
|
||||||
return body.Port, nil
|
return body.Port, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *pia) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
|
|
||||||
return p.fileManager.WriteLinesToFile(
|
|
||||||
string(filepath),
|
|
||||||
[]string{fmt.Sprintf("%d", port)},
|
|
||||||
files.Ownership(uid, gid),
|
|
||||||
files.Permissions(0400))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *pia) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
|
|
||||||
return p.firewall.AllowInputTrafficOnPort(ctx, device, port)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,13 +1,9 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/qdm12/golibs/files"
|
"github.com/qdm12/golibs/files"
|
||||||
"github.com/qdm12/golibs/logging"
|
|
||||||
"github.com/qdm12/golibs/network"
|
"github.com/qdm12/golibs/network"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,16 +12,14 @@ type Provider interface {
|
|||||||
GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error)
|
GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error)
|
||||||
BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (err error)
|
BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (err error)
|
||||||
GetPortForward() (port uint16, err error)
|
GetPortForward() (port uint16, err error)
|
||||||
WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error)
|
|
||||||
AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(provider models.VPNProvider, logger logging.Logger, client network.Client, fileManager files.FileManager, firewall firewall.Configurator) Provider {
|
func New(provider models.VPNProvider, client network.Client, fileManager files.FileManager) Provider {
|
||||||
switch provider {
|
switch provider {
|
||||||
case constants.PrivateInternetAccess:
|
case constants.PrivateInternetAccess:
|
||||||
return newPrivateInternetAccess(client, fileManager, firewall)
|
return newPrivateInternetAccess(client, fileManager)
|
||||||
case constants.Mullvad:
|
case constants.Mullvad:
|
||||||
return newMullvad(fileManager, logger)
|
return newMullvad(fileManager)
|
||||||
case constants.Windscribe:
|
case constants.Windscribe:
|
||||||
return newWindscribe(fileManager)
|
return newWindscribe(fileManager)
|
||||||
case constants.Surfshark:
|
case constants.Surfshark:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -127,11 +126,3 @@ func (s *surfshark) BuildConf(connections []models.OpenVPNConnection, verbosity,
|
|||||||
func (s *surfshark) GetPortForward() (port uint16, err error) {
|
func (s *surfshark) GetPortForward() (port uint16, err error) {
|
||||||
panic("port forwarding is not supported for surfshark")
|
panic("port forwarding is not supported for surfshark")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *surfshark) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
|
|
||||||
panic("port forwarding is not supported for surfshark")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *surfshark) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
|
|
||||||
panic("port forwarding is not supported for surfshark")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -124,11 +123,3 @@ func (w *windscribe) BuildConf(connections []models.OpenVPNConnection, verbosity
|
|||||||
func (w *windscribe) GetPortForward() (port uint16, err error) {
|
func (w *windscribe) GetPortForward() (port uint16, err error) {
|
||||||
panic("port forwarding is not supported for windscribe")
|
panic("port forwarding is not supported for windscribe")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *windscribe) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) {
|
|
||||||
panic("port forwarding is not supported for windscribe")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *windscribe) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) {
|
|
||||||
panic("port forwarding is not supported for windscribe")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,29 +7,34 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *routing) AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error {
|
func (r *routing) AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error {
|
||||||
for _, subnet := range subnets {
|
subnetStr := subnet.String()
|
||||||
|
r.logger.Info("adding %s as route via %s %s", subnetStr, defaultGateway, defaultInterface)
|
||||||
exists, err := r.routeExists(subnet)
|
exists, err := r.routeExists(subnet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
} else if exists { // thanks to @npawelek https://github.com/npawelek
|
} else if exists {
|
||||||
if err := r.removeRoute(ctx, subnet); err != nil {
|
return nil
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
output, err := r.commander.Run(ctx, "ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface)
|
||||||
r.logger.Info("adding %s as route via %s", subnet.String(), defaultInterface)
|
|
||||||
output, err := r.commander.Run(ctx, "ip", "route", "add", subnet.String(), "via", defaultGateway.String(), "dev", defaultInterface)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnet.String(), defaultGateway.String(), "dev", defaultInterface, output, err)
|
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnetStr, defaultGateway, "dev", defaultInterface, output, err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *routing) removeRoute(ctx context.Context, subnet net.IPNet) (err error) {
|
func (r *routing) DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error) {
|
||||||
output, err := r.commander.Run(ctx, "ip", "route", "del", subnet.String())
|
subnetStr := subnet.String()
|
||||||
|
r.logger.Info("deleting route for %s", subnetStr)
|
||||||
|
exists, err := r.routeExists(subnet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot delete route for %s: %s: %w", subnet.String(), output, err)
|
return err
|
||||||
|
} else if !exists { // thanks to @npawelek https://github.com/npawelek
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
output, err := r.commander.Run(ctx, "ip", "route", "del", subnetStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot delete route for %s: %s: %w", subnetStr, output, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,12 +8,16 @@ import (
|
|||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/qdm12/golibs/command/mock_command"
|
"github.com/qdm12/golibs/command/mock_command"
|
||||||
|
"github.com/qdm12/golibs/files/mock_files"
|
||||||
|
"github.com/qdm12/golibs/logging/mock_logging"
|
||||||
|
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_removeRoute(t *testing.T) {
|
func Test_DeleteRouteVia(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
ctx := context.Background()
|
||||||
tests := map[string]struct {
|
tests := map[string]struct {
|
||||||
subnet net.IPNet
|
subnet net.IPNet
|
||||||
runOutput string
|
runOutput string
|
||||||
@@ -22,26 +26,26 @@ func Test_removeRoute(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
"no output no error": {
|
"no output no error": {
|
||||||
subnet: net.IPNet{
|
subnet: net.IPNet{
|
||||||
IP: net.IP{192, 168, 1, 0},
|
IP: net.IP{192, 168, 2, 0},
|
||||||
Mask: net.IPMask{255, 255, 255, 0},
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"error only": {
|
"error only": {
|
||||||
subnet: net.IPNet{
|
subnet: net.IPNet{
|
||||||
IP: net.IP{192, 168, 1, 0},
|
IP: net.IP{192, 168, 2, 0},
|
||||||
Mask: net.IPMask{255, 255, 255, 0},
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
},
|
},
|
||||||
runErr: fmt.Errorf("error"),
|
runErr: fmt.Errorf("error"),
|
||||||
err: fmt.Errorf("cannot delete route for 192.168.1.0/24: : error"),
|
err: fmt.Errorf("cannot delete route for 192.168.2.0/24: : error"),
|
||||||
},
|
},
|
||||||
"error and output": {
|
"error and output": {
|
||||||
subnet: net.IPNet{
|
subnet: net.IPNet{
|
||||||
IP: net.IP{192, 168, 1, 0},
|
IP: net.IP{192, 168, 2, 0},
|
||||||
Mask: net.IPMask{255, 255, 255, 0},
|
Mask: net.IPMask{255, 255, 255, 0},
|
||||||
},
|
},
|
||||||
runErr: fmt.Errorf("error"),
|
runErr: fmt.Errorf("error"),
|
||||||
runOutput: "output",
|
runOutput: "output",
|
||||||
err: fmt.Errorf("cannot delete route for 192.168.1.0/24: output: error"),
|
err: fmt.Errorf("cannot delete route for 192.168.2.0/24: output: error"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for name, tc := range tests {
|
for name, tc := range tests {
|
||||||
@@ -50,12 +54,26 @@ func Test_removeRoute(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
mockCtrl := gomock.NewController(t)
|
mockCtrl := gomock.NewController(t)
|
||||||
defer mockCtrl.Finish()
|
defer mockCtrl.Finish()
|
||||||
commander := mock_command.NewMockCommander(mockCtrl)
|
|
||||||
|
|
||||||
commander.EXPECT().Run(context.Background(), "ip", "route", "del", tc.subnet.String()).
|
subnetStr := tc.subnet.String()
|
||||||
|
|
||||||
|
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||||
|
logger.EXPECT().Info("deleting route for %s")
|
||||||
|
commander := mock_command.NewMockCommander(mockCtrl)
|
||||||
|
commander.EXPECT().Run(ctx, "ip", "route", "del", subnetStr).
|
||||||
Return(tc.runOutput, tc.runErr).Times(1)
|
Return(tc.runOutput, tc.runErr).Times(1)
|
||||||
r := &routing{commander: commander}
|
fileManager := mock_files.NewMockFileManager(mockCtrl)
|
||||||
err := r.removeRoute(context.Background(), tc.subnet)
|
routesData := []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||||
|
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
||||||
|
`)
|
||||||
|
fileManager.EXPECT().ReadFile(string(constants.NetRoute)).Return(routesData, nil)
|
||||||
|
r := &routing{
|
||||||
|
logger: logger,
|
||||||
|
commander: commander,
|
||||||
|
fileManager: fileManager,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := r.DeleteRouteVia(ctx, tc.subnet)
|
||||||
if tc.err != nil {
|
if tc.err != nil {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Equal(t, tc.err.Error(), err.Error())
|
assert.Equal(t, tc.err.Error(), err.Error())
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Routing interface {
|
type Routing interface {
|
||||||
AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error
|
AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error
|
||||||
|
DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error)
|
||||||
DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error)
|
DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error)
|
||||||
VPNGatewayIP(defaultInterface string) (ip net.IP, err error)
|
VPNGatewayIP(defaultInterface string) (ip net.IP, err error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
|||||||
}
|
}
|
||||||
defer l.logger.Warn("loop exited")
|
defer l.logger.Warn("loop exited")
|
||||||
|
|
||||||
|
var previousPort uint16
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
nameserver := l.dnsSettings.PlaintextAddress.String()
|
nameserver := l.dnsSettings.PlaintextAddress.String()
|
||||||
if l.dnsSettings.Enabled {
|
if l.dnsSettings.Enabled {
|
||||||
@@ -75,11 +76,19 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
|||||||
l.logAndWait(ctx, err)
|
l.logAndWait(ctx, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port)
|
|
||||||
// TODO remove firewall rule on exit below
|
if previousPort > 0 {
|
||||||
if err != nil {
|
if err := l.firewallConf.RemoveAllowedPort(ctx, previousPort); err != nil {
|
||||||
l.logger.Error(err)
|
l.logger.Error(err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if err := l.firewallConf.SetAllowedPort(ctx, l.settings.Port); err != nil {
|
||||||
|
l.logger.Error(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
previousPort = l.settings.Port
|
||||||
|
|
||||||
shadowsocksCtx, shadowsocksCancel := context.WithCancel(context.Background())
|
shadowsocksCtx, shadowsocksCancel := context.WithCancel(context.Background())
|
||||||
stdout, stderr, waitFn, err := l.conf.Start(shadowsocksCtx, "0.0.0.0", l.settings.Port, l.settings.Password, l.settings.Log)
|
stdout, stderr, waitFn, err := l.conf.Start(shadowsocksCtx, "0.0.0.0", l.settings.Port, l.settings.Password, l.settings.Log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
|||||||
}
|
}
|
||||||
defer l.logger.Warn("loop exited")
|
defer l.logger.Warn("loop exited")
|
||||||
|
|
||||||
|
var previousPort uint16
|
||||||
for ctx.Err() == nil {
|
for ctx.Err() == nil {
|
||||||
err := l.conf.MakeConf(
|
err := l.conf.MakeConf(
|
||||||
l.settings.LogLevel,
|
l.settings.LogLevel,
|
||||||
@@ -69,11 +70,19 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
|||||||
l.logAndWait(ctx, err)
|
l.logAndWait(ctx, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port)
|
|
||||||
// TODO remove firewall rule on exit below
|
if previousPort > 0 {
|
||||||
if err != nil {
|
if err := l.firewallConf.RemoveAllowedPort(ctx, previousPort); err != nil {
|
||||||
l.logger.Error(err)
|
l.logger.Error(err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if err := l.firewallConf.SetAllowedPort(ctx, l.settings.Port); err != nil {
|
||||||
|
l.logger.Error(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
previousPort = l.settings.Port
|
||||||
|
|
||||||
tinyproxyCtx, tinyproxyCancel := context.WithCancel(context.Background())
|
tinyproxyCtx, tinyproxyCancel := context.WithCancel(context.Background())
|
||||||
stream, waitFn, err := l.conf.Start(tinyproxyCtx)
|
stream, waitFn, err := l.conf.Start(tinyproxyCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user