Simplified loop mechanism for openvpn and dns

- Refers to #91
- http control server starts without waiting for unbound and/or openvpn
- Trying to get rid of waiter and use channels directly
- Simpler main.go
- More robust logic overall
This commit is contained in:
Quentin McGaw
2020-07-08 13:14:39 +00:00
parent dd529a48fa
commit 7a136db085
4 changed files with 309 additions and 224 deletions

View File

@@ -3,7 +3,6 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"os" "os"
"os/signal" "os/signal"
"regexp" "regexp"
@@ -211,40 +210,49 @@ func _main(background context.Context, args []string) int {
go streamMerger.Merge(ctx, stderr, command.MergeName("shadowsocks error"), command.MergeColor(constants.ColorShadowsocksError())) go streamMerger.Merge(ctx, stderr, command.MergeName("shadowsocks error"), command.MergeColor(constants.ColorShadowsocksError()))
} }
httpServer := server.New("0.0.0.0:8000", logger) restartOpenvpn := make(chan struct{})
restartUnbound := make(chan struct{})
openvpnDone := make(chan struct{})
unboundDone := make(chan struct{})
serverDone := make(chan struct{})
go openvpnRunLoop(ctx, ovpnConf, streamMerger, logger, httpServer, waiter, fatalOnError) openvpnLooper := openvpn.NewLooper(ovpnConf, allSettings.OpenVPN, logger, streamMerger, fatalOnError)
// wait for restartOpenvpn
go openvpnLooper.Run(ctx, restartOpenvpn, openvpnDone)
waiter.Add(func() error { unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger,
err := httpServer.Run(ctx) streamMerger, allSettings.System.UID, allSettings.System.GID)
logger.Error("http server: %s", err) // wait for restartUnbound
return err go unboundLooper.Run(ctx, restartUnbound, unboundDone)
})
startUnboundCh := make(chan struct{})
go unboundRunLoop(ctx, startUnboundCh, logger, dnsConf, allSettings.DNS, allSettings.System.UID, allSettings.System.GID, waiter, streamMerger, httpServer)
if !allSettings.DNS.Enabled {
httpServer.SetUnboundRestart(func() {})
dnsConf.UseDNSInternally(allSettings.DNS.PlaintextAddress)
if err := dnsConf.UseDNSSystemWide(allSettings.DNS.PlaintextAddress); err != nil {
logger.Error(err)
}
}
go func() { go func() {
first := true
var restartTickerContext context.Context
var restartTickerCancel context.CancelFunc = func() {}
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
restartTickerCancel()
return return
case <-connectedCh: // blocks until openvpn is connected case <-connectedCh: // blocks until openvpn is connected
if allSettings.DNS.Enabled { if first {
startUnboundCh <- struct{}{} first = false
restartUnbound <- struct{}{}
} }
restartTickerCancel()
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
go unboundLooper.RunRestartTicker(restartTickerContext, restartUnbound)
onConnected(allSettings, logger, fileManager, routingConf, defaultInterface, providerConf) onConnected(allSettings, logger, fileManager, routingConf, defaultInterface, providerConf)
} }
} }
}() }()
httpServer := server.New("0.0.0.0:8000", logger, restartOpenvpn, restartUnbound)
go httpServer.Run(ctx, serverDone)
// Start openvpn for the first time
restartOpenvpn <- struct{}{}
signalsCh := make(chan os.Signal, 1) signalsCh := make(chan os.Signal, 1)
signal.Notify(signalsCh, signal.Notify(signalsCh,
syscall.SIGINT, syscall.SIGINT,
@@ -278,6 +286,9 @@ func _main(background context.Context, args []string) int {
logger.Error(err) logger.Error(err)
exitStatus = 1 exitStatus = 1
} }
<-serverDone
<-unboundDone
<-openvpnDone
return exitStatus return exitStatus
} }
@@ -348,34 +359,6 @@ func trimEventualProgramPrefix(s string) string {
} }
} }
func openvpnRunLoop(ctx context.Context, ovpnConf openvpn.Configurator, streamMerger command.StreamMerger,
logger logging.Logger, httpServer server.Server, waiter command.Waiter, fatalOnError func(err error)) {
logger = logger.WithPrefix("openvpn: ")
waitErrors := make(chan error)
for ctx.Err() == nil {
logger.Info("starting")
openvpnCtx, openvpnCancel := context.WithCancel(ctx)
stream, waitFn, err := ovpnConf.Start(openvpnCtx)
fatalOnError(err)
httpServer.SetOpenVPNRestart(openvpnCancel)
go streamMerger.Merge(openvpnCtx, stream, command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn()))
waiter.Add(func() error {
return <-waitErrors
})
err = waitFn()
waitErrors <- fmt.Errorf("openvpn: %w", err)
switch {
case ctx.Err() != nil:
logger.Warn("context canceled: exiting openvpn run loop")
case openvpnCtx.Err() == context.Canceled:
logger.Info("triggered openvpn restart")
default:
logger.Warn(err)
openvpnCancel()
}
}
}
func onConnected(allSettings settings.Settings, func onConnected(allSettings settings.Settings,
logger logging.Logger, fileManager files.FileManager, logger logging.Logger, fileManager files.FileManager,
routingConf routing.Routing, defaultInterface string, routingConf routing.Routing, defaultInterface string,
@@ -393,7 +376,7 @@ func onConnected(allSettings settings.Settings,
} else { } else {
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP) logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
} }
time.AfterFunc(7*time.Second, func() { // wait for Unbound to start - TODO use signal channel time.AfterFunc(10*time.Second, func() { // wait for Unbound to start - TODO use signal channel
publicIP, err := publicip.NewIPGetter(network.NewClient(3 * time.Second)).Get() publicIP, err := publicip.NewIPGetter(network.NewClient(3 * time.Second)).Get()
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
@@ -411,126 +394,6 @@ func onConnected(allSettings settings.Settings,
}) })
} }
func fallbackToUnencryptedIPv4DNS(dnsConf dns.Configurator, providers []models.DNSProvider) error {
var targetIP net.IP
for _, provider := range providers {
data := constants.DNSProviderMapping()[provider]
for _, targetIP = range data.IPs {
if targetIP.To4() != nil {
dnsConf.UseDNSInternally(targetIP)
return dnsConf.UseDNSSystemWide(targetIP)
}
}
}
// No IPv4 address found
return fmt.Errorf("no ipv4 DNS address found for providers %s", providers)
}
func unboundRun(ctx, oldCtx context.Context, oldCancel context.CancelFunc, timer *time.Timer,
dnsConf dns.Configurator, settings settings.DNS, uid, gid int,
streamMerger command.StreamMerger, waiter command.Waiter, httpServer server.Server) (
newCtx context.Context, newCancel context.CancelFunc, setupErr, startErr, waitErr error) {
if timer != nil {
timer.Stop()
timer.Reset(settings.UpdatePeriod)
}
if err := dnsConf.DownloadRootHints(uid, gid); err != nil {
return oldCtx, oldCancel, err, nil, nil
}
if err := dnsConf.DownloadRootKey(uid, gid); err != nil {
return oldCtx, oldCancel, err, nil, nil
}
if err := dnsConf.MakeUnboundConf(settings, uid, gid); err != nil {
return oldCtx, oldCancel, err, nil, nil
}
newCtx, newCancel = context.WithCancel(ctx)
oldCancel()
stream, waitFn, err := dnsConf.Start(newCtx, settings.VerbosityDetailsLevel)
if err != nil {
return newCtx, newCancel, nil, err, nil
}
go streamMerger.Merge(newCtx, stream, command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound()))
dnsConf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
if err := dnsConf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound
return newCtx, newCancel, nil, err, nil
}
if err := dnsConf.WaitForUnbound(); err != nil {
return newCtx, newCancel, nil, err, nil
}
// Unbound is up and running at this point
httpServer.SetUnboundRestart(newCancel)
waitError := make(chan error)
waiterError := make(chan error)
waiter.Add(func() error { //nolint:scopelint
return <-waiterError
})
go func() {
err := fmt.Errorf("unbound: %w", waitFn())
waitError <- err
waiterError <- err
}()
if timer == nil {
waitErr := <-waitError
return newCtx, newCancel, nil, nil, waitErr
}
select {
case <-timer.C:
return newCtx, newCancel, nil, nil, nil
case waitErr := <-waitError:
return newCtx, newCancel, nil, nil, waitErr
}
}
func unboundRunLoop(ctx context.Context, startCh chan struct{}, logger logging.Logger, dnsConf dns.Configurator,
settings settings.DNS, uid, gid int,
waiter command.Waiter, streamMerger command.StreamMerger, httpServer server.Server,
) {
logger = logger.WithPrefix("unbound dns over tls setup: ")
select {
case <-startCh:
case <-ctx.Done():
logger.Warn("context canceled: exiting unbound run loop")
return
}
if err := fallbackToUnencryptedIPv4DNS(dnsConf, settings.Providers); err != nil {
logger.Error(err)
}
var timer *time.Timer
if settings.UpdatePeriod > 0 {
timer = time.NewTimer(settings.UpdatePeriod)
}
unboundCtx, unboundCancel := context.WithCancel(ctx)
defer unboundCancel()
for ctx.Err() == nil {
var setupErr, startErr, waitErr error
unboundCtx, unboundCancel, setupErr, startErr, waitErr = unboundRun(
ctx, unboundCtx, unboundCancel, timer, dnsConf, settings,
uid, gid, streamMerger, waiter, httpServer)
switch {
case ctx.Err() != nil:
logger.Warn("context canceled: exiting unbound run loop")
case timer != nil && !timer.Stop():
logger.Info("planned restart of unbound")
case unboundCtx.Err() == context.Canceled:
logger.Info("triggered restart of unbound")
case setupErr != nil:
logger.Warn(setupErr)
case startErr != nil:
logger.Error(startErr)
unboundCancel()
if err := fallbackToUnencryptedIPv4DNS(dnsConf, settings.Providers); err != nil {
logger.Error(err)
}
case waitErr != nil:
logger.Warn(waitErr)
if err := fallbackToUnencryptedIPv4DNS(dnsConf, settings.Providers); err != nil {
logger.Error(err)
}
logger.Warn("restarting unbound because of unexpected exit")
}
}
}
func setupPortForwarding(logger logging.Logger, providerConf provider.Provider, filepath models.Filepath, uid, gid int) { func setupPortForwarding(logger logging.Logger, providerConf provider.Provider, filepath models.Filepath, uid, gid int) {
pfLogger := logger.WithPrefix("port forwarding: ") pfLogger := logger.WithPrefix("port forwarding: ")
var port uint16 var port uint16

183
internal/dns/loop.go Normal file
View File

@@ -0,0 +1,183 @@
package dns
import (
"context"
"fmt"
"net"
"time"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/settings"
)
type Looper interface {
Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{})
RunRestartTicker(ctx context.Context, restart chan<- struct{})
}
type looper struct {
conf Configurator
settings settings.DNS
logger logging.Logger
streamMerger command.StreamMerger
uid int
gid int
}
func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
streamMerger command.StreamMerger, uid, gid int) Looper {
return &looper{
conf: conf,
settings: settings,
logger: logger.WithPrefix("dns over tls: "),
uid: uid,
gid: gid,
streamMerger: streamMerger,
}
}
func (l *looper) attemptingRestart(err error) {
l.logger.Warn(err)
l.logger.Info("attempting restart in 10 seconds")
time.Sleep(10 * time.Second)
}
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) {
l.fallbackToUnencryptedDNS()
select {
case <-restart:
case <-ctx.Done():
close(done)
return
}
_, unboundCancel := context.WithCancel(ctx)
for {
if !l.settings.Enabled {
// wait for another restart signal to recheck if it is enabled
select {
case <-restart:
case <-ctx.Done():
unboundCancel()
close(done)
return
}
}
if ctx.Err() == context.Canceled {
unboundCancel()
close(done)
return
}
// Setup
if err := l.conf.DownloadRootHints(l.uid, l.gid); err != nil {
l.attemptingRestart(err)
continue
}
if err := l.conf.DownloadRootKey(l.uid, l.gid); err != nil {
l.attemptingRestart(err)
continue
}
if err := l.conf.MakeUnboundConf(l.settings, l.uid, l.gid); err != nil {
l.attemptingRestart(err)
continue
}
// Start command
unboundCancel()
unboundCtx, unboundCancel := context.WithCancel(ctx)
stream, waitFn, err := l.conf.Start(unboundCtx, l.settings.VerbosityDetailsLevel)
if err != nil {
unboundCancel()
l.fallbackToUnencryptedDNS()
l.attemptingRestart(err)
}
// Started successfully
go l.streamMerger.Merge(unboundCtx, stream,
command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound()))
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound
unboundCancel()
l.fallbackToUnencryptedDNS()
l.attemptingRestart(err)
}
if err := l.conf.WaitForUnbound(); err != nil {
unboundCancel()
l.fallbackToUnencryptedDNS()
l.attemptingRestart(err)
}
waitError := make(chan error)
go func() {
err := waitFn() // blocking
if unboundCtx.Err() != context.Canceled {
waitError <- fmt.Errorf("unbound: %w", err)
}
}()
// Wait for one of the three cases below
select {
case <-ctx.Done():
l.logger.Warn("context canceled: exiting loop")
unboundCancel()
close(waitError)
close(done)
return
case <-restart: // triggered restart
unboundCancel()
close(waitError)
l.logger.Info("restarting")
case err := <-waitError: // unexpected error
unboundCancel()
close(waitError)
l.fallbackToUnencryptedDNS()
l.attemptingRestart(err)
}
}
}
func (l *looper) fallbackToUnencryptedDNS() {
// Try with user provided plaintext ip address
targetIP := l.settings.PlaintextAddress
if targetIP != nil {
l.conf.UseDNSInternally(targetIP)
if err := l.conf.UseDNSSystemWide(targetIP); err != nil {
l.logger.Error(err)
}
return
}
// Try with any IPv4 address from the providers chosen
for _, provider := range l.settings.Providers {
data := constants.DNSProviderMapping()[provider]
for _, targetIP = range data.IPs {
if targetIP.To4() != nil {
l.conf.UseDNSInternally(targetIP)
if err := l.conf.UseDNSSystemWide(targetIP); err != nil {
l.logger.Error(err)
}
return
}
}
}
// No IPv4 address found
l.logger.Error("no ipv4 DNS address found for providers %s", l.settings.Providers)
}
func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{}) {
if l.settings.UpdatePeriod == 0 {
return
}
ticker := time.NewTicker(l.settings.UpdatePeriod)
for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
restart <- struct{}{}
}
}
}

