Fix several async issues

- race conditions between ctx.Done and waitError channel
- Sleep for retry cancels on cancelation of context
- Stops the any loop at the start if the context was canceled
- Mentions when loops exit
- Wait for errors on triggered loop restarts
This commit is contained in:
Quentin McGaw
2020-07-11 20:59:30 +00:00
parent 1ac06ee4a8
commit ccf11990f1
5 changed files with 67 additions and 48 deletions

View File

@@ -38,10 +38,12 @@ func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
} }
} }
func (l *looper) attemptingRestart(err error) { func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Warn(err) l.logger.Warn(err)
l.logger.Info("attempting restart in 10 seconds") l.logger.Info("attempting restart in 10 seconds")
time.Sleep(10 * time.Second) ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
<-ctx.Done()
} }
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) { func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) {
@@ -53,8 +55,13 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
case <-ctx.Done(): case <-ctx.Done():
return return
} }
_, unboundCancel := context.WithCancel(ctx) defer l.logger.Warn("loop exited")
for {
var unboundCtx context.Context
var unboundCancel context.CancelFunc = func() {}
var waitError chan error
triggeredRestart := false
for ctx.Err() == nil {
if !l.settings.Enabled { if !l.settings.Enabled {
// wait for another restart signal to recheck if it is enabled // wait for another restart signal to recheck if it is enabled
select { select {
@@ -64,33 +71,34 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
return return
} }
} }
if ctx.Err() == context.Canceled {
unboundCancel()
return
}
// Setup // Setup
if err := l.conf.DownloadRootHints(l.uid, l.gid); err != nil { if err := l.conf.DownloadRootHints(l.uid, l.gid); err != nil {
l.attemptingRestart(err) l.logAndWait(ctx, err)
continue continue
} }
if err := l.conf.DownloadRootKey(l.uid, l.gid); err != nil { if err := l.conf.DownloadRootKey(l.uid, l.gid); err != nil {
l.attemptingRestart(err) l.logAndWait(ctx, err)
continue continue
} }
if err := l.conf.MakeUnboundConf(l.settings, l.uid, l.gid); err != nil { if err := l.conf.MakeUnboundConf(l.settings, l.uid, l.gid); err != nil {
l.attemptingRestart(err) l.logAndWait(ctx, err)
continue continue
} }
// Start command if triggeredRestart {
unboundCancel() triggeredRestart = false
unboundCtx, unboundCancel := context.WithCancel(ctx) unboundCancel()
<-waitError
close(waitError)
}
unboundCtx, unboundCancel = context.WithCancel(context.Background())
stream, waitFn, err := l.conf.Start(unboundCtx, l.settings.VerbosityDetailsLevel) stream, waitFn, err := l.conf.Start(unboundCtx, l.settings.VerbosityDetailsLevel)
if err != nil { if err != nil {
unboundCancel() unboundCancel()
l.fallbackToUnencryptedDNS() l.fallbackToUnencryptedDNS()
l.attemptingRestart(err) l.logAndWait(ctx, err)
continue
} }
// Started successfully // Started successfully
@@ -98,16 +106,15 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound())) command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound()))
l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound l.conf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound if err := l.conf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound
unboundCancel() l.logger.Error(err)
l.fallbackToUnencryptedDNS()
l.attemptingRestart(err)
} }
if err := l.conf.WaitForUnbound(); err != nil { if err := l.conf.WaitForUnbound(); err != nil {
unboundCancel() unboundCancel()
l.fallbackToUnencryptedDNS() l.fallbackToUnencryptedDNS()
l.attemptingRestart(err) l.logAndWait(ctx, err)
continue
} }
waitError := make(chan error) waitError = make(chan error)
go func() { go func() {
err := waitFn() // blocking err := waitFn() // blocking
waitError <- err waitError <- err
@@ -122,16 +129,17 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
close(waitError) close(waitError)
return return
case <-restart: // triggered restart case <-restart: // triggered restart
unboundCancel()
close(waitError)
l.logger.Info("restarting") l.logger.Info("restarting")
// unboundCancel occurs next loop run when the setup is complete
triggeredRestart = true
case err := <-waitError: // unexpected error case err := <-waitError: // unexpected error
unboundCancel()
close(waitError) close(waitError)
unboundCancel()
l.fallbackToUnencryptedDNS() l.fallbackToUnencryptedDNS()
l.attemptingRestart(err) l.logAndWait(ctx, err)
} }
} }
unboundCancel()
} }
func (l *looper) fallbackToUnencryptedDNS() { func (l *looper) fallbackToUnencryptedDNS() {

View File

@@ -36,10 +36,12 @@ func NewLooper(client network.Client, logger logging.Logger, fileManager files.F
} }
} }
func (l *looper) logAndWait(err error) { func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Error(err) l.logger.Error(err)
l.logger.Info("retrying in 5 seconds") l.logger.Info("retrying in 5 seconds")
time.Sleep(5 * time.Second) ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() // just for the linter
<-ctx.Done()
} }
func (l *looper) Run(ctx context.Context, restart <-chan struct{}) { func (l *looper) Run(ctx context.Context, restart <-chan struct{}) {
@@ -48,10 +50,12 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}) {
case <-ctx.Done(): case <-ctx.Done():
return return
} }
for { defer l.logger.Warn("loop exited")
for ctx.Err() == nil {
ip, err := l.getter.Get() ip, err := l.getter.Get()
if err != nil { if err != nil {
l.logAndWait(err) l.logAndWait(ctx, err)
continue continue
} }
l.logger.Info("Public IP address is %s", ip) l.logger.Info("Public IP address is %s", ip)
@@ -61,7 +65,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}) {
files.Ownership(l.uid, l.gid), files.Ownership(l.uid, l.gid),
files.Permissions(0600)) files.Permissions(0600))
if err != nil { if err != nil {
l.logAndWait(err) l.logAndWait(ctx, err)
continue continue
} }
select { select {

View File

@@ -37,6 +37,7 @@ func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
<-ctx.Done() <-ctx.Done()
s.logger.Warn("context canceled: exiting loop") s.logger.Warn("context canceled: exiting loop")
defer s.logger.Warn("loop exited")
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil { if err := server.Shutdown(shutdownCtx); err != nil {

View File

@@ -27,10 +27,12 @@ type looper struct {
gid int gid int
} }
func (l *looper) logAndWait(err error) { func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Error(err) l.logger.Error(err)
l.logger.Info("retrying in 1 minute") l.logger.Info("retrying in 1 minute")
time.Sleep(time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() // just for the linter
<-ctx.Done()
} }
func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.ShadowSocks, dnsSettings settings.DNS, func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.ShadowSocks, dnsSettings settings.DNS,
@@ -55,7 +57,9 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
case <-ctx.Done(): case <-ctx.Done():
return return
} }
for { defer l.logger.Warn("loop exited")
for ctx.Err() == nil {
nameserver := l.dnsSettings.PlaintextAddress.String() nameserver := l.dnsSettings.PlaintextAddress.String()
if l.dnsSettings.Enabled { if l.dnsSettings.Enabled {
nameserver = "127.0.0.1" nameserver = "127.0.0.1"
@@ -68,7 +72,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
l.uid, l.uid,
l.gid) l.gid)
if err != nil { if err != nil {
l.logAndWait(err) l.logAndWait(ctx, err)
continue continue
} }
err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port) err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port)
@@ -76,11 +80,11 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
if err != nil { if err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
shadowsocksCtx, shadowsocksCancel := context.WithCancel(ctx) shadowsocksCtx, shadowsocksCancel := context.WithCancel(context.Background())
stdout, stderr, waitFn, err := l.conf.Start(ctx, "0.0.0.0", l.settings.Port, l.settings.Password, l.settings.Log) stdout, stderr, waitFn, err := l.conf.Start(shadowsocksCtx, "0.0.0.0", l.settings.Port, l.settings.Password, l.settings.Log)
if err != nil { if err != nil {
shadowsocksCancel() shadowsocksCancel()
l.logAndWait(err) l.logAndWait(ctx, err)
continue continue
} }
go l.streamMerger.Merge(shadowsocksCtx, stdout, go l.streamMerger.Merge(shadowsocksCtx, stdout,
@@ -102,13 +106,12 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
case <-restart: // triggered restart case <-restart: // triggered restart
l.logger.Info("restarting") l.logger.Info("restarting")
shadowsocksCancel() shadowsocksCancel()
<-waitError
close(waitError) close(waitError)
case err := <-waitError: // unexpected error case err := <-waitError: // unexpected error
l.logger.Warn(err)
l.logger.Info("restarting")
shadowsocksCancel() shadowsocksCancel()
close(waitError) close(waitError)
time.Sleep(time.Second) l.logAndWait(ctx, err)
} }
} }
} }

