diff --git a/internal/dns/loop.go b/internal/dns/loop.go index eb11c4e7..90d0c2ae 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -34,6 +34,8 @@ type looper struct { start chan struct{} stop chan struct{} updateTicker chan struct{} + timeNow func() time.Time + timeSince func(time.Time) time.Duration } func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger, @@ -49,6 +51,8 @@ func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger, start: make(chan struct{}), stop: make(chan struct{}), updateTicker: make(chan struct{}), + timeNow: time.Now, + timeSince: time.Since, } } @@ -285,24 +289,45 @@ func (l *looper) useUnencryptedDNS(fallback bool) { func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - ticker := time.NewTicker(time.Hour) + // Timer that acts as a ticker + timer := time.NewTimer(time.Hour) + timer.Stop() + timerIsStopped := true settings := l.GetSettings() if settings.UpdatePeriod > 0 { - ticker = time.NewTicker(settings.UpdatePeriod) - } else { - ticker.Stop() + timer.Reset(settings.UpdatePeriod) + timerIsStopped = false } + lastTick := time.Unix(0, 0) for { select { case <-ctx.Done(): - ticker.Stop() + if !timerIsStopped && !timer.Stop() { + <-timer.C + } return - case <-ticker.C: + case <-timer.C: + lastTick = l.timeNow() l.restart <- struct{}{} + settings := l.GetSettings() + timer.Reset(settings.UpdatePeriod) case <-l.updateTicker: - ticker.Stop() - period := l.GetSettings().UpdatePeriod - ticker = time.NewTicker(period) + if !timer.Stop() { + <-timer.C + } + timerIsStopped = true + settings := l.GetSettings() + newUpdatePeriod := settings.UpdatePeriod + if newUpdatePeriod == 0 { + continue + } + var waited time.Duration + if lastTick.UnixNano() != 0 { + waited = l.timeSince(lastTick) + } + leftToWait := newUpdatePeriod - waited + timer.Reset(leftToWait) + timerIsStopped = false } } } diff --git a/internal/provider/piav4.go b/internal/provider/piav4.go index bce9352f..5866ab28 100644 --- a/internal/provider/piav4.go +++ b/internal/provider/piav4.go @@ -151,11 +151,9 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client, } expiryTimer := time.NewTimer(durationToExpiration) - defer expiryTimer.Stop() const keepAlivePeriod = 15 * time.Minute - keepAliveTicker := time.NewTicker(keepAlivePeriod) - defer keepAliveTicker.Stop() - + // Timer behaving as a ticker + keepAliveTimer := time.NewTimer(keepAlivePeriod) for { select { case <-ctx.Done(): @@ -164,11 +162,18 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client, if err := fw.RemoveAllowedPort(removeCtx, data.Port); err != nil { pfLogger.Error(err) } + if !keepAliveTimer.Stop() { + <-keepAliveTimer.C + } + if !expiryTimer.Stop() { + <-expiryTimer.C + } return - case <-keepAliveTicker.C: + case <-keepAliveTimer.C: if err := bindPIAPort(client, gateway, data); err != nil { pfLogger.Error(err) } + keepAliveTimer.Reset(keepAlivePeriod) case <-expiryTimer.C: pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123)) oldPort := data.Port @@ -199,7 +204,10 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client, if err := bindPIAPort(client, gateway, data); err != nil { pfLogger.Error(err) } - keepAliveTicker.Reset(keepAlivePeriod) + if !keepAliveTimer.Stop() { + <-keepAliveTimer.C + } + keepAliveTimer.Reset(keepAlivePeriod) expiryTimer.Reset(durationToExpiration) } } diff --git a/internal/publicip/loop.go b/internal/publicip/loop.go index 22a781ec..8fd8fc91 100644 --- a/internal/publicip/loop.go +++ b/internal/publicip/loop.go @@ -32,6 +32,8 @@ type looper struct { restart chan struct{} stop chan struct{} updateTicker chan struct{} + timeNow func() time.Time + timeSince func(time.Time) time.Duration } func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager, @@ -47,6 +49,8 @@ func NewLooper(client network.Client, logger logging.Logger, fileManager files.F restart: make(chan struct{}), stop: make(chan struct{}), updateTicker: make(chan struct{}), + timeNow: time.Now, + timeSince: time.Since, } } @@ -127,23 +131,42 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - ticker := time.NewTicker(time.Hour) + timer := time.NewTimer(time.Hour) + timer.Stop() // 1 hour, cannot be a race condition + timerIsStopped := true period := l.GetPeriod() if period > 0 { - ticker = time.NewTicker(period) - } else { - ticker.Stop() + timer.Reset(period) + timerIsStopped = false } + lastTick := time.Unix(0, 0) for { select { case <-ctx.Done(): - ticker.Stop() + if !timerIsStopped && !timer.Stop() { + <-timer.C + } return - case <-ticker.C: + case <-timer.C: + lastTick = l.timeNow() l.restart <- struct{}{} + timer.Reset(l.GetPeriod()) case <-l.updateTicker: - ticker.Stop() - ticker = time.NewTicker(l.GetPeriod()) + if !timer.Stop() { + <-timer.C + } + timerIsStopped = true + period := l.GetPeriod() + if period == 0 { + continue + } + var waited time.Duration + if lastTick.UnixNano() > 0 { + waited = l.timeSince(lastTick) + } + leftToWait := period - waited + timer.Reset(leftToWait) + timerIsStopped = false } } } diff --git a/internal/updater/loop.go b/internal/updater/loop.go index ef9c578b..608f0cd8 100644 --- a/internal/updater/loop.go +++ b/internal/updater/loop.go @@ -30,6 +30,8 @@ type looper struct { restart chan struct{} stop chan struct{} updateTicker chan struct{} + timeNow func() time.Time + timeSince func(time.Time) time.Duration } func NewLooper(options Options, period time.Duration, currentServers models.AllServers, @@ -45,6 +47,8 @@ func NewLooper(options Options, period time.Duration, currentServers models.AllS restart: make(chan struct{}), stop: make(chan struct{}), updateTicker: make(chan struct{}), + timeNow: time.Now, + timeSince: time.Since, } } @@ -125,23 +129,41 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) { func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() - ticker := time.NewTicker(time.Hour) - period := l.GetPeriod() - if period > 0 { - ticker = time.NewTicker(period) - } else { - ticker.Stop() + timer := time.NewTimer(time.Hour) + timer.Stop() + timerIsStopped := true + if period := l.GetPeriod(); period > 0 { + timerIsStopped = false + timer.Reset(period) } + lastTick := time.Unix(0, 0) for { select { case <-ctx.Done(): - ticker.Stop() + if !timerIsStopped && !timer.Stop() { + <-timer.C + } return - case <-ticker.C: + case <-timer.C: + lastTick = l.timeNow() l.restart <- struct{}{} + timer.Reset(l.GetPeriod()) case <-l.updateTicker: - ticker.Stop() - ticker = time.NewTicker(l.GetPeriod()) + if !timerIsStopped && !timer.Stop() { + <-timer.C + } + timerIsStopped = true + period := l.GetPeriod() + if period == 0 { + continue + } + var waited time.Duration + if lastTick.UnixNano() > 0 { + waited = l.timeSince(lastTick) + } + leftToWait := period - waited + timer.Reset(leftToWait) + timerIsStopped = false } } }