Replace explicit channels with functions

This commit is contained in:
Quentin McGaw
2020-07-15 01:34:46 +00:00
parent 8c7c8f7d5a
commit 616ba0c538
7 changed files with 94 additions and 68 deletions

View File

@@ -117,14 +117,6 @@ func _main(background context.Context, args []string) int {
defer close(connectedCh)
go collectStreamLines(ctx, streamMerger, logger, signalConnected)
// TODO replace these with methods on loopers and pass loopers around
restartOpenvpn := make(chan struct{})
portForward := make(chan struct{})
restartUnbound := make(chan struct{})
restartPublicIP := make(chan struct{})
restartTinyproxy := make(chan struct{})
restartShadowsocks := make(chan struct{})
if allSettings.Firewall.Enabled {
err := firewallConf.SetEnabled(ctx, true) // disabled by default
fatalOnError(err)
@@ -135,28 +127,34 @@ func _main(background context.Context, args []string) int {
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid,
ovpnConf, firewallConf, logger, client, fileManager, streamMerger, fatalOnError)
restartOpenvpn := openvpnLooper.Restart
portForward := openvpnLooper.PortForward
// wait for restartOpenvpn
go openvpnLooper.Run(ctx, restartOpenvpn, portForward, wg)
go openvpnLooper.Run(ctx, wg)
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid)
restartUnbound := unboundLooper.Restart
// wait for restartUnbound
go unboundLooper.Run(ctx, restartUnbound, wg)
go unboundLooper.Run(ctx, wg)
publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, uid, gid)
go publicIPLooper.Run(ctx, restartPublicIP)
go publicIPLooper.RunRestartTicker(ctx, restartPublicIP)
restartPublicIP := publicIPLooper.Restart
go publicIPLooper.Run(ctx)
go publicIPLooper.RunRestartTicker(ctx)
tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid)
go tinyproxyLooper.Run(ctx, restartTinyproxy, wg)
restartTinyproxy := tinyproxyLooper.Restart
go tinyproxyLooper.Run(ctx, wg)
shadowsocksLooper := shadowsocks.NewLooper(shadowsocksConf, firewallConf, allSettings.ShadowSocks, allSettings.DNS, logger, streamMerger, uid, gid)
go shadowsocksLooper.Run(ctx, restartShadowsocks, wg)
restartShadowsocks := shadowsocksLooper.Restart
go shadowsocksLooper.Run(ctx, wg)
if allSettings.TinyProxy.Enabled {
restartTinyproxy <- struct{}{}
restartTinyproxy()
}
if allSettings.ShadowSocks.Enabled {
restartShadowsocks <- struct{}{}
restartShadowsocks()
}
go func() {
@@ -170,7 +168,7 @@ func _main(background context.Context, args []string) int {
case <-connectedCh: // blocks until openvpn is connected
restartTickerCancel()
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
go unboundLooper.RunRestartTicker(restartTickerContext, restartUnbound)
go unboundLooper.RunRestartTicker(restartTickerContext)
onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP)
}
}
@@ -180,7 +178,7 @@ func _main(background context.Context, args []string) int {
go httpServer.Run(ctx, wg)
// Start openvpn for the first time
restartOpenvpn <- struct{}{}
restartOpenvpn()
signalsCh := make(chan os.Signal, 1)
signal.Notify(signalsCh,
@@ -291,14 +289,12 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger,
}
func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing,
portForward, restartUnbound, restartPublicIP chan<- struct{},
portForward, restartUnbound, restartPublicIP func(),
) {
restartUnbound <- struct{}{}
restartPublicIP <- struct{}{}
restartUnbound()
restartPublicIP()
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
time.AfterFunc(5*time.Second, func() {
portForward <- struct{}{}
})
time.AfterFunc(5*time.Second, portForward)
}
defaultInterface, _, err := routingConf.DefaultRoute()
if err != nil {

View File

@@ -13,8 +13,9 @@ import (
)
type Looper interface {
Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup)
RunRestartTicker(ctx context.Context, restart chan<- struct{})
Run(ctx context.Context, wg *sync.WaitGroup)
RunRestartTicker(ctx context.Context)
Restart()
}
type looper struct {
@@ -24,6 +25,7 @@ type looper struct {
streamMerger command.StreamMerger
uid int
gid int
restart chan struct{}
}
func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
@@ -35,9 +37,12 @@ func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
uid: uid,
gid: gid,
streamMerger: streamMerger,
restart: make(chan struct{}),
}
}
func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Warn(err)
l.logger.Info("attempting restart in 10 seconds")
@@ -46,12 +51,12 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
<-ctx.Done()
}
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) {
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
wg.Add(1)
defer wg.Done()
l.fallbackToUnencryptedDNS()
select {
case <-restart:
case <-l.restart:
case <-ctx.Done():
return
}
@@ -65,7 +70,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
if !l.settings.Enabled {
// wait for another restart signal to recheck if it is enabled
select {
case <-restart:
case <-l.restart:
case <-ctx.Done():
unboundCancel()
return
@@ -127,7 +132,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
<-waitError
close(waitError)
return
case <-restart: // triggered restart
case <-l.restart: // triggered restart
l.logger.Info("restarting")
// unboundCancel occurs next loop run when the setup is complete
triggeredRestart = true
@@ -172,7 +177,7 @@ func (l *looper) fallbackToUnencryptedDNS() {
l.logger.Error("no ipv4 DNS address found for providers %s", l.settings.Providers)
}
func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{}) {
func (l *looper) RunRestartTicker(ctx context.Context) {
if l.settings.UpdatePeriod == 0 {
return
}
@@ -183,7 +188,7 @@ func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{})
ticker.Stop()
return
case <-ticker.C:
restart <- struct{}{}
l.restart <- struct{}{}
}
}
}

