diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 58b2ccd6..32bc5a30 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -3,7 +3,6 @@ package main import ( "context" "fmt" - "net" "os" "os/signal" "regexp" @@ -211,40 +210,49 @@ func _main(background context.Context, args []string) int { go streamMerger.Merge(ctx, stderr, command.MergeName("shadowsocks error"), command.MergeColor(constants.ColorShadowsocksError())) } - httpServer := server.New("0.0.0.0:8000", logger) + restartOpenvpn := make(chan struct{}) + restartUnbound := make(chan struct{}) + openvpnDone := make(chan struct{}) + unboundDone := make(chan struct{}) + serverDone := make(chan struct{}) - go openvpnRunLoop(ctx, ovpnConf, streamMerger, logger, httpServer, waiter, fatalOnError) + openvpnLooper := openvpn.NewLooper(ovpnConf, allSettings.OpenVPN, logger, streamMerger, fatalOnError) + // wait for restartOpenvpn + go openvpnLooper.Run(ctx, restartOpenvpn, openvpnDone) - waiter.Add(func() error { - err := httpServer.Run(ctx) - logger.Error("http server: %s", err) - return err - }) - - startUnboundCh := make(chan struct{}) - go unboundRunLoop(ctx, startUnboundCh, logger, dnsConf, allSettings.DNS, allSettings.System.UID, allSettings.System.GID, waiter, streamMerger, httpServer) - if !allSettings.DNS.Enabled { - httpServer.SetUnboundRestart(func() {}) - dnsConf.UseDNSInternally(allSettings.DNS.PlaintextAddress) - if err := dnsConf.UseDNSSystemWide(allSettings.DNS.PlaintextAddress); err != nil { - logger.Error(err) - } - } + unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, + streamMerger, allSettings.System.UID, allSettings.System.GID) + // wait for restartUnbound + go unboundLooper.Run(ctx, restartUnbound, unboundDone) go func() { + first := true + var restartTickerContext context.Context + var restartTickerCancel context.CancelFunc = func() {} for { select { case <-ctx.Done(): + restartTickerCancel() return case <-connectedCh: // blocks until openvpn is connected - if allSettings.DNS.Enabled { - startUnboundCh <- struct{}{} + if first { + first = false + restartUnbound <- struct{}{} } + restartTickerCancel() + restartTickerContext, restartTickerCancel = context.WithCancel(ctx) + go unboundLooper.RunRestartTicker(restartTickerContext, restartUnbound) onConnected(allSettings, logger, fileManager, routingConf, defaultInterface, providerConf) } } }() + httpServer := server.New("0.0.0.0:8000", logger, restartOpenvpn, restartUnbound) + go httpServer.Run(ctx, serverDone) + + // Start openvpn for the first time + restartOpenvpn <- struct{}{} + signalsCh := make(chan os.Signal, 1) signal.Notify(signalsCh, syscall.SIGINT, @@ -278,6 +286,9 @@ func _main(background context.Context, args []string) int { logger.Error(err) exitStatus = 1 } + <-serverDone + <-unboundDone + <-openvpnDone return exitStatus } @@ -348,34 +359,6 @@ func trimEventualProgramPrefix(s string) string { } } -func openvpnRunLoop(ctx context.Context, ovpnConf openvpn.Configurator, streamMerger command.StreamMerger, - logger logging.Logger, httpServer server.Server, waiter command.Waiter, fatalOnError func(err error)) { - logger = logger.WithPrefix("openvpn: ") - waitErrors := make(chan error) - for ctx.Err() == nil { - logger.Info("starting") - openvpnCtx, openvpnCancel := context.WithCancel(ctx) - stream, waitFn, err := ovpnConf.Start(openvpnCtx) - fatalOnError(err) - httpServer.SetOpenVPNRestart(openvpnCancel) - go streamMerger.Merge(openvpnCtx, stream, command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn())) - waiter.Add(func() error { - return <-waitErrors - }) - err = waitFn() - waitErrors <- fmt.Errorf("openvpn: %w", err) - switch { - case ctx.Err() != nil: - logger.Warn("context canceled: exiting openvpn run loop") - case openvpnCtx.Err() == context.Canceled: - logger.Info("triggered openvpn restart") - default: - logger.Warn(err) - openvpnCancel() - } - } -} - func onConnected(allSettings settings.Settings, logger logging.Logger, fileManager files.FileManager, routingConf routing.Routing, defaultInterface string, @@ -393,7 +376,7 @@ func onConnected(allSettings settings.Settings, } else { logger.Info("Gateway VPN IP address: %s", vpnGatewayIP) } - time.AfterFunc(7*time.Second, func() { // wait for Unbound to start - TODO use signal channel + time.AfterFunc(10*time.Second, func() { // wait for Unbound to start - TODO use signal channel publicIP, err := publicip.NewIPGetter(network.NewClient(3 * time.Second)).Get() if err != nil { logger.Error(err) @@ -411,126 +394,6 @@ func onConnected(allSettings settings.Settings, }) } -func fallbackToUnencryptedIPv4DNS(dnsConf dns.Configurator, providers []models.DNSProvider) error { - var targetIP net.IP - for _, provider := range providers { - data := constants.DNSProviderMapping()[provider] - for _, targetIP = range data.IPs { - if targetIP.To4() != nil { - dnsConf.UseDNSInternally(targetIP) - return dnsConf.UseDNSSystemWide(targetIP) - } - } - } - // No IPv4 address found - return fmt.Errorf("no ipv4 DNS address found for providers %s", providers) -} - -func unboundRun(ctx, oldCtx context.Context, oldCancel context.CancelFunc, timer *time.Timer, - dnsConf dns.Configurator, settings settings.DNS, uid, gid int, - streamMerger command.StreamMerger, waiter command.Waiter, httpServer server.Server) ( - newCtx context.Context, newCancel context.CancelFunc, setupErr, startErr, waitErr error) { - if timer != nil { - timer.Stop() - timer.Reset(settings.UpdatePeriod) - } - if err := dnsConf.DownloadRootHints(uid, gid); err != nil { - return oldCtx, oldCancel, err, nil, nil - } - if err := dnsConf.DownloadRootKey(uid, gid); err != nil { - return oldCtx, oldCancel, err, nil, nil - } - if err := dnsConf.MakeUnboundConf(settings, uid, gid); err != nil { - return oldCtx, oldCancel, err, nil, nil - } - newCtx, newCancel = context.WithCancel(ctx) - oldCancel() - stream, waitFn, err := dnsConf.Start(newCtx, settings.VerbosityDetailsLevel) - if err != nil { - return newCtx, newCancel, nil, err, nil - } - go streamMerger.Merge(newCtx, stream, command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound())) - dnsConf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound - if err := dnsConf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound - return newCtx, newCancel, nil, err, nil - } - if err := dnsConf.WaitForUnbound(); err != nil { - return newCtx, newCancel, nil, err, nil - } - // Unbound is up and running at this point - httpServer.SetUnboundRestart(newCancel) - waitError := make(chan error) - waiterError := make(chan error) - waiter.Add(func() error { //nolint:scopelint - return <-waiterError - }) - go func() { - err := fmt.Errorf("unbound: %w", waitFn()) - waitError <- err - waiterError <- err - }() - if timer == nil { - waitErr := <-waitError - return newCtx, newCancel, nil, nil, waitErr - } - select { - case <-timer.C: - return newCtx, newCancel, nil, nil, nil - case waitErr := <-waitError: - return newCtx, newCancel, nil, nil, waitErr - } -} - -func unboundRunLoop(ctx context.Context, startCh chan struct{}, logger logging.Logger, dnsConf dns.Configurator, - settings settings.DNS, uid, gid int, - waiter command.Waiter, streamMerger command.StreamMerger, httpServer server.Server, -) { - logger = logger.WithPrefix("unbound dns over tls setup: ") - select { - case <-startCh: - case <-ctx.Done(): - logger.Warn("context canceled: exiting unbound run loop") - return - } - if err := fallbackToUnencryptedIPv4DNS(dnsConf, settings.Providers); err != nil { - logger.Error(err) - } - var timer *time.Timer - if settings.UpdatePeriod > 0 { - timer = time.NewTimer(settings.UpdatePeriod) - } - unboundCtx, unboundCancel := context.WithCancel(ctx) - defer unboundCancel() - for ctx.Err() == nil { - var setupErr, startErr, waitErr error - unboundCtx, unboundCancel, setupErr, startErr, waitErr = unboundRun( - ctx, unboundCtx, unboundCancel, timer, dnsConf, settings, - uid, gid, streamMerger, waiter, httpServer) - switch { - case ctx.Err() != nil: - logger.Warn("context canceled: exiting unbound run loop") - case timer != nil && !timer.Stop(): - logger.Info("planned restart of unbound") - case unboundCtx.Err() == context.Canceled: - logger.Info("triggered restart of unbound") - case setupErr != nil: - logger.Warn(setupErr) - case startErr != nil: - logger.Error(startErr) - unboundCancel() - if err := fallbackToUnencryptedIPv4DNS(dnsConf, settings.Providers); err != nil { - logger.Error(err) - } - case waitErr != nil: - logger.Warn(waitErr) - if err := fallbackToUnencryptedIPv4DNS(dnsConf, settings.Providers); err != nil { - logger.Error(err) - } - logger.Warn("restarting unbound because of unexpected exit") - } - } -} - func setupPortForwarding(logger logging.Logger, providerConf provider.Provider, filepath models.Filepath, uid, gid int) { pfLogger := logger.WithPrefix("port forwarding: ") var port uint16 diff --git a/internal/dns/loop.go b/internal/dns/loop.go new file mode 100644 index 00000000..ecd7cc86 --- /dev/null +++ b/internal/dns/loop.go @@ -0,0 +1,183 @@ +package dns + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/qdm12/golibs/command" + "github.com/qdm12/golibs/logging" + "github.com/qdm12/private-internet-access-docker/internal/constants" + "github.com/qdm12/private-internet-access-docker/internal/settings" +) + +type Looper interface { + Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) + RunRestartTicker(ctx context.Context, restart chan<- struct{}) +} + +type looper struct { + conf Configurator + settings settings.DNS + logger logging.Logger + streamMerger command.StreamMerger + uid int + gid int +} + +func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger, + streamMerger command.StreamMerger, uid, gid int) Looper { + return &looper{ + conf: conf, + settings: settings, + logger: logger.WithPrefix("dns over tls: "), + uid: uid, + gid: gid, + streamMerger: streamMerger, + } +} + +func (l *looper) attemptingRestart(err error) { + l.logger.Warn(err) + l.logger.Info("attempting restart in 10 seconds") + time.Sleep(10 * time.Second) +} + +func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) { + l.fallbackToUnencryptedDNS() + select { + case <-restart: + case <-ctx.Done(): + close(done) + return + } + _, unboundCancel := context.WithCancel(ctx) + for { + if !l.settings.Enabled { + // wait for another restart signal to recheck if it is enabled + select { + case <-restart: + case <-ctx.Done(): + unboundCancel() + close(done) + return + } + } + if ctx.Err() == context.Canceled { + unboundCancel() + close(done) + return + } + + // Setup + if err := l.conf.DownloadRootHints(l.uid, l.gid); err != nil { + l.attemptingRestart(err) + continue + } + if err := l.conf.DownloadRootKey(l.uid, l.gid); err != nil { + l.attemptingRestart(err) + continue + } + if err := l.conf.MakeUnboundConf(l.settings, l.uid, l.gid); err != nil { + l.attemptingRestart(err) + continue + } + + // Start command + unboundCancel() + unboundCtx, unboundCancel := context.WithCancel(ctx) + stream, waitFn, err := l.conf.Start(unboundCtx, l.settings.VerbosityDetailsLevel) + if err != nil { + unboundCancel() + l.fallbackToUnencryptedDNS() + l.attemptingRestart(err) + } + + // Started successfully + go l.streamMerger.Merge(unboundCtx, stream, + command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound())) + l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound + if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound + unboundCancel() + l.fallbackToUnencryptedDNS() + l.attemptingRestart(err) + } + if err := l.conf.WaitForUnbound(); err != nil { + unboundCancel() + l.fallbackToUnencryptedDNS() + l.attemptingRestart(err) + } + waitError := make(chan error) + go func() { + err := waitFn() // blocking + if unboundCtx.Err() != context.Canceled { + waitError <- fmt.Errorf("unbound: %w", err) + } + }() + + // Wait for one of the three cases below + select { + case <-ctx.Done(): + l.logger.Warn("context canceled: exiting loop") + unboundCancel() + close(waitError) + close(done) + return + case <-restart: // triggered restart + unboundCancel() + close(waitError) + l.logger.Info("restarting") + case err := <-waitError: // unexpected error + unboundCancel() + close(waitError) + l.fallbackToUnencryptedDNS() + l.attemptingRestart(err) + } + } +} + +func (l *looper) fallbackToUnencryptedDNS() { + // Try with user provided plaintext ip address + targetIP := l.settings.PlaintextAddress + if targetIP != nil { + l.conf.UseDNSInternally(targetIP) + if err := l.conf.UseDNSSystemWide(targetIP); err != nil { + l.logger.Error(err) + } + return + } + + // Try with any IPv4 address from the providers chosen + for _, provider := range l.settings.Providers { + data := constants.DNSProviderMapping()[provider] + for _, targetIP = range data.IPs { + if targetIP.To4() != nil { + l.conf.UseDNSInternally(targetIP) + if err := l.conf.UseDNSSystemWide(targetIP); err != nil { + l.logger.Error(err) + } + return + } + } + } + + // No IPv4 address found + l.logger.Error("no ipv4 DNS address found for providers %s", l.settings.Providers) +} + +func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{}) { + if l.settings.UpdatePeriod == 0 { + return + } + ticker := time.NewTicker(l.settings.UpdatePeriod) + for { + select { + case <-ctx.Done(): + ticker.Stop() + return + case <-ticker.C: + restart <- struct{}{} + } + } +} diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go new file mode 100644 index 00000000..534230d4 --- /dev/null +++ b/internal/openvpn/loop.go @@ -0,0 +1,76 @@ +package openvpn + +import ( + "context" + "fmt" + "time" + + "github.com/qdm12/golibs/command" + "github.com/qdm12/golibs/logging" + "github.com/qdm12/private-internet-access-docker/internal/constants" + "github.com/qdm12/private-internet-access-docker/internal/settings" +) + +type Looper interface { + Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) +} + +type looper struct { + conf Configurator + settings settings.OpenVPN + logger logging.Logger + streamMerger command.StreamMerger + fatalOnError func(err error) +} + +func NewLooper(conf Configurator, settings settings.OpenVPN, logger logging.Logger, + streamMerger command.StreamMerger, fatalOnError func(err error)) Looper { + return &looper{ + conf: conf, + settings: settings, + logger: logger.WithPrefix("openvpn: "), + streamMerger: streamMerger, + fatalOnError: fatalOnError, + } +} + +func (l *looper) Run(ctx context.Context, restart <-chan struct{}, done chan<- struct{}) { + select { + case <-restart: + case <-ctx.Done(): + close(done) + return + } + for { + openvpnCtx, openvpnCancel := context.WithCancel(ctx) + stream, waitFn, err := l.conf.Start(openvpnCtx) + l.fatalOnError(err) + go l.streamMerger.Merge(openvpnCtx, stream, + command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn())) + waitError := make(chan error) + go func() { + err := waitFn() // blocking + if openvpnCtx.Err() != context.Canceled { + waitError <- fmt.Errorf("openvpn: %w", err) + } + }() + select { + case <-ctx.Done(): + l.logger.Warn("context canceled: exiting loop") + openvpnCancel() + close(waitError) + close(done) + return + case <-restart: // triggered restart + l.logger.Info("restarting") + openvpnCancel() + close(waitError) + case err := <-waitError: // unexpected error + l.logger.Warn(err) + l.logger.Info("restarting") + openvpnCancel() + close(waitError) + time.Sleep(time.Second) + } + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 324c2123..f05432b3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -4,55 +4,37 @@ import ( "context" "fmt" "net/http" - "sync" "time" "github.com/qdm12/golibs/logging" ) type Server interface { - SetOpenVPNRestart(f func()) - SetUnboundRestart(f func()) - Run(ctx context.Context) error + Run(ctx context.Context, serverDone chan struct{}) } type server struct { - address string - logger logging.Logger - restartOpenvpn func() - restartOpenvpnSet context.Context - restartOpenvpnSetSignal func() - restartUnbound func() - restartUnboundSet context.Context - restartUnboundSetSignal func() - sync.RWMutex + address string + logger logging.Logger + restartOpenvpn chan<- struct{} + restartUnbound chan<- struct{} } -func New(address string, logger logging.Logger) Server { - restartOpenvpnSet, restartOpenvpnSetSignal := context.WithCancel(context.Background()) - restartUnboundSet, restartUnboundSetSignal := context.WithCancel(context.Background()) +func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound chan<- struct{}) Server { return &server{ - address: address, - logger: logger.WithPrefix("http server: "), - restartOpenvpnSet: restartOpenvpnSet, - restartOpenvpnSetSignal: restartOpenvpnSetSignal, - restartUnboundSet: restartUnboundSet, - restartUnboundSetSignal: restartUnboundSetSignal, + address: address, + logger: logger.WithPrefix("http server: "), + restartOpenvpn: restartOpenvpn, + restartUnbound: restartUnbound, } } -func (s *server) Run(ctx context.Context) error { - if s.restartOpenvpnSet.Err() == nil { - s.logger.Warn("restartOpenvpn function is not set, waiting...") - <-s.restartOpenvpnSet.Done() - } - if s.restartUnboundSet.Err() == nil { - s.logger.Warn("restartUnbound function is not set, waiting...") - <-s.restartUnboundSet.Done() - } +func (s *server) Run(ctx context.Context, serverDone chan struct{}) { server := http.Server{Addr: s.address, Handler: s.makeHandler()} go func() { + defer close(serverDone) <-ctx.Done() + s.logger.Warn("context canceled: exiting loop") shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() if err := server.Shutdown(shutdownCtx); err != nil { @@ -60,24 +42,9 @@ func (s *server) Run(ctx context.Context) error { } }() s.logger.Info("listening on %s", s.address) - return server.ListenAndServe() -} - -func (s *server) SetOpenVPNRestart(f func()) { - s.Lock() - defer s.Unlock() - s.restartOpenvpn = f - if s.restartOpenvpnSet.Err() == nil { - s.restartOpenvpnSetSignal() - } -} - -func (s *server) SetUnboundRestart(f func()) { - s.Lock() - defer s.Unlock() - s.restartUnbound = f - if s.restartUnboundSet.Err() == nil { - s.restartUnboundSetSignal() + err := server.ListenAndServe() + if err != nil && ctx.Err() != context.Canceled { + s.logger.Error(err) } } @@ -88,13 +55,9 @@ func (s *server) makeHandler() http.HandlerFunc { case http.MethodGet: switch r.RequestURI { case "/openvpn/actions/restart": - s.RLock() - defer s.RUnlock() - s.restartOpenvpn() + s.restartOpenvpn <- struct{}{} case "/unbound/actions/restart": - s.RLock() - defer s.RUnlock() - s.restartUnbound() + s.restartUnbound <- struct{}{} default: routeDoesNotExist(s.logger, w, r) }