76
internal/openvpn/loop.go Normal file
View File

@@ -0,0 +1,76 @@
package openvpn
import (
"context"
"fmt"
"time"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/settings"
)
type Looper interface {
Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{})
}
type looper struct {
conf Configurator
settings settings.OpenVPN
logger logging.Logger
streamMerger command.StreamMerger
fatalOnError func(err error)
}
func NewLooper(conf Configurator, settings settings.OpenVPN, logger logging.Logger,
streamMerger command.StreamMerger, fatalOnError func(err error)) Looper {
return &looper{
conf: conf,
settings: settings,
logger: logger.WithPrefix("openvpn: "),
streamMerger: streamMerger,
fatalOnError: fatalOnError,
}
}
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) {
select {
case <-restart:
case <-ctx.Done():
close(done)
return
}
for {
openvpnCtx, openvpnCancel := context.WithCancel(ctx)
stream, waitFn, err := l.conf.Start(openvpnCtx)
l.fatalOnError(err)
go l.streamMerger.Merge(openvpnCtx, stream,
command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn()))
waitError := make(chan error)
go func() {
err := waitFn() // blocking
if openvpnCtx.Err() != context.Canceled {
waitError <- fmt.Errorf("openvpn: %w", err)
}
}()
select {
case <-ctx.Done():
l.logger.Warn("context canceled: exiting loop")
openvpnCancel()
close(waitError)
close(done)
return
case <-restart: // triggered restart
l.logger.Info("restarting")
openvpnCancel()
close(waitError)
case err := <-waitError: // unexpected error
l.logger.Warn(err)
l.logger.Info("restarting")
openvpnCancel()
close(waitError)
time.Sleep(time.Second)
}
}
}