View File

@@ -18,7 +18,9 @@ import (
)
type Looper interface {
Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup)
Run(ctx context.Context, wg *sync.WaitGroup)
Restart()
PortForward()
}
type looper struct {
@@ -37,6 +39,9 @@ type looper struct {
fileManager files.FileManager
streamMerger command.StreamMerger
fatalOnError func(err error)
// Internal channels
restart chan struct{}
portForwardSignals chan struct{}
}
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
@@ -56,14 +61,19 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
fileManager: fileManager,
streamMerger: streamMerger,
fatalOnError: fatalOnError,
restart: make(chan struct{}),
portForwardSignals: make(chan struct{}),
}
}
func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup) {
func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) PortForward() { l.portForwardSignals <- struct{}{} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
wg.Add(1)
defer wg.Done()
select {
case <-restart:
case <-l.restart:
case <-ctx.Done():
return
}
@@ -107,7 +117,7 @@ func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{},
select {
case <-ctx.Done():
return
case <-portForward:
case <-l.portForwardSignals:
l.portForward(ctx, providerConf, l.client)
}
}
@@ -126,7 +136,7 @@ func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{},
<-waitError
close(waitError)
return
case <-restart: // triggered restart
case <-l.restart: // triggered restart
l.logger.Info("restarting")
openvpnCancel()
<-waitError

View File

@@ -11,8 +11,9 @@ import (
)
type Looper interface {
Run(ctx context.Context, restart <-chan struct{})
RunRestartTicker(ctx context.Context, restart chan<- struct{})
Run(ctx context.Context)
RunRestartTicker(ctx context.Context)
Restart()
}
type looper struct {
@@ -22,6 +23,7 @@ type looper struct {
ipStatusFilepath models.Filepath
uid int
gid int
restart chan struct{}
}
func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager,
@@ -33,9 +35,12 @@ func NewLooper(client network.Client, logger logging.Logger, fileManager files.F
ipStatusFilepath: ipStatusFilepath,
uid: uid,
gid: gid,
restart: make(chan struct{}),
}
}
func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) logAndWait(ctx context.Context, err error) {
l.logger.Error(err)
l.logger.Info("retrying in 5 seconds")
@@ -44,9 +49,9 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
<-ctx.Done()
}
func (l *looper) Run(ctx context.Context, restart <-chan struct{}) {
func (l *looper) Run(ctx context.Context) {
select {
case <-restart:
case <-l.restart:
case <-ctx.Done():
return
}
@@ -69,7 +74,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}) {
continue
}
select {
case <-restart: // triggered restart
case <-l.restart: // triggered restart
case <-ctx.Done():
l.logger.Warn("context canceled: exiting loop")
return
@@ -77,7 +82,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}) {
}
}
func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{}) {
func (l *looper) RunRestartTicker(ctx context.Context) {
ticker := time.NewTicker(time.Hour)
for {
select {
@@ -85,7 +90,7 @@ func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{})
ticker.Stop()
return
case <-ticker.C:
restart <- struct{}{}
l.restart <- struct{}{}
}
}
}

View File

@@ -17,11 +17,11 @@ type Server interface {
type server struct {
address string
logger logging.Logger
restartOpenvpn chan<- struct{}
restartUnbound chan<- struct{}
restartOpenvpn func()
restartUnbound func()
}
func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound chan<- struct{}) Server {
func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound func()) Server {
return &server{
address: address,
logger: logger.WithPrefix("http server: "),
@@ -58,9 +58,9 @@ func (s *server) makeHandler() http.HandlerFunc {
case http.MethodGet:
switch r.RequestURI {
case "/openvpn/actions/restart":
s.restartOpenvpn <- struct{}{}
s.restartOpenvpn()
case "/unbound/actions/restart":
s.restartUnbound <- struct{}{}
s.restartUnbound()
default:
routeDoesNotExist(s.logger, w, r)
}

View File

@@ -12,7 +12,8 @@ import (
)
type Looper interface {
Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup)
Run(ctx context.Context, wg *sync.WaitGroup)
Restart()
}
type looper struct {
@@ -24,6 +25,7 @@ type looper struct {
streamMerger command.StreamMerger
uid int
gid int
restart chan struct{}
}
func (l *looper) logAndWait(ctx context.Context, err error) {
@@ -45,14 +47,17 @@ func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings s
streamMerger: streamMerger,
uid: uid,
gid: gid,
restart: make(chan struct{}),
}
}
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) {
func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
wg.Add(1)
defer wg.Done()
select {
case <-restart:
case <-l.restart:
case <-ctx.Done():
return
}
@@ -109,7 +114,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
<-waitError
close(waitError)
return
case <-restart: // triggered restart
case <-l.restart: // triggered restart
l.logger.Info("restarting")
shadowsocksCancel()
<-waitError

View File

@@ -12,7 +12,8 @@ import (
)
type Looper interface {
Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup)
Run(ctx context.Context, wg *sync.WaitGroup)
Restart()
}
type looper struct {
@@ -23,6 +24,7 @@ type looper struct {
streamMerger command.StreamMerger
uid int
gid int
restart chan struct{}
}
func (l *looper) logAndWait(ctx context.Context, err error) {
@@ -43,14 +45,17 @@ func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings s
streamMerger: streamMerger,
uid: uid,
gid: gid,
restart: make(chan struct{}),
}
}
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) {
func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
wg.Add(1)
defer wg.Done()
select {
case <-restart:
case <-l.restart:
case <-ctx.Done():
return
}
@@ -102,7 +107,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
<-waitError
close(waitError)
return
case <-restart: // triggered restart
case <-l.restart: // triggered restart
l.logger.Info("restarting")
tinyproxyCancel()
<-waitError