Maint: minor DNS loop fixes and changes

This commit is contained in:
Quentin McGaw (desktop)
2021-07-16 21:21:09 +00:00
parent 39068dda17
commit 7e50c95823
2 changed files with 11 additions and 15 deletions

View File

@@ -36,6 +36,7 @@ type looper struct {
blockBuilder blacklist.Builder blockBuilder blacklist.Builder
client *http.Client client *http.Client
logger logging.Logger logger logging.Logger
userTrigger bool
start <-chan struct{} start <-chan struct{}
running chan<- models.LoopStatus running chan<- models.LoopStatus
stop <-chan struct{} stop <-chan struct{}
@@ -65,6 +66,7 @@ func NewLooper(conf unbound.Configurator, settings configuration.DNS, client *ht
blockBuilder: blacklist.NewBuilder(client), blockBuilder: blacklist.NewBuilder(client),
client: client, client: client,
logger: logger, logger: logger,
userTrigger: true,
start: start, start: start,
running: running, running: running,
stop: stop, stop: stop,
@@ -93,9 +95,9 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
} }
} }
func (l *looper) signalOrSetStatus(userTriggered *bool, status models.LoopStatus) { func (l *looper) signalOrSetStatus(status models.LoopStatus) {
if *userTriggered { if l.userTrigger {
*userTriggered = false l.userTrigger = false
select { select {
case l.running <- status: case l.running <- status:
default: // receiver droppped out - avoid deadlock on events routing when shutting down default: // receiver droppped out - avoid deadlock on events routing when shutting down
@@ -118,8 +120,6 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
return return
} }
userTriggered := true
for ctx.Err() == nil { for ctx.Err() == nil {
// Upper scope variables for Unbound only // Upper scope variables for Unbound only
// Their values are to be used if DOT=off // Their values are to be used if DOT=off
@@ -133,11 +133,11 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
if err == nil { if err == nil {
l.backoffTime = defaultBackoffTime l.backoffTime = defaultBackoffTime
l.logger.Info("ready") l.logger.Info("ready")
l.signalOrSetStatus(&userTriggered, constants.Running) l.signalOrSetStatus(constants.Running)
break break
} }
l.signalOrSetStatus(&userTriggered, constants.Crashed) l.signalOrSetStatus(constants.Crashed)
if ctx.Err() != nil { if ctx.Err() != nil {
return return
@@ -155,7 +155,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
l.useUnencryptedDNS(fallback) l.useUnencryptedDNS(fallback)
} }
userTriggered = false l.userTrigger = false
stayHere := true stayHere := true
for stayHere { for stayHere {
@@ -167,7 +167,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
closeStreams() closeStreams()
return return
case <-l.stop: case <-l.stop:
userTriggered = true l.userTrigger = true
l.logger.Info("stopping") l.logger.Info("stopping")
const fallback = false const fallback = false
l.useUnencryptedDNS(fallback) l.useUnencryptedDNS(fallback)
@@ -178,23 +178,19 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) {
closeStreams() closeStreams()
l.stopped <- struct{}{} l.stopped <- struct{}{}
case <-l.start: case <-l.start:
userTriggered = true l.userTrigger = true
l.logger.Info("starting") l.logger.Info("starting")
stayHere = false stayHere = false
case err := <-waitError: // unexpected error case err := <-waitError: // unexpected error
close(waitError) close(waitError)
closeStreams() closeStreams()
l.state.Lock() // prevent SetStatus from running in parallel
unboundCancel() unboundCancel()
l.state.SetStatus(constants.Crashed) l.state.SetStatus(constants.Crashed)
const fallback = true const fallback = true
l.useUnencryptedDNS(fallback) l.useUnencryptedDNS(fallback)
l.logAndWait(ctx, err) l.logAndWait(ctx, err)
stayHere = false stayHere = false
l.state.Unlock()
} }
} }
} }

View File

@@ -120,7 +120,7 @@ func (s *state) ApplyStatus(ctx context.Context, status models.LoopStatus) (
} }
s.SetStatus(newStatus) s.SetStatus(newStatus)
return status.String(), nil return newStatus.String(), nil
default: default:
return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s", return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s",
ErrInvalidStatus, status, constants.Running, constants.Stopped) ErrInvalidStatus, status, constants.Running, constants.Stopped)