Simplified loop mechanism for openvpn and dns
- Refers to #91 - http control server starts without waiting for unbound and/or openvpn - Trying to get rid of waiter and use channels directly - Simpler main.go - More robust logic overall
This commit is contained in:
@@ -3,7 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -211,40 +210,49 @@ func _main(background context.Context, args []string) int {
|
|||||||
go streamMerger.Merge(ctx, stderr, command.MergeName("shadowsocks error"), command.MergeColor(constants.ColorShadowsocksError()))
|
go streamMerger.Merge(ctx, stderr, command.MergeName("shadowsocks error"), command.MergeColor(constants.ColorShadowsocksError()))
|
||||||
}
|
}
|
||||||
|
|
||||||
httpServer := server.New("0.0.0.0:8000", logger)
|
restartOpenvpn := make(chan struct{})
|
||||||
|
restartUnbound := make(chan struct{})
|
||||||
|
openvpnDone := make(chan struct{})
|
||||||
|
unboundDone := make(chan struct{})
|
||||||
|
serverDone := make(chan struct{})
|
||||||
|
|
||||||
go openvpnRunLoop(ctx, ovpnConf, streamMerger, logger, httpServer, waiter, fatalOnError)
|
openvpnLooper := openvpn.NewLooper(ovpnConf, allSettings.OpenVPN, logger, streamMerger, fatalOnError)
|
||||||
|
// wait for restartOpenvpn
|
||||||
|
go openvpnLooper.Run(ctx, restartOpenvpn, openvpnDone)
|
||||||
|
|
||||||
waiter.Add(func() error {
|
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger,
|
||||||
err := httpServer.Run(ctx)
|
streamMerger, allSettings.System.UID, allSettings.System.GID)
|
||||||
logger.Error("http server: %s", err)
|
// wait for restartUnbound
|
||||||
return err
|
go unboundLooper.Run(ctx, restartUnbound, unboundDone)
|
||||||
})
|
|
||||||
|
|
||||||
startUnboundCh := make(chan struct{})
|
|
||||||
go unboundRunLoop(ctx, startUnboundCh, logger, dnsConf, allSettings.DNS, allSettings.System.UID, allSettings.System.GID, waiter, streamMerger, httpServer)
|
|
||||||
if !allSettings.DNS.Enabled {
|
|
||||||
httpServer.SetUnboundRestart(func() {})
|
|
||||||
dnsConf.UseDNSInternally(allSettings.DNS.PlaintextAddress)
|
|
||||||
if err := dnsConf.UseDNSSystemWide(allSettings.DNS.PlaintextAddress); err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
first := true
|
||||||
|
var restartTickerContext context.Context
|
||||||
|
var restartTickerCancel context.CancelFunc = func() {}
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
restartTickerCancel()
|
||||||
return
|
return
|
||||||
case <-connectedCh: // blocks until openvpn is connected
|
case <-connectedCh: // blocks until openvpn is connected
|
||||||
if allSettings.DNS.Enabled {
|
if first {
|
||||||
startUnboundCh <- struct{}{}
|
first = false
|
||||||
|
restartUnbound <- struct{}{}
|
||||||
}
|
}
|
||||||
|
restartTickerCancel()
|
||||||
|
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
||||||
|
go unboundLooper.RunRestartTicker(restartTickerContext, restartUnbound)
|
||||||
onConnected(allSettings, logger, fileManager, routingConf, defaultInterface, providerConf)
|
onConnected(allSettings, logger, fileManager, routingConf, defaultInterface, providerConf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
httpServer := server.New("0.0.0.0:8000", logger, restartOpenvpn, restartUnbound)
|
||||||
|
go httpServer.Run(ctx, serverDone)
|
||||||
|
|
||||||
|
// Start openvpn for the first time
|
||||||
|
restartOpenvpn <- struct{}{}
|
||||||
|
|
||||||
signalsCh := make(chan os.Signal, 1)
|
signalsCh := make(chan os.Signal, 1)
|
||||||
signal.Notify(signalsCh,
|
signal.Notify(signalsCh,
|
||||||
syscall.SIGINT,
|
syscall.SIGINT,
|
||||||
@@ -278,6 +286,9 @@ func _main(background context.Context, args []string) int {
|
|||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
exitStatus = 1
|
exitStatus = 1
|
||||||
}
|
}
|
||||||
|
<-serverDone
|
||||||
|
<-unboundDone
|
||||||
|
<-openvpnDone
|
||||||
return exitStatus
|
return exitStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -348,34 +359,6 @@ func trimEventualProgramPrefix(s string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func openvpnRunLoop(ctx context.Context, ovpnConf openvpn.Configurator, streamMerger command.StreamMerger,
|
|
||||||
logger logging.Logger, httpServer server.Server, waiter command.Waiter, fatalOnError func(err error)) {
|
|
||||||
logger = logger.WithPrefix("openvpn: ")
|
|
||||||
waitErrors := make(chan error)
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
logger.Info("starting")
|
|
||||||
openvpnCtx, openvpnCancel := context.WithCancel(ctx)
|
|
||||||
stream, waitFn, err := ovpnConf.Start(openvpnCtx)
|
|
||||||
fatalOnError(err)
|
|
||||||
httpServer.SetOpenVPNRestart(openvpnCancel)
|
|
||||||
go streamMerger.Merge(openvpnCtx, stream, command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn()))
|
|
||||||
waiter.Add(func() error {
|
|
||||||
return <-waitErrors
|
|
||||||
})
|
|
||||||
err = waitFn()
|
|
||||||
waitErrors <- fmt.Errorf("openvpn: %w", err)
|
|
||||||
switch {
|
|
||||||
case ctx.Err() != nil:
|
|
||||||
logger.Warn("context canceled: exiting openvpn run loop")
|
|
||||||
case openvpnCtx.Err() == context.Canceled:
|
|
||||||
logger.Info("triggered openvpn restart")
|
|
||||||
default:
|
|
||||||
logger.Warn(err)
|
|
||||||
openvpnCancel()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func onConnected(allSettings settings.Settings,
|
func onConnected(allSettings settings.Settings,
|
||||||
logger logging.Logger, fileManager files.FileManager,
|
logger logging.Logger, fileManager files.FileManager,
|
||||||
routingConf routing.Routing, defaultInterface string,
|
routingConf routing.Routing, defaultInterface string,
|
||||||
@@ -393,7 +376,7 @@ func onConnected(allSettings settings.Settings,
|
|||||||
} else {
|
} else {
|
||||||
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
|
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
|
||||||
}
|
}
|
||||||
time.AfterFunc(7*time.Second, func() { // wait for Unbound to start - TODO use signal channel
|
time.AfterFunc(10*time.Second, func() { // wait for Unbound to start - TODO use signal channel
|
||||||
publicIP, err := publicip.NewIPGetter(network.NewClient(3 * time.Second)).Get()
|
publicIP, err := publicip.NewIPGetter(network.NewClient(3 * time.Second)).Get()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
@@ -411,126 +394,6 @@ func onConnected(allSettings settings.Settings,
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func fallbackToUnencryptedIPv4DNS(dnsConf dns.Configurator, providers []models.DNSProvider) error {
|
|
||||||
var targetIP net.IP
|
|
||||||
for _, provider := range providers {
|
|
||||||
data := constants.DNSProviderMapping()[provider]
|
|
||||||
for _, targetIP = range data.IPs {
|
|
||||||
if targetIP.To4() != nil {
|
|
||||||
dnsConf.UseDNSInternally(targetIP)
|
|
||||||
return dnsConf.UseDNSSystemWide(targetIP)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// No IPv4 address found
|
|
||||||
return fmt.Errorf("no ipv4 DNS address found for providers %s", providers)
|
|
||||||
}
|
|
||||||
|
|
||||||
func unboundRun(ctx, oldCtx context.Context, oldCancel context.CancelFunc, timer *time.Timer,
|
|
||||||
dnsConf dns.Configurator, settings settings.DNS, uid, gid int,
|
|
||||||
streamMerger command.StreamMerger, waiter command.Waiter, httpServer server.Server) (
|
|
||||||
newCtx context.Context, newCancel context.CancelFunc, setupErr, startErr, waitErr error) {
|
|
||||||
if timer != nil {
|
|
||||||
timer.Stop()
|
|
||||||
timer.Reset(settings.UpdatePeriod)
|
|
||||||
}
|
|
||||||
if err := dnsConf.DownloadRootHints(uid, gid); err != nil {
|
|
||||||
return oldCtx, oldCancel, err, nil, nil
|
|
||||||
}
|
|
||||||
if err := dnsConf.DownloadRootKey(uid, gid); err != nil {
|
|
||||||
return oldCtx, oldCancel, err, nil, nil
|
|
||||||
}
|
|
||||||
if err := dnsConf.MakeUnboundConf(settings, uid, gid); err != nil {
|
|
||||||
return oldCtx, oldCancel, err, nil, nil
|
|
||||||
}
|
|
||||||
newCtx, newCancel = context.WithCancel(ctx)
|
|
||||||
oldCancel()
|
|
||||||
stream, waitFn, err := dnsConf.Start(newCtx, settings.VerbosityDetailsLevel)
|
|
||||||
if err != nil {
|
|
||||||
return newCtx, newCancel, nil, err, nil
|
|
||||||
}
|
|
||||||
go streamMerger.Merge(newCtx, stream, command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound()))
|
|
||||||
dnsConf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
|
|
||||||
if err := dnsConf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound
|
|
||||||
return newCtx, newCancel, nil, err, nil
|
|
||||||
}
|
|
||||||
if err := dnsConf.WaitForUnbound(); err != nil {
|
|
||||||
return newCtx, newCancel, nil, err, nil
|
|
||||||
}
|
|
||||||
// Unbound is up and running at this point
|
|
||||||
httpServer.SetUnboundRestart(newCancel)
|
|
||||||
waitError := make(chan error)
|
|
||||||
waiterError := make(chan error)
|
|
||||||
waiter.Add(func() error { //nolint:scopelint
|
|
||||||
return <-waiterError
|
|
||||||
})
|
|
||||||
go func() {
|
|
||||||
err := fmt.Errorf("unbound: %w", waitFn())
|
|
||||||
waitError <- err
|
|
||||||
waiterError <- err
|
|
||||||
}()
|
|
||||||
if timer == nil {
|
|
||||||
waitErr := <-waitError
|
|
||||||
return newCtx, newCancel, nil, nil, waitErr
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-timer.C:
|
|
||||||
return newCtx, newCancel, nil, nil, nil
|
|
||||||
case waitErr := <-waitError:
|
|
||||||
return newCtx, newCancel, nil, nil, waitErr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func unboundRunLoop(ctx context.Context, startCh chan struct{}, logger logging.Logger, dnsConf dns.Configurator,
|
|
||||||
settings settings.DNS, uid, gid int,
|
|
||||||
waiter command.Waiter, streamMerger command.StreamMerger, httpServer server.Server,
|
|
||||||
) {
|
|
||||||
logger = logger.WithPrefix("unbound dns over tls setup: ")
|
|
||||||
select {
|
|
||||||
case <-startCh:
|
|
||||||
case <-ctx.Done():
|
|
||||||
logger.Warn("context canceled: exiting unbound run loop")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := fallbackToUnencryptedIPv4DNS(dnsConf, settings.Providers); err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
var timer *time.Timer
|
|
||||||
if settings.UpdatePeriod > 0 {
|
|
||||||
timer = time.NewTimer(settings.UpdatePeriod)
|
|
||||||
}
|
|
||||||
unboundCtx, unboundCancel := context.WithCancel(ctx)
|
|
||||||
defer unboundCancel()
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
var setupErr, startErr, waitErr error
|
|
||||||
unboundCtx, unboundCancel, setupErr, startErr, waitErr = unboundRun(
|
|
||||||
ctx, unboundCtx, unboundCancel, timer, dnsConf, settings,
|
|
||||||
uid, gid, streamMerger, waiter, httpServer)
|
|
||||||
switch {
|
|
||||||
case ctx.Err() != nil:
|
|
||||||
logger.Warn("context canceled: exiting unbound run loop")
|
|
||||||
case timer != nil && !timer.Stop():
|
|
||||||
logger.Info("planned restart of unbound")
|
|
||||||
case unboundCtx.Err() == context.Canceled:
|
|
||||||
logger.Info("triggered restart of unbound")
|
|
||||||
case setupErr != nil:
|
|
||||||
logger.Warn(setupErr)
|
|
||||||
case startErr != nil:
|
|
||||||
logger.Error(startErr)
|
|
||||||
unboundCancel()
|
|
||||||
if err := fallbackToUnencryptedIPv4DNS(dnsConf, settings.Providers); err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
case waitErr != nil:
|
|
||||||
logger.Warn(waitErr)
|
|
||||||
if err := fallbackToUnencryptedIPv4DNS(dnsConf, settings.Providers); err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
logger.Warn("restarting unbound because of unexpected exit")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupPortForwarding(logger logging.Logger, providerConf provider.Provider, filepath models.Filepath, uid, gid int) {
|
func setupPortForwarding(logger logging.Logger, providerConf provider.Provider, filepath models.Filepath, uid, gid int) {
|
||||||
pfLogger := logger.WithPrefix("port forwarding: ")
|
pfLogger := logger.WithPrefix("port forwarding: ")
|
||||||
var port uint16
|
var port uint16
|
||||||
|
|||||||
183
internal/dns/loop.go
Normal file
183
internal/dns/loop.go
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/golibs/command"
|
||||||
|
"github.com/qdm12/golibs/logging"
|
||||||
|
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||||
|
"github.com/qdm12/private-internet-access-docker/internal/settings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Looper interface {
|
||||||
|
Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{})
|
||||||
|
RunRestartTicker(ctx context.Context, restart chan<- struct{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type looper struct {
|
||||||
|
conf Configurator
|
||||||
|
settings settings.DNS
|
||||||
|
logger logging.Logger
|
||||||
|
streamMerger command.StreamMerger
|
||||||
|
uid int
|
||||||
|
gid int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
|
||||||
|
streamMerger command.StreamMerger, uid, gid int) Looper {
|
||||||
|
return &looper{
|
||||||
|
conf: conf,
|
||||||
|
settings: settings,
|
||||||
|
logger: logger.WithPrefix("dns over tls: "),
|
||||||
|
uid: uid,
|
||||||
|
gid: gid,
|
||||||
|
streamMerger: streamMerger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) attemptingRestart(err error) {
|
||||||
|
l.logger.Warn(err)
|
||||||
|
l.logger.Info("attempting restart in 10 seconds")
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) {
|
||||||
|
l.fallbackToUnencryptedDNS()
|
||||||
|
select {
|
||||||
|
case <-restart:
|
||||||
|
case <-ctx.Done():
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, unboundCancel := context.WithCancel(ctx)
|
||||||
|
for {
|
||||||
|
if !l.settings.Enabled {
|
||||||
|
// wait for another restart signal to recheck if it is enabled
|
||||||
|
select {
|
||||||
|
case <-restart:
|
||||||
|
case <-ctx.Done():
|
||||||
|
unboundCancel()
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ctx.Err() == context.Canceled {
|
||||||
|
unboundCancel()
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup
|
||||||
|
if err := l.conf.DownloadRootHints(l.uid, l.gid); err != nil {
|
||||||
|
l.attemptingRestart(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := l.conf.DownloadRootKey(l.uid, l.gid); err != nil {
|
||||||
|
l.attemptingRestart(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := l.conf.MakeUnboundConf(l.settings, l.uid, l.gid); err != nil {
|
||||||
|
l.attemptingRestart(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start command
|
||||||
|
unboundCancel()
|
||||||
|
unboundCtx, unboundCancel := context.WithCancel(ctx)
|
||||||
|
stream, waitFn, err := l.conf.Start(unboundCtx, l.settings.VerbosityDetailsLevel)
|
||||||
|
if err != nil {
|
||||||
|
unboundCancel()
|
||||||
|
l.fallbackToUnencryptedDNS()
|
||||||
|
l.attemptingRestart(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Started successfully
|
||||||
|
go l.streamMerger.Merge(unboundCtx, stream,
|
||||||
|
command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound()))
|
||||||
|
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
|
||||||
|
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound
|
||||||
|
unboundCancel()
|
||||||
|
l.fallbackToUnencryptedDNS()
|
||||||
|
l.attemptingRestart(err)
|
||||||
|
}
|
||||||
|
if err := l.conf.WaitForUnbound(); err != nil {
|
||||||
|
unboundCancel()
|
||||||
|
l.fallbackToUnencryptedDNS()
|
||||||
|
l.attemptingRestart(err)
|
||||||
|
}
|
||||||
|
waitError := make(chan error)
|
||||||
|
go func() {
|
||||||
|
err := waitFn() // blocking
|
||||||
|
if unboundCtx.Err() != context.Canceled {
|
||||||
|
waitError <- fmt.Errorf("unbound: %w", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for one of the three cases below
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
l.logger.Warn("context canceled: exiting loop")
|
||||||
|
unboundCancel()
|
||||||
|
close(waitError)
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
case <-restart: // triggered restart
|
||||||
|
unboundCancel()
|
||||||
|
close(waitError)
|
||||||
|
l.logger.Info("restarting")
|
||||||
|
case err := <-waitError: // unexpected error
|
||||||
|
unboundCancel()
|
||||||
|
close(waitError)
|
||||||
|
l.fallbackToUnencryptedDNS()
|
||||||
|
l.attemptingRestart(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) fallbackToUnencryptedDNS() {
|
||||||
|
// Try with user provided plaintext ip address
|
||||||
|
targetIP := l.settings.PlaintextAddress
|
||||||
|
if targetIP != nil {
|
||||||
|
l.conf.UseDNSInternally(targetIP)
|
||||||
|
if err := l.conf.UseDNSSystemWide(targetIP); err != nil {
|
||||||
|
l.logger.Error(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try with any IPv4 address from the providers chosen
|
||||||
|
for _, provider := range l.settings.Providers {
|
||||||
|
data := constants.DNSProviderMapping()[provider]
|
||||||
|
for _, targetIP = range data.IPs {
|
||||||
|
if targetIP.To4() != nil {
|
||||||
|
l.conf.UseDNSInternally(targetIP)
|
||||||
|
if err := l.conf.UseDNSSystemWide(targetIP); err != nil {
|
||||||
|
l.logger.Error(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No IPv4 address found
|
||||||
|
l.logger.Error("no ipv4 DNS address found for providers %s", l.settings.Providers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{}) {
|
||||||
|
if l.settings.UpdatePeriod == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ticker := time.NewTicker(l.settings.UpdatePeriod)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
ticker.Stop()
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
restart <- struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
76
internal/openvpn/loop.go
Normal file
76
internal/openvpn/loop.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package openvpn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/golibs/command"
|
||||||
|
"github.com/qdm12/golibs/logging"
|
||||||
|
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||||
|
"github.com/qdm12/private-internet-access-docker/internal/settings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Looper interface {
|
||||||
|
Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type looper struct {
|
||||||
|
conf Configurator
|
||||||
|
settings settings.OpenVPN
|
||||||
|
logger logging.Logger
|
||||||
|
streamMerger command.StreamMerger
|
||||||
|
fatalOnError func(err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLooper(conf Configurator, settings settings.OpenVPN, logger logging.Logger,
|
||||||
|
streamMerger command.StreamMerger, fatalOnError func(err error)) Looper {
|
||||||
|
return &looper{
|
||||||
|
conf: conf,
|
||||||
|
settings: settings,
|
||||||
|
logger: logger.WithPrefix("openvpn: "),
|
||||||
|
streamMerger: streamMerger,
|
||||||
|
fatalOnError: fatalOnError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) {
|
||||||
|
select {
|
||||||
|
case <-restart:
|
||||||
|
case <-ctx.Done():
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
openvpnCtx, openvpnCancel := context.WithCancel(ctx)
|
||||||
|
stream, waitFn, err := l.conf.Start(openvpnCtx)
|
||||||
|
l.fatalOnError(err)
|
||||||
|
go l.streamMerger.Merge(openvpnCtx, stream,
|
||||||
|
command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn()))
|
||||||
|
waitError := make(chan error)
|
||||||
|
go func() {
|
||||||
|
err := waitFn() // blocking
|
||||||
|
if openvpnCtx.Err() != context.Canceled {
|
||||||
|
waitError <- fmt.Errorf("openvpn: %w", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
l.logger.Warn("context canceled: exiting loop")
|
||||||
|
openvpnCancel()
|
||||||
|
close(waitError)
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
case <-restart: // triggered restart
|
||||||
|
l.logger.Info("restarting")
|
||||||
|
openvpnCancel()
|
||||||
|
close(waitError)
|
||||||
|
case err := <-waitError: // unexpected error
|
||||||
|
l.logger.Warn(err)
|
||||||
|
l.logger.Info("restarting")
|
||||||
|
openvpnCancel()
|
||||||
|
close(waitError)
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,55 +4,37 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server interface {
|
type Server interface {
|
||||||
SetOpenVPNRestart(f func())
|
Run(ctx context.Context, serverDone chan struct{})
|
||||||
SetUnboundRestart(f func())
|
|
||||||
Run(ctx context.Context) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type server struct {
|
type server struct {
|
||||||
address string
|
address string
|
||||||
logger logging.Logger
|
logger logging.Logger
|
||||||
restartOpenvpn func()
|
restartOpenvpn chan<- struct{}
|
||||||
restartOpenvpnSet context.Context
|
restartUnbound chan<- struct{}
|
||||||
restartOpenvpnSetSignal func()
|
|
||||||
restartUnbound func()
|
|
||||||
restartUnboundSet context.Context
|
|
||||||
restartUnboundSetSignal func()
|
|
||||||
sync.RWMutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(address string, logger logging.Logger) Server {
|
func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound chan<- struct{}) Server {
|
||||||
restartOpenvpnSet, restartOpenvpnSetSignal := context.WithCancel(context.Background())
|
|
||||||
restartUnboundSet, restartUnboundSetSignal := context.WithCancel(context.Background())
|
|
||||||
return &server{
|
return &server{
|
||||||
address: address,
|
address: address,
|
||||||
logger: logger.WithPrefix("http server: "),
|
logger: logger.WithPrefix("http server: "),
|
||||||
restartOpenvpnSet: restartOpenvpnSet,
|
restartOpenvpn: restartOpenvpn,
|
||||||
restartOpenvpnSetSignal: restartOpenvpnSetSignal,
|
restartUnbound: restartUnbound,
|
||||||
restartUnboundSet: restartUnboundSet,
|
|
||||||
restartUnboundSetSignal: restartUnboundSetSignal,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) Run(ctx context.Context) error {
|
func (s *server) Run(ctx context.Context, serverDone chan struct{}) {
|
||||||
if s.restartOpenvpnSet.Err() == nil {
|
|
||||||
s.logger.Warn("restartOpenvpn function is not set, waiting...")
|
|
||||||
<-s.restartOpenvpnSet.Done()
|
|
||||||
}
|
|
||||||
if s.restartUnboundSet.Err() == nil {
|
|
||||||
s.logger.Warn("restartUnbound function is not set, waiting...")
|
|
||||||
<-s.restartUnboundSet.Done()
|
|
||||||
}
|
|
||||||
server := http.Server{Addr: s.address, Handler: s.makeHandler()}
|
server := http.Server{Addr: s.address, Handler: s.makeHandler()}
|
||||||
go func() {
|
go func() {
|
||||||
|
defer close(serverDone)
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
|
s.logger.Warn("context canceled: exiting loop")
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||||
@@ -60,24 +42,9 @@ func (s *server) Run(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
s.logger.Info("listening on %s", s.address)
|
s.logger.Info("listening on %s", s.address)
|
||||||
return server.ListenAndServe()
|
err := server.ListenAndServe()
|
||||||
}
|
if err != nil && ctx.Err() != context.Canceled {
|
||||||
|
s.logger.Error(err)
|
||||||
func (s *server) SetOpenVPNRestart(f func()) {
|
|
||||||
s.Lock()
|
|
||||||
defer s.Unlock()
|
|
||||||
s.restartOpenvpn = f
|
|
||||||
if s.restartOpenvpnSet.Err() == nil {
|
|
||||||
s.restartOpenvpnSetSignal()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) SetUnboundRestart(f func()) {
|
|
||||||
s.Lock()
|
|
||||||
defer s.Unlock()
|
|
||||||
s.restartUnbound = f
|
|
||||||
if s.restartUnboundSet.Err() == nil {
|
|
||||||
s.restartUnboundSetSignal()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,13 +55,9 @@ func (s *server) makeHandler() http.HandlerFunc {
|
|||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
switch r.RequestURI {
|
switch r.RequestURI {
|
||||||
case "/openvpn/actions/restart":
|
case "/openvpn/actions/restart":
|
||||||
s.RLock()
|
s.restartOpenvpn <- struct{}{}
|
||||||
defer s.RUnlock()
|
|
||||||
s.restartOpenvpn()
|
|
||||||
case "/unbound/actions/restart":
|
case "/unbound/actions/restart":
|
||||||
s.RLock()
|
s.restartUnbound <- struct{}{}
|
||||||
defer s.RUnlock()
|
|
||||||
s.restartUnbound()
|
|
||||||
default:
|
default:
|
||||||
routeDoesNotExist(s.logger, w, r)
|
routeDoesNotExist(s.logger, w, r)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user