diff --git a/internal/dns/loop.go b/internal/dns/loop.go index 21cb3410..28018f15 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -93,22 +93,58 @@ func (l *looper) logAndWait(ctx context.Context, err error) { <-ctx.Done() } +func (l *looper) waitForFirstStart(ctx context.Context) { + for { + select { + case <-l.stop: + l.setEnabled(false) + l.logger.Info("not started yet") + case <-l.restart: + if l.isEnabled() { + return + } + l.logger.Info("not restarting because disabled") + case <-l.start: + l.setEnabled(true) + return + case <-ctx.Done(): + return + } + } +} + +func (l *looper) waitForSubsequentStart(ctx context.Context, unboundCancel context.CancelFunc) { + if l.isEnabled() { + return + } + for { + // wait for a signal to re-enable + select { + case <-l.stop: + l.logger.Info("already disabled") + case <-l.restart: + if !l.isEnabled() { + l.logger.Info("not restarting because disabled") + } else { + return + } + case <-l.start: + l.setEnabled(true) + return + case <-ctx.Done(): + unboundCancel() + return + } + } +} + func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { wg.Add(1) defer wg.Done() l.fallbackToUnencryptedDNS() - waitForStart := true - for waitForStart { - select { - case <-l.stop: - l.logger.Info("not started yet") - case <-l.restart: - waitForStart = false - case <-l.start: - waitForStart = false - case <-ctx.Done(): - return - } + l.waitForFirstStart(ctx) + if ctx.Err() != nil { + return } defer l.logger.Warn("loop exited") @@ -118,20 +154,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { triggeredRestart := false l.setEnabled(true) for ctx.Err() == nil { - for !l.isEnabled() { - // wait for a signal to re-enable - select { - case <-l.stop: - l.logger.Info("already disabled") - case <-l.restart: - l.setEnabled(true) - case <-l.start: - l.setEnabled(true) - case <-ctx.Done(): - unboundCancel() - return - } - } + l.waitForSubsequentStart(ctx, unboundCancel) settings := l.GetSettings()