View File

@@ -4,55 +4,37 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
type Server interface { type Server interface {
SetOpenVPNRestart(f func()) Run(ctx context.Context, serverDone chan struct{})
SetUnboundRestart(f func())
Run(ctx context.Context) error
} }
type server struct { type server struct {
address string address string
logger logging.Logger logger logging.Logger
restartOpenvpn func() restartOpenvpn chan<- struct{}
restartOpenvpnSet context.Context restartUnbound chan<- struct{}
restartOpenvpnSetSignal func()
restartUnbound func()
restartUnboundSet context.Context
restartUnboundSetSignal func()
sync.RWMutex
} }
func New(address string, logger logging.Logger) Server { func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound chan<- struct{}) Server {
restartOpenvpnSet, restartOpenvpnSetSignal := context.WithCancel(context.Background())
restartUnboundSet, restartUnboundSetSignal := context.WithCancel(context.Background())
return &server{ return &server{
address: address, address: address,
logger: logger.WithPrefix("http server: "), logger: logger.WithPrefix("http server: "),
restartOpenvpnSet: restartOpenvpnSet, restartOpenvpn: restartOpenvpn,
restartOpenvpnSetSignal: restartOpenvpnSetSignal, restartUnbound: restartUnbound,
restartUnboundSet: restartUnboundSet,
restartUnboundSetSignal: restartUnboundSetSignal,
} }
} }
func (s *server) Run(ctx context.Context) error { func (s *server) Run(ctx context.Context, serverDone chan struct{}) {
if s.restartOpenvpnSet.Err() == nil {
s.logger.Warn("restartOpenvpn function is not set, waiting...")
<-s.restartOpenvpnSet.Done()
}
if s.restartUnboundSet.Err() == nil {
s.logger.Warn("restartUnbound function is not set, waiting...")
<-s.restartUnboundSet.Done()
}
server := http.Server{Addr: s.address, Handler: s.makeHandler()} server := http.Server{Addr: s.address, Handler: s.makeHandler()}
go func() { go func() {
defer close(serverDone)
<-ctx.Done() <-ctx.Done()
s.logger.Warn("context canceled: exiting loop")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil { if err := server.Shutdown(shutdownCtx); err != nil {
@@ -60,24 +42,9 @@ func (s *server) Run(ctx context.Context) error {
} }
}() }()
s.logger.Info("listening on %s", s.address) s.logger.Info("listening on %s", s.address)
return server.ListenAndServe() err := server.ListenAndServe()
} if err != nil && ctx.Err() != context.Canceled {
s.logger.Error(err)
func (s *server) SetOpenVPNRestart(f func()) {
s.Lock()
defer s.Unlock()
s.restartOpenvpn = f
if s.restartOpenvpnSet.Err() == nil {
s.restartOpenvpnSetSignal()
}
}
func (s *server) SetUnboundRestart(f func()) {
s.Lock()
defer s.Unlock()
s.restartUnbound = f
if s.restartUnboundSet.Err() == nil {
s.restartUnboundSetSignal()
} }
} }
@@ -88,13 +55,9 @@ func (s *server) makeHandler() http.HandlerFunc {
case http.MethodGet: case http.MethodGet:
switch r.RequestURI { switch r.RequestURI {
case "/openvpn/actions/restart": case "/openvpn/actions/restart":
s.RLock() s.restartOpenvpn <- struct{}{}
defer s.RUnlock()
s.restartOpenvpn()
case "/unbound/actions/restart": case "/unbound/actions/restart":
s.RLock() s.restartUnbound <- struct{}{}
defer s.RUnlock()
s.restartUnbound()
default: default:
routeDoesNotExist(s.logger, w, r) routeDoesNotExist(s.logger, w, r)
} }