diff --git a/internal/dns/loop.go b/internal/dns/loop.go index c732ebf9..aa90664c 100644 --- a/internal/dns/loop.go +++ b/internal/dns/loop.go @@ -10,6 +10,8 @@ import ( "github.com/qdm12/dns/pkg/unbound" "github.com/qdm12/gluetun/internal/configuration" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/dns/state" + "github.com/qdm12/gluetun/internal/loopstate" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/golibs/logging" ) @@ -19,26 +21,28 @@ var _ Looper = (*Loop)(nil) type Looper interface { Runner RestartTickerRunner - StatusGetterApplier + loopstate.Applier + loopstate.Getter SettingsGetterSetter } type Loop struct { - state *state - conf unbound.Configurator - resolvConf string - blockBuilder blacklist.Builder - client *http.Client - logger logging.Logger - userTrigger bool - start <-chan struct{} - running chan<- models.LoopStatus - stop <-chan struct{} - stopped chan<- struct{} - updateTicker <-chan struct{} - backoffTime time.Duration - timeNow func() time.Time - timeSince func(time.Time) time.Duration + statusManager loopstate.Manager + state state.Manager + conf unbound.Configurator + resolvConf string + blockBuilder blacklist.Builder + client *http.Client + logger logging.Logger + userTrigger bool + start <-chan struct{} + running chan<- models.LoopStatus + stop <-chan struct{} + stopped chan<- struct{} + updateTicker <-chan struct{} + backoffTime time.Duration + timeNow func() time.Time + timeSince func(time.Time) time.Duration } const defaultBackoffTime = 10 * time.Second @@ -51,24 +55,26 @@ func NewLoop(conf unbound.Configurator, settings configuration.DNS, client *http stopped := make(chan struct{}) updateTicker := make(chan struct{}) - state := newState(constants.Stopped, settings, start, running, stop, stopped, updateTicker) + statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped) + state := state.New(statusManager, settings, updateTicker) return &Loop{ - state: state, - conf: conf, - resolvConf: "/etc/resolv.conf", - blockBuilder: blacklist.NewBuilder(client), - client: client, - logger: logger, - userTrigger: true, - start: start, - running: running, - stop: stop, - stopped: stopped, - updateTicker: updateTicker, - backoffTime: defaultBackoffTime, - timeNow: time.Now, - timeSince: time.Since, + statusManager: statusManager, + state: state, + conf: conf, + resolvConf: "/etc/resolv.conf", + blockBuilder: blacklist.NewBuilder(client), + client: client, + logger: logger, + userTrigger: true, + start: start, + running: running, + stop: stop, + stopped: stopped, + updateTicker: updateTicker, + backoffTime: defaultBackoffTime, + timeNow: time.Now, + timeSince: time.Since, } } @@ -96,6 +102,6 @@ func (l *Loop) signalOrSetStatus(status models.LoopStatus) { default: // receiver dropped out - avoid deadlock on events routing when shutting down } } else { - l.state.SetStatus(status) + l.statusManager.SetStatus(status) } } diff --git a/internal/dns/run.go b/internal/dns/run.go index 58ddfa3f..0838faa2 100644 --- a/internal/dns/run.go +++ b/internal/dns/run.go @@ -90,7 +90,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { closeStreams() unboundCancel() - l.state.SetStatus(constants.Crashed) + l.statusManager.SetStatus(constants.Crashed) const fallback = true l.useUnencryptedDNS(fallback) l.logAndWait(ctx, err) diff --git a/internal/dns/state.go b/internal/dns/state.go deleted file mode 100644 index 9b214905..00000000 --- a/internal/dns/state.go +++ /dev/null @@ -1,164 +0,0 @@ -package dns - -import ( - "context" - "errors" - "fmt" - "reflect" - "sync" - - "github.com/qdm12/gluetun/internal/configuration" - "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/gluetun/internal/models" -) - -func newState(status models.LoopStatus, settings configuration.DNS, - start chan<- struct{}, running <-chan models.LoopStatus, - stop chan<- struct{}, stopped <-chan struct{}, - updateTicker chan<- struct{}) *state { - return &state{ - status: status, - settings: settings, - start: start, - running: running, - stop: stop, - stopped: stopped, - updateTicker: updateTicker, - } -} - -type state struct { - loopMu sync.RWMutex - - status models.LoopStatus - statusMu sync.RWMutex - - settings configuration.DNS - settingsMu sync.RWMutex - - start chan<- struct{} - running <-chan models.LoopStatus - stop chan<- struct{} - stopped <-chan struct{} - - updateTicker chan<- struct{} -} - -func (s *state) Lock() { s.loopMu.Lock() } -func (s *state) Unlock() { s.loopMu.Unlock() } - -// SetStatus sets the status thread safely. -// It should only be called by the loop internal code since -// it does not interact with the loop code directly. -func (s *state) SetStatus(status models.LoopStatus) { - s.statusMu.Lock() - defer s.statusMu.Unlock() - s.status = status -} - -// GetStatus gets the status thread safely. -func (s *state) GetStatus() (status models.LoopStatus) { - s.statusMu.RLock() - defer s.statusMu.RUnlock() - return s.status -} - -var ErrInvalidStatus = errors.New("invalid status") - -// ApplyStatus sends signals to the running loop depending on the -// current status and status requested, such that its next status -// matches the requested one. It is thread safe and a synchronous call -// since it waits to the loop to fully change its status. -func (s *state) ApplyStatus(ctx context.Context, status models.LoopStatus) ( - outcome string, err error) { - // prevent simultaneous loop changes by restricting - // multiple SetStatus calls to run sequentially. - s.loopMu.Lock() - defer s.loopMu.Unlock() - - // not a read lock as we want to modify it eventually in - // the code below before any other call. - s.statusMu.Lock() - existingStatus := s.status - - switch status { - case constants.Running: - if existingStatus != constants.Stopped { - // starting, running, stopping, crashed - s.statusMu.Unlock() - return "already " + existingStatus.String(), nil - } - - s.status = constants.Starting - s.statusMu.Unlock() - s.start <- struct{}{} - - // Wait for the loop to react to the start signal - newStatus := constants.Starting // for canceled context - select { - case <-ctx.Done(): - case newStatus = <-s.running: - } - s.SetStatus(newStatus) - - return newStatus.String(), nil - case constants.Stopped: - if existingStatus != constants.Running { - return "already " + existingStatus.String(), nil - } - - s.status = constants.Stopping - s.statusMu.Unlock() - s.stop <- struct{}{} - - // Wait for the loop to react to the stop signal - newStatus := constants.Stopping // for canceled context - select { - case <-ctx.Done(): - case <-s.stopped: - newStatus = constants.Stopped - } - s.SetStatus(newStatus) - - return newStatus.String(), nil - default: - return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s", - ErrInvalidStatus, status, constants.Running, constants.Stopped) - } -} - -func (s *state) GetSettings() (settings configuration.DNS) { - s.settingsMu.RLock() - defer s.settingsMu.RUnlock() - return s.settings -} - -func (s *state) SetSettings(ctx context.Context, settings configuration.DNS) ( - outcome string) { - s.settingsMu.Lock() - defer s.settingsMu.Unlock() - - settingsUnchanged := reflect.DeepEqual(s.settings, settings) - if settingsUnchanged { - return "settings left unchanged" - } - - // Check for only update period change - tempSettings := s.settings - tempSettings.UpdatePeriod = settings.UpdatePeriod - onlyUpdatePeriodChanged := reflect.DeepEqual(tempSettings, settings) - - s.settings = settings - - if onlyUpdatePeriodChanged { - s.updateTicker <- struct{}{} - return "update period changed" - } - - // Restart - _, _ = s.ApplyStatus(ctx, constants.Stopped) - if settings.Enabled { - outcome, _ = s.ApplyStatus(ctx, constants.Running) - } - return outcome -} diff --git a/internal/dns/state/settings.go b/internal/dns/state/settings.go new file mode 100644 index 00000000..af5ef272 --- /dev/null +++ b/internal/dns/state/settings.go @@ -0,0 +1,51 @@ +package state + +import ( + "context" + "reflect" + + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/constants" +) + +type SettingsGetterSetter interface { + GetSettings() (settings configuration.DNS) + SetSettings(ctx context.Context, + settings configuration.DNS) (outcome string) +} + +func (s *State) GetSettings() (settings configuration.DNS) { + s.settingsMu.RLock() + defer s.settingsMu.RUnlock() + return s.settings +} + +func (s *State) SetSettings(ctx context.Context, settings configuration.DNS) ( + outcome string) { + s.settingsMu.Lock() + defer s.settingsMu.Unlock() + + settingsUnchanged := reflect.DeepEqual(s.settings, settings) + if settingsUnchanged { + return "settings left unchanged" + } + + // Check for only update period change + tempSettings := s.settings + tempSettings.UpdatePeriod = settings.UpdatePeriod + onlyUpdatePeriodChanged := reflect.DeepEqual(tempSettings, settings) + + s.settings = settings + + if onlyUpdatePeriodChanged { + s.updateTicker <- struct{}{} + return "update period changed" + } + + // Restart + _, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped) + if settings.Enabled { + outcome, _ = s.statusApplier.ApplyStatus(ctx, constants.Running) + } + return outcome +} diff --git a/internal/dns/state/state.go b/internal/dns/state/state.go new file mode 100644 index 00000000..52679d7e --- /dev/null +++ b/internal/dns/state/state.go @@ -0,0 +1,33 @@ +package state + +import ( + "sync" + + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/loopstate" +) + +var _ Manager = (*State)(nil) + +type Manager interface { + SettingsGetterSetter +} + +func New(statusApplier loopstate.Applier, + settings configuration.DNS, + updateTicker chan<- struct{}) *State { + return &State{ + statusApplier: statusApplier, + settings: settings, + updateTicker: updateTicker, + } +} + +type State struct { + statusApplier loopstate.Applier + + settings configuration.DNS + settingsMu sync.RWMutex + + updateTicker chan<- struct{} +} diff --git a/internal/dns/status.go b/internal/dns/status.go index ee492ab9..d7f7d2c2 100644 --- a/internal/dns/status.go +++ b/internal/dns/status.go @@ -6,23 +6,11 @@ import ( "github.com/qdm12/gluetun/internal/models" ) -type StatusGetterApplier interface { - StatusGetter - StatusApplier -} - -type StatusGetter interface { - GetStatus() (status models.LoopStatus) -} - -func (l *Loop) GetStatus() (status models.LoopStatus) { return l.state.GetStatus() } - -type StatusApplier interface { - ApplyStatus(ctx context.Context, status models.LoopStatus) ( - outcome string, err error) +func (l *Loop) GetStatus() (status models.LoopStatus) { + return l.statusManager.GetStatus() } func (l *Loop) ApplyStatus(ctx context.Context, status models.LoopStatus) ( outcome string, err error) { - return l.state.ApplyStatus(ctx, status) + return l.statusManager.ApplyStatus(ctx, status) } diff --git a/internal/dns/ticker.go b/internal/dns/ticker.go index 4918903b..4eac320f 100644 --- a/internal/dns/ticker.go +++ b/internal/dns/ticker.go @@ -36,15 +36,15 @@ func (l *Loop) RunRestartTicker(ctx context.Context, done chan<- struct{}) { status := l.GetStatus() if status == constants.Running { if err := l.updateFiles(ctx); err != nil { - l.state.SetStatus(constants.Crashed) + l.statusManager.SetStatus(constants.Crashed) l.logger.Error(err.Error()) l.logger.Warn("skipping Unbound restart due to failed files update") continue } } - _, _ = l.ApplyStatus(ctx, constants.Stopped) - _, _ = l.ApplyStatus(ctx, constants.Running) + _, _ = l.statusManager.ApplyStatus(ctx, constants.Stopped) + _, _ = l.statusManager.ApplyStatus(ctx, constants.Running) settings := l.GetSettings() timer.Reset(settings.UpdatePeriod)