View File

@@ -26,10 +26,12 @@ type looper struct {
gid int gid int
} }
func (l *looper) logAndWait(err error) { func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Error(err) l.logger.Error(err)
l.logger.Info("retrying in 1 minute") l.logger.Info("retrying in 1 minute")
time.Sleep(time.Minute) ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel() // just for the linter
<-ctx.Done()
} }
func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.TinyProxy, func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings settings.TinyProxy,
@@ -53,7 +55,9 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
case <-ctx.Done(): case <-ctx.Done():
return return
} }
for { defer l.logger.Warn("loop exited")
for ctx.Err() == nil {
err := l.conf.MakeConf( err := l.conf.MakeConf(
l.settings.LogLevel, l.settings.LogLevel,
l.settings.Port, l.settings.Port,
@@ -62,7 +66,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
l.uid, l.uid,
l.gid) l.gid)
if err != nil { if err != nil {
l.logAndWait(err) l.logAndWait(ctx, err)
continue continue
} }
err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port) err = l.firewallConf.AllowAnyIncomingOnPort(ctx, l.settings.Port)
@@ -70,11 +74,11 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
if err != nil { if err != nil {
l.logger.Error(err) l.logger.Error(err)
} }
tinyproxyCtx, tinyproxyCancel := context.WithCancel(ctx) tinyproxyCtx, tinyproxyCancel := context.WithCancel(context.Background())
stream, waitFn, err := l.conf.Start(tinyproxyCtx) stream, waitFn, err := l.conf.Start(tinyproxyCtx)
if err != nil { if err != nil {
tinyproxyCancel() tinyproxyCancel()
l.logAndWait(err) l.logAndWait(ctx, err)
continue continue
} }
go l.streamMerger.Merge(tinyproxyCtx, stream, go l.streamMerger.Merge(tinyproxyCtx, stream,
@@ -94,13 +98,12 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
case <-restart: // triggered restart case <-restart: // triggered restart
l.logger.Info("restarting") l.logger.Info("restarting")
tinyproxyCancel() tinyproxyCancel()
<-waitError
close(waitError) close(waitError)
case err := <-waitError: // unexpected error case err := <-waitError: // unexpected error
l.logger.Warn(err)
l.logger.Info("restarting")
tinyproxyCancel() tinyproxyCancel()
close(waitError) close(waitError)
time.Sleep(time.Second) l.logAndWait(ctx, err)
} }
} }
} }