Simplified loop mechanism for openvpn and dns
- Refers to #91 - http control server starts without waiting for unbound and/or openvpn - Trying to get rid of waiter and use channels directly - Simpler main.go - More robust logic overall
This commit is contained in:
183
internal/dns/loop.go
Normal file
183
internal/dns/loop.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/settings"
|
||||
)
|
||||
|
||||
type Looper interface {
|
||||
Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{})
|
||||
RunRestartTicker(ctx context.Context, restart chan<- struct{})
|
||||
}
|
||||
|
||||
type looper struct {
|
||||
conf Configurator
|
||||
settings settings.DNS
|
||||
logger logging.Logger
|
||||
streamMerger command.StreamMerger
|
||||
uid int
|
||||
gid int
|
||||
}
|
||||
|
||||
func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
|
||||
streamMerger command.StreamMerger, uid, gid int) Looper {
|
||||
return &looper{
|
||||
conf: conf,
|
||||
settings: settings,
|
||||
logger: logger.WithPrefix("dns over tls: "),
|
||||
uid: uid,
|
||||
gid: gid,
|
||||
streamMerger: streamMerger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) attemptingRestart(err error) {
|
||||
l.logger.Warn(err)
|
||||
l.logger.Info("attempting restart in 10 seconds")
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
|
||||
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) {
|
||||
l.fallbackToUnencryptedDNS()
|
||||
select {
|
||||
case <-restart:
|
||||
case <-ctx.Done():
|
||||
close(done)
|
||||
return
|
||||
}
|
||||
_, unboundCancel := context.WithCancel(ctx)
|
||||
for {
|
||||
if !l.settings.Enabled {
|
||||
// wait for another restart signal to recheck if it is enabled
|
||||
select {
|
||||
case <-restart:
|
||||
case <-ctx.Done():
|
||||
unboundCancel()
|
||||
close(done)
|
||||
return
|
||||
}
|
||||
}
|
||||
if ctx.Err() == context.Canceled {
|
||||
unboundCancel()
|
||||
close(done)
|
||||
return
|
||||
}
|
||||
|
||||
// Setup
|
||||
if err := l.conf.DownloadRootHints(l.uid, l.gid); err != nil {
|
||||
l.attemptingRestart(err)
|
||||
continue
|
||||
}
|
||||
if err := l.conf.DownloadRootKey(l.uid, l.gid); err != nil {
|
||||
l.attemptingRestart(err)
|
||||
continue
|
||||
}
|
||||
if err := l.conf.MakeUnboundConf(l.settings, l.uid, l.gid); err != nil {
|
||||
l.attemptingRestart(err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Start command
|
||||
unboundCancel()
|
||||
unboundCtx, unboundCancel := context.WithCancel(ctx)
|
||||
stream, waitFn, err := l.conf.Start(unboundCtx, l.settings.VerbosityDetailsLevel)
|
||||
if err != nil {
|
||||
unboundCancel()
|
||||
l.fallbackToUnencryptedDNS()
|
||||
l.attemptingRestart(err)
|
||||
}
|
||||
|
||||
// Started successfully
|
||||
go l.streamMerger.Merge(unboundCtx, stream,
|
||||
command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound()))
|
||||
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
|
||||
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound
|
||||
unboundCancel()
|
||||
l.fallbackToUnencryptedDNS()
|
||||
l.attemptingRestart(err)
|
||||
}
|
||||
if err := l.conf.WaitForUnbound(); err != nil {
|
||||
unboundCancel()
|
||||
l.fallbackToUnencryptedDNS()
|
||||
l.attemptingRestart(err)
|
||||
}
|
||||
waitError := make(chan error)
|
||||
go func() {
|
||||
err := waitFn() // blocking
|
||||
if unboundCtx.Err() != context.Canceled {
|
||||
waitError <- fmt.Errorf("unbound: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for one of the three cases below
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
l.logger.Warn("context canceled: exiting loop")
|
||||
unboundCancel()
|
||||
close(waitError)
|
||||
close(done)
|
||||
return
|
||||
case <-restart: // triggered restart
|
||||
unboundCancel()
|
||||
close(waitError)
|
||||
l.logger.Info("restarting")
|
||||
case err := <-waitError: // unexpected error
|
||||
unboundCancel()
|
||||
close(waitError)
|
||||
l.fallbackToUnencryptedDNS()
|
||||
l.attemptingRestart(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) fallbackToUnencryptedDNS() {
|
||||
// Try with user provided plaintext ip address
|
||||
targetIP := l.settings.PlaintextAddress
|
||||
if targetIP != nil {
|
||||
l.conf.UseDNSInternally(targetIP)
|
||||
if err := l.conf.UseDNSSystemWide(targetIP); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Try with any IPv4 address from the providers chosen
|
||||
for _, provider := range l.settings.Providers {
|
||||
data := constants.DNSProviderMapping()[provider]
|
||||
for _, targetIP = range data.IPs {
|
||||
if targetIP.To4() != nil {
|
||||
l.conf.UseDNSInternally(targetIP)
|
||||
if err := l.conf.UseDNSSystemWide(targetIP); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No IPv4 address found
|
||||
l.logger.Error("no ipv4 DNS address found for providers %s", l.settings.Providers)
|
||||
}
|
||||
|
||||
func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{}) {
|
||||
if l.settings.UpdatePeriod == 0 {
|
||||
return
|
||||
}
|
||||
ticker := time.NewTicker(l.settings.UpdatePeriod)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ticker.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
restart <- struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
76
internal/openvpn/loop.go
Normal file
76
internal/openvpn/loop.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/settings"
|
||||
)
|
||||
|
||||
type Looper interface {
|
||||
Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{})
|
||||
}
|
||||
|
||||
type looper struct {
|
||||
conf Configurator
|
||||
settings settings.OpenVPN
|
||||
logger logging.Logger
|
||||
streamMerger command.StreamMerger
|
||||
fatalOnError func(err error)
|
||||
}
|
||||
|
||||
func NewLooper(conf Configurator, settings settings.OpenVPN, logger logging.Logger,
|
||||
streamMerger command.StreamMerger, fatalOnError func(err error)) Looper {
|
||||
return &looper{
|
||||
conf: conf,
|
||||
settings: settings,
|
||||
logger: logger.WithPrefix("openvpn: "),
|
||||
streamMerger: streamMerger,
|
||||
fatalOnError: fatalOnError,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) {
|
||||
select {
|
||||
case <-restart:
|
||||
case <-ctx.Done():
|
||||
close(done)
|
||||
return
|
||||
}
|
||||
for {
|
||||
openvpnCtx, openvpnCancel := context.WithCancel(ctx)
|
||||
stream, waitFn, err := l.conf.Start(openvpnCtx)
|
||||
l.fatalOnError(err)
|
||||
go l.streamMerger.Merge(openvpnCtx, stream,
|
||||
command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn()))
|
||||
waitError := make(chan error)
|
||||
go func() {
|
||||
err := waitFn() // blocking
|
||||
if openvpnCtx.Err() != context.Canceled {
|
||||
waitError <- fmt.Errorf("openvpn: %w", err)
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
l.logger.Warn("context canceled: exiting loop")
|
||||
openvpnCancel()
|
||||
close(waitError)
|
||||
close(done)
|
||||
return
|
||||
case <-restart: // triggered restart
|
||||
l.logger.Info("restarting")
|
||||
openvpnCancel()
|
||||
close(waitError)
|
||||
case err := <-waitError: // unexpected error
|
||||
l.logger.Warn(err)
|
||||
l.logger.Info("restarting")
|
||||
openvpnCancel()
|
||||
close(waitError)
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,55 +4,37 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type Server interface {
|
||||
SetOpenVPNRestart(f func())
|
||||
SetUnboundRestart(f func())
|
||||
Run(ctx context.Context) error
|
||||
Run(ctx context.Context, serverDone chan struct{})
|
||||
}
|
||||
|
||||
type server struct {
|
||||
address string
|
||||
logger logging.Logger
|
||||
restartOpenvpn func()
|
||||
restartOpenvpnSet context.Context
|
||||
restartOpenvpnSetSignal func()
|
||||
restartUnbound func()
|
||||
restartUnboundSet context.Context
|
||||
restartUnboundSetSignal func()
|
||||
sync.RWMutex
|
||||
address string
|
||||
logger logging.Logger
|
||||
restartOpenvpn chan<- struct{}
|
||||
restartUnbound chan<- struct{}
|
||||
}
|
||||
|
||||
func New(address string, logger logging.Logger) Server {
|
||||
restartOpenvpnSet, restartOpenvpnSetSignal := context.WithCancel(context.Background())
|
||||
restartUnboundSet, restartUnboundSetSignal := context.WithCancel(context.Background())
|
||||
func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound chan<- struct{}) Server {
|
||||
return &server{
|
||||
address: address,
|
||||
logger: logger.WithPrefix("http server: "),
|
||||
restartOpenvpnSet: restartOpenvpnSet,
|
||||
restartOpenvpnSetSignal: restartOpenvpnSetSignal,
|
||||
restartUnboundSet: restartUnboundSet,
|
||||
restartUnboundSetSignal: restartUnboundSetSignal,
|
||||
address: address,
|
||||
logger: logger.WithPrefix("http server: "),
|
||||
restartOpenvpn: restartOpenvpn,
|
||||
restartUnbound: restartUnbound,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) Run(ctx context.Context) error {
|
||||
if s.restartOpenvpnSet.Err() == nil {
|
||||
s.logger.Warn("restartOpenvpn function is not set, waiting...")
|
||||
<-s.restartOpenvpnSet.Done()
|
||||
}
|
||||
if s.restartUnboundSet.Err() == nil {
|
||||
s.logger.Warn("restartUnbound function is not set, waiting...")
|
||||
<-s.restartUnboundSet.Done()
|
||||
}
|
||||
func (s *server) Run(ctx context.Context, serverDone chan struct{}) {
|
||||
server := http.Server{Addr: s.address, Handler: s.makeHandler()}
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
<-ctx.Done()
|
||||
s.logger.Warn("context canceled: exiting loop")
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
@@ -60,24 +42,9 @@ func (s *server) Run(ctx context.Context) error {
|
||||
}
|
||||
}()
|
||||
s.logger.Info("listening on %s", s.address)
|
||||
return server.ListenAndServe()
|
||||
}
|
||||
|
||||
func (s *server) SetOpenVPNRestart(f func()) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.restartOpenvpn = f
|
||||
if s.restartOpenvpnSet.Err() == nil {
|
||||
s.restartOpenvpnSetSignal()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) SetUnboundRestart(f func()) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.restartUnbound = f
|
||||
if s.restartUnboundSet.Err() == nil {
|
||||
s.restartUnboundSetSignal()
|
||||
err := server.ListenAndServe()
|
||||
if err != nil && ctx.Err() != context.Canceled {
|
||||
s.logger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,13 +55,9 @@ func (s *server) makeHandler() http.HandlerFunc {
|
||||
case http.MethodGet:
|
||||
switch r.RequestURI {
|
||||
case "/openvpn/actions/restart":
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
s.restartOpenvpn()
|
||||
s.restartOpenvpn <- struct{}{}
|
||||
case "/unbound/actions/restart":
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
s.restartUnbound()
|
||||
s.restartUnbound <- struct{}{}
|
||||
default:
|
||||
routeDoesNotExist(s.logger, w, r)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user