Replace explicit channels with functions

This commit is contained in:
Quentin McGaw
2020-07-15 01:34:46 +00:00
parent 8c7c8f7d5a
commit 616ba0c538
7 changed files with 94 additions and 68 deletions

View File

@@ -18,7 +18,9 @@ import (
)
type Looper interface {
Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup)
Run(ctx context.Context, wg *sync.WaitGroup)
Restart()
PortForward()
}
type looper struct {
@@ -37,6 +39,9 @@ type looper struct {
fileManager files.FileManager
streamMerger command.StreamMerger
fatalOnError func(err error)
// Internal channels
restart chan struct{}
portForwardSignals chan struct{}
}
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
@@ -45,25 +50,30 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
logger logging.Logger, client network.Client, fileManager files.FileManager,
streamMerger command.StreamMerger, fatalOnError func(err error)) Looper {
return &looper{
provider: provider,
settings: settings,
uid: uid,
gid: gid,
conf: conf,
fw: fw,
logger: logger.WithPrefix("openvpn: "),
client: client,
fileManager: fileManager,
streamMerger: streamMerger,
fatalOnError: fatalOnError,
provider: provider,
settings: settings,
uid: uid,
gid: gid,
conf: conf,
fw: fw,
logger: logger.WithPrefix("openvpn: "),
client: client,
fileManager: fileManager,
streamMerger: streamMerger,
fatalOnError: fatalOnError,
restart: make(chan struct{}),
portForwardSignals: make(chan struct{}),
}
}
func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup) {
func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) PortForward() { l.portForwardSignals <- struct{}{} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
wg.Add(1)
defer wg.Done()
select {
case <-restart:
case <-l.restart:
case <-ctx.Done():
return
}
@@ -107,7 +117,7 @@ func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{},
select {
case <-ctx.Done():
return
case <-portForward:
case <-l.portForwardSignals:
l.portForward(ctx, providerConf, l.client)
}
}
@@ -126,7 +136,7 @@ func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{},
<-waitError
close(waitError)
return
case <-restart: // triggered restart
case <-l.restart: // triggered restart
l.logger.Info("restarting")
openvpnCancel()
<-waitError