Firewall refactoring
- Ability to enable and disable rules in various loops - Simplified code overall - Port forwarding moved into openvpn loop - Route addition and removal improved
This commit is contained in:
@@ -2,43 +2,64 @@ package openvpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/provider"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/settings"
|
||||
)
|
||||
|
||||
type Looper interface {
|
||||
Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup)
|
||||
Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup)
|
||||
}
|
||||
|
||||
type looper struct {
|
||||
conf Configurator
|
||||
settings settings.OpenVPN
|
||||
// Variable parameters
|
||||
provider models.VPNProvider
|
||||
settings settings.OpenVPN
|
||||
// Fixed parameters
|
||||
uid int
|
||||
gid int
|
||||
// Configurators
|
||||
conf Configurator
|
||||
fw firewall.Configurator
|
||||
// Other objects
|
||||
logger logging.Logger
|
||||
client network.Client
|
||||
fileManager files.FileManager
|
||||
streamMerger command.StreamMerger
|
||||
fatalOnError func(err error)
|
||||
uid int
|
||||
gid int
|
||||
}
|
||||
|
||||
func NewLooper(conf Configurator, settings settings.OpenVPN, logger logging.Logger,
|
||||
streamMerger command.StreamMerger, fatalOnError func(err error), uid, gid int) Looper {
|
||||
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
||||
uid, gid int,
|
||||
conf Configurator, fw firewall.Configurator,
|
||||
logger logging.Logger, client network.Client, fileManager files.FileManager,
|
||||
streamMerger command.StreamMerger, fatalOnError func(err error)) Looper {
|
||||
return &looper{
|
||||
conf: conf,
|
||||
provider: provider,
|
||||
settings: settings,
|
||||
logger: logger.WithPrefix("openvpn: "),
|
||||
streamMerger: streamMerger,
|
||||
fatalOnError: fatalOnError,
|
||||
uid: uid,
|
||||
gid: gid,
|
||||
conf: conf,
|
||||
fw: fw,
|
||||
logger: logger.WithPrefix("openvpn: "),
|
||||
client: client,
|
||||
fileManager: fileManager,
|
||||
streamMerger: streamMerger,
|
||||
fatalOnError: fatalOnError,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) {
|
||||
func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup) {
|
||||
wg.Add(1)
|
||||
defer wg.Done()
|
||||
select {
|
||||
@@ -46,17 +67,51 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
for {
|
||||
openvpnCtx, openvpnCancel := context.WithCancel(ctx)
|
||||
err := l.conf.WriteAuthFile(
|
||||
l.settings.User,
|
||||
l.settings.Password,
|
||||
defer l.logger.Warn("loop exited")
|
||||
|
||||
for ctx.Err() == nil {
|
||||
providerConf := provider.New(l.provider, l.client, l.fileManager)
|
||||
connections, err := providerConf.GetOpenVPNConnections(l.settings.Provider.ServerSelection)
|
||||
l.fatalOnError(err)
|
||||
err = providerConf.BuildConf(
|
||||
connections,
|
||||
l.settings.Verbosity,
|
||||
l.uid,
|
||||
l.gid,
|
||||
l.settings.Root,
|
||||
l.settings.Cipher,
|
||||
l.settings.Auth,
|
||||
l.settings.Provider.ExtraConfigOptions,
|
||||
)
|
||||
l.fatalOnError(err)
|
||||
stream, waitFn, err := l.conf.Start(openvpnCtx)
|
||||
|
||||
err = l.conf.WriteAuthFile(l.settings.User, l.settings.Password, l.uid, l.gid)
|
||||
l.fatalOnError(err)
|
||||
|
||||
if err := l.fw.SetVPNConnections(ctx, connections); err != nil {
|
||||
l.fatalOnError(err)
|
||||
}
|
||||
|
||||
openvpnCtx, openvpnCancel := context.WithCancel(context.Background())
|
||||
|
||||
stream, waitFn, err := l.conf.Start(openvpnCtx)
|
||||
if err != nil {
|
||||
openvpnCancel()
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
|
||||
go func(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-portForward:
|
||||
l.portForward(ctx, providerConf)
|
||||
}
|
||||
}
|
||||
}(openvpnCtx)
|
||||
|
||||
go l.streamMerger.Merge(openvpnCtx, stream,
|
||||
command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn()))
|
||||
waitError := make(chan error)
|
||||
@@ -74,13 +129,53 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
||||
case <-restart: // triggered restart
|
||||
l.logger.Info("restarting")
|
||||
openvpnCancel()
|
||||
<-waitError
|
||||
close(waitError)
|
||||
case err := <-waitError: // unexpected error
|
||||
l.logger.Warn(err)
|
||||
l.logger.Info("restarting")
|
||||
openvpnCancel()
|
||||
close(waitError)
|
||||
time.Sleep(time.Second)
|
||||
l.logAndWait(ctx, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||
l.logger.Error(err)
|
||||
l.logger.Info("retrying in 30 seconds")
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel() // just for the linter
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
func (l *looper) portForward(ctx context.Context, providerConf provider.Provider) {
|
||||
if !l.settings.Provider.PortForwarding.Enabled {
|
||||
return
|
||||
}
|
||||
var port uint16
|
||||
err := fmt.Errorf("")
|
||||
for err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
port, err = providerConf.GetPortForward()
|
||||
if err != nil {
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
l.logger.Info("port forwarded is %d", port)
|
||||
}
|
||||
|
||||
filepath := l.settings.Provider.PortForwarding.Filepath
|
||||
l.logger.Info("writing forwarded port to %s", filepath)
|
||||
err = l.fileManager.WriteLinesToFile(
|
||||
string(filepath), []string{fmt.Sprintf("%d", port)},
|
||||
files.Ownership(l.uid, l.gid), files.Permissions(0400),
|
||||
)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
|
||||
if err := l.fw.SetPortForward(ctx, port); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user