Replace explicit channels with functions
This commit is contained in:
@@ -117,14 +117,6 @@ func _main(background context.Context, args []string) int {
|
|||||||
defer close(connectedCh)
|
defer close(connectedCh)
|
||||||
go collectStreamLines(ctx, streamMerger, logger, signalConnected)
|
go collectStreamLines(ctx, streamMerger, logger, signalConnected)
|
||||||
|
|
||||||
// TODO replace these with methods on loopers and pass loopers around
|
|
||||||
restartOpenvpn := make(chan struct{})
|
|
||||||
portForward := make(chan struct{})
|
|
||||||
restartUnbound := make(chan struct{})
|
|
||||||
restartPublicIP := make(chan struct{})
|
|
||||||
restartTinyproxy := make(chan struct{})
|
|
||||||
restartShadowsocks := make(chan struct{})
|
|
||||||
|
|
||||||
if allSettings.Firewall.Enabled {
|
if allSettings.Firewall.Enabled {
|
||||||
err := firewallConf.SetEnabled(ctx, true) // disabled by default
|
err := firewallConf.SetEnabled(ctx, true) // disabled by default
|
||||||
fatalOnError(err)
|
fatalOnError(err)
|
||||||
@@ -135,28 +127,34 @@ func _main(background context.Context, args []string) int {
|
|||||||
|
|
||||||
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid,
|
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid,
|
||||||
ovpnConf, firewallConf, logger, client, fileManager, streamMerger, fatalOnError)
|
ovpnConf, firewallConf, logger, client, fileManager, streamMerger, fatalOnError)
|
||||||
|
restartOpenvpn := openvpnLooper.Restart
|
||||||
|
portForward := openvpnLooper.PortForward
|
||||||
// wait for restartOpenvpn
|
// wait for restartOpenvpn
|
||||||
go openvpnLooper.Run(ctx, restartOpenvpn, portForward, wg)
|
go openvpnLooper.Run(ctx, wg)
|
||||||
|
|
||||||
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid)
|
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid)
|
||||||
|
restartUnbound := unboundLooper.Restart
|
||||||
// wait for restartUnbound
|
// wait for restartUnbound
|
||||||
go unboundLooper.Run(ctx, restartUnbound, wg)
|
go unboundLooper.Run(ctx, wg)
|
||||||
|
|
||||||
publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, uid, gid)
|
publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, uid, gid)
|
||||||
go publicIPLooper.Run(ctx, restartPublicIP)
|
restartPublicIP := publicIPLooper.Restart
|
||||||
go publicIPLooper.RunRestartTicker(ctx, restartPublicIP)
|
go publicIPLooper.Run(ctx)
|
||||||
|
go publicIPLooper.RunRestartTicker(ctx)
|
||||||
|
|
||||||
tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid)
|
tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid)
|
||||||
go tinyproxyLooper.Run(ctx, restartTinyproxy, wg)
|
restartTinyproxy := tinyproxyLooper.Restart
|
||||||
|
go tinyproxyLooper.Run(ctx, wg)
|
||||||
|
|
||||||
shadowsocksLooper := shadowsocks.NewLooper(shadowsocksConf, firewallConf, allSettings.ShadowSocks, allSettings.DNS, logger, streamMerger, uid, gid)
|
shadowsocksLooper := shadowsocks.NewLooper(shadowsocksConf, firewallConf, allSettings.ShadowSocks, allSettings.DNS, logger, streamMerger, uid, gid)
|
||||||
go shadowsocksLooper.Run(ctx, restartShadowsocks, wg)
|
restartShadowsocks := shadowsocksLooper.Restart
|
||||||
|
go shadowsocksLooper.Run(ctx, wg)
|
||||||
|
|
||||||
if allSettings.TinyProxy.Enabled {
|
if allSettings.TinyProxy.Enabled {
|
||||||
restartTinyproxy <- struct{}{}
|
restartTinyproxy()
|
||||||
}
|
}
|
||||||
if allSettings.ShadowSocks.Enabled {
|
if allSettings.ShadowSocks.Enabled {
|
||||||
restartShadowsocks <- struct{}{}
|
restartShadowsocks()
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -170,7 +168,7 @@ func _main(background context.Context, args []string) int {
|
|||||||
case <-connectedCh: // blocks until openvpn is connected
|
case <-connectedCh: // blocks until openvpn is connected
|
||||||
restartTickerCancel()
|
restartTickerCancel()
|
||||||
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
||||||
go unboundLooper.RunRestartTicker(restartTickerContext, restartUnbound)
|
go unboundLooper.RunRestartTicker(restartTickerContext)
|
||||||
onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP)
|
onConnected(allSettings, logger, routingConf, portForward, restartUnbound, restartPublicIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -180,7 +178,7 @@ func _main(background context.Context, args []string) int {
|
|||||||
go httpServer.Run(ctx, wg)
|
go httpServer.Run(ctx, wg)
|
||||||
|
|
||||||
// Start openvpn for the first time
|
// Start openvpn for the first time
|
||||||
restartOpenvpn <- struct{}{}
|
restartOpenvpn()
|
||||||
|
|
||||||
signalsCh := make(chan os.Signal, 1)
|
signalsCh := make(chan os.Signal, 1)
|
||||||
signal.Notify(signalsCh,
|
signal.Notify(signalsCh,
|
||||||
@@ -291,14 +289,12 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing,
|
func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing,
|
||||||
portForward, restartUnbound, restartPublicIP chan<- struct{},
|
portForward, restartUnbound, restartPublicIP func(),
|
||||||
) {
|
) {
|
||||||
restartUnbound <- struct{}{}
|
restartUnbound()
|
||||||
restartPublicIP <- struct{}{}
|
restartPublicIP()
|
||||||
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
||||||
time.AfterFunc(5*time.Second, func() {
|
time.AfterFunc(5*time.Second, portForward)
|
||||||
portForward <- struct{}{}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
defaultInterface, _, err := routingConf.DefaultRoute()
|
defaultInterface, _, err := routingConf.DefaultRoute()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -13,8 +13,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup)
|
Run(ctx context.Context, wg *sync.WaitGroup)
|
||||||
RunRestartTicker(ctx context.Context, restart chan<- struct{})
|
RunRestartTicker(ctx context.Context)
|
||||||
|
Restart()
|
||||||
}
|
}
|
||||||
|
|
||||||
type looper struct {
|
type looper struct {
|
||||||
@@ -24,6 +25,7 @@ type looper struct {
|
|||||||
streamMerger command.StreamMerger
|
streamMerger command.StreamMerger
|
||||||
uid int
|
uid int
|
||||||
gid int
|
gid int
|
||||||
|
restart chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
|
func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
|
||||||
@@ -35,9 +37,12 @@ func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
|
|||||||
uid: uid,
|
uid: uid,
|
||||||
gid: gid,
|
gid: gid,
|
||||||
streamMerger: streamMerger,
|
streamMerger: streamMerger,
|
||||||
|
restart: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *looper) Restart() { l.restart <- struct{}{} }
|
||||||
|
|
||||||
func (l *looper) logAndWait(ctx context.Context, err error) {
|
func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||||
l.logger.Warn(err)
|
l.logger.Warn(err)
|
||||||
l.logger.Info("attempting restart in 10 seconds")
|
l.logger.Info("attempting restart in 10 seconds")
|
||||||
@@ -46,12 +51,12 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
|||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) {
|
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
l.fallbackToUnencryptedDNS()
|
l.fallbackToUnencryptedDNS()
|
||||||
select {
|
select {
|
||||||
case <-restart:
|
case <-l.restart:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -65,7 +70,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
|||||||
if !l.settings.Enabled {
|
if !l.settings.Enabled {
|
||||||
// wait for another restart signal to recheck if it is enabled
|
// wait for another restart signal to recheck if it is enabled
|
||||||
select {
|
select {
|
||||||
case <-restart:
|
case <-l.restart:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
unboundCancel()
|
unboundCancel()
|
||||||
return
|
return
|
||||||
@@ -127,7 +132,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
|||||||
<-waitError
|
<-waitError
|
||||||
close(waitError)
|
close(waitError)
|
||||||
return
|
return
|
||||||
case <-restart: // triggered restart
|
case <-l.restart: // triggered restart
|
||||||
l.logger.Info("restarting")
|
l.logger.Info("restarting")
|
||||||
// unboundCancel occurs next loop run when the setup is complete
|
// unboundCancel occurs next loop run when the setup is complete
|
||||||
triggeredRestart = true
|
triggeredRestart = true
|
||||||
@@ -172,7 +177,7 @@ func (l *looper) fallbackToUnencryptedDNS() {
|
|||||||
l.logger.Error("no ipv4 DNS address found for providers %s", l.settings.Providers)
|
l.logger.Error("no ipv4 DNS address found for providers %s", l.settings.Providers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{}) {
|
func (l *looper) RunRestartTicker(ctx context.Context) {
|
||||||
if l.settings.UpdatePeriod == 0 {
|
if l.settings.UpdatePeriod == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -183,7 +188,7 @@ func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{})
|
|||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
restart <- struct{}{}
|
l.restart <- struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup)
|
Run(ctx context.Context, wg *sync.WaitGroup)
|
||||||
|
Restart()
|
||||||
|
PortForward()
|
||||||
}
|
}
|
||||||
|
|
||||||
type looper struct {
|
type looper struct {
|
||||||
@@ -37,6 +39,9 @@ type looper struct {
|
|||||||
fileManager files.FileManager
|
fileManager files.FileManager
|
||||||
streamMerger command.StreamMerger
|
streamMerger command.StreamMerger
|
||||||
fatalOnError func(err error)
|
fatalOnError func(err error)
|
||||||
|
// Internal channels
|
||||||
|
restart chan struct{}
|
||||||
|
portForwardSignals chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
||||||
@@ -45,25 +50,30 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
|||||||
logger logging.Logger, client network.Client, fileManager files.FileManager,
|
logger logging.Logger, client network.Client, fileManager files.FileManager,
|
||||||
streamMerger command.StreamMerger, fatalOnError func(err error)) Looper {
|
streamMerger command.StreamMerger, fatalOnError func(err error)) Looper {
|
||||||
return &looper{
|
return &looper{
|
||||||
provider: provider,
|
provider: provider,
|
||||||
settings: settings,
|
settings: settings,
|
||||||
uid: uid,
|
uid: uid,
|
||||||
gid: gid,
|
gid: gid,
|
||||||
conf: conf,
|
conf: conf,
|
||||||
fw: fw,
|
fw: fw,
|
||||||
logger: logger.WithPrefix("openvpn: "),
|
logger: logger.WithPrefix("openvpn: "),
|
||||||
client: client,
|
client: client,
|
||||||
fileManager: fileManager,
|
fileManager: fileManager,
|
||||||
streamMerger: streamMerger,
|
streamMerger: streamMerger,
|
||||||
fatalOnError: fatalOnError,
|
fatalOnError: fatalOnError,
|
||||||
|
restart: make(chan struct{}),
|
||||||
|
portForwardSignals: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{}, wg *sync.WaitGroup) {
|
func (l *looper) Restart() { l.restart <- struct{}{} }
|
||||||
|
func (l *looper) PortForward() { l.portForwardSignals <- struct{}{} }
|
||||||
|
|
||||||
|
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
select {
|
select {
|
||||||
case <-restart:
|
case <-l.restart:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -107,7 +117,7 @@ func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{},
|
|||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-portForward:
|
case <-l.portForwardSignals:
|
||||||
l.portForward(ctx, providerConf, l.client)
|
l.portForward(ctx, providerConf, l.client)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -126,7 +136,7 @@ func (l *looper) Run(ctx context.Context, restart, portForward <-chan struct{},
|
|||||||
<-waitError
|
<-waitError
|
||||||
close(waitError)
|
close(waitError)
|
||||||
return
|
return
|
||||||
case <-restart: // triggered restart
|
case <-l.restart: // triggered restart
|
||||||
l.logger.Info("restarting")
|
l.logger.Info("restarting")
|
||||||
openvpnCancel()
|
openvpnCancel()
|
||||||
<-waitError
|
<-waitError
|
||||||
|
|||||||
@@ -11,8 +11,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, restart <-chan struct{})
|
Run(ctx context.Context)
|
||||||
RunRestartTicker(ctx context.Context, restart chan<- struct{})
|
RunRestartTicker(ctx context.Context)
|
||||||
|
Restart()
|
||||||
}
|
}
|
||||||
|
|
||||||
type looper struct {
|
type looper struct {
|
||||||
@@ -22,6 +23,7 @@ type looper struct {
|
|||||||
ipStatusFilepath models.Filepath
|
ipStatusFilepath models.Filepath
|
||||||
uid int
|
uid int
|
||||||
gid int
|
gid int
|
||||||
|
restart chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager,
|
func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager,
|
||||||
@@ -33,9 +35,12 @@ func NewLooper(client network.Client, logger logging.Logger, fileManager files.F
|
|||||||
ipStatusFilepath: ipStatusFilepath,
|
ipStatusFilepath: ipStatusFilepath,
|
||||||
uid: uid,
|
uid: uid,
|
||||||
gid: gid,
|
gid: gid,
|
||||||
|
restart: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *looper) Restart() { l.restart <- struct{}{} }
|
||||||
|
|
||||||
func (l *looper) logAndWait(ctx context.Context, err error) {
|
func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||||
l.logger.Error(err)
|
l.logger.Error(err)
|
||||||
l.logger.Info("retrying in 5 seconds")
|
l.logger.Info("retrying in 5 seconds")
|
||||||
@@ -44,9 +49,9 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
|||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, restart <-chan struct{}) {
|
func (l *looper) Run(ctx context.Context) {
|
||||||
select {
|
select {
|
||||||
case <-restart:
|
case <-l.restart:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -69,7 +74,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-restart: // triggered restart
|
case <-l.restart: // triggered restart
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
l.logger.Warn("context canceled: exiting loop")
|
l.logger.Warn("context canceled: exiting loop")
|
||||||
return
|
return
|
||||||
@@ -77,7 +82,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{}) {
|
func (l *looper) RunRestartTicker(ctx context.Context) {
|
||||||
ticker := time.NewTicker(time.Hour)
|
ticker := time.NewTicker(time.Hour)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -85,7 +90,7 @@ func (l *looper) RunRestartTicker(ctx context.Context, restart chan<- struct{})
|
|||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
restart <- struct{}{}
|
l.restart <- struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,11 +17,11 @@ type Server interface {
|
|||||||
type server struct {
|
type server struct {
|
||||||
address string
|
address string
|
||||||
logger logging.Logger
|
logger logging.Logger
|
||||||
restartOpenvpn chan<- struct{}
|
restartOpenvpn func()
|
||||||
restartUnbound chan<- struct{}
|
restartUnbound func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound chan<- struct{}) Server {
|
func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound func()) Server {
|
||||||
return &server{
|
return &server{
|
||||||
address: address,
|
address: address,
|
||||||
logger: logger.WithPrefix("http server: "),
|
logger: logger.WithPrefix("http server: "),
|
||||||
@@ -58,9 +58,9 @@ func (s *server) makeHandler() http.HandlerFunc {
|
|||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
switch r.RequestURI {
|
switch r.RequestURI {
|
||||||
case "/openvpn/actions/restart":
|
case "/openvpn/actions/restart":
|
||||||
s.restartOpenvpn <- struct{}{}
|
s.restartOpenvpn()
|
||||||
case "/unbound/actions/restart":
|
case "/unbound/actions/restart":
|
||||||
s.restartUnbound <- struct{}{}
|
s.restartUnbound()
|
||||||
default:
|
default:
|
||||||
routeDoesNotExist(s.logger, w, r)
|
routeDoesNotExist(s.logger, w, r)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup)
|
Run(ctx context.Context, wg *sync.WaitGroup)
|
||||||
|
Restart()
|
||||||
}
|
}
|
||||||
|
|
||||||
type looper struct {
|
type looper struct {
|
||||||
@@ -24,6 +25,7 @@ type looper struct {
|
|||||||
streamMerger command.StreamMerger
|
streamMerger command.StreamMerger
|
||||||
uid int
|
uid int
|
||||||
gid int
|
gid int
|
||||||
|
restart chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) logAndWait(ctx context.Context, err error) {
|
func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||||
@@ -45,14 +47,17 @@ func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings s
|
|||||||
streamMerger: streamMerger,
|
streamMerger: streamMerger,
|
||||||
uid: uid,
|
uid: uid,
|
||||||
gid: gid,
|
gid: gid,
|
||||||
|
restart: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) {
|
func (l *looper) Restart() { l.restart <- struct{}{} }
|
||||||
|
|
||||||
|
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
select {
|
select {
|
||||||
case <-restart:
|
case <-l.restart:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -109,7 +114,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
|||||||
<-waitError
|
<-waitError
|
||||||
close(waitError)
|
close(waitError)
|
||||||
return
|
return
|
||||||
case <-restart: // triggered restart
|
case <-l.restart: // triggered restart
|
||||||
l.logger.Info("restarting")
|
l.logger.Info("restarting")
|
||||||
shadowsocksCancel()
|
shadowsocksCancel()
|
||||||
<-waitError
|
<-waitError
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Looper interface {
|
type Looper interface {
|
||||||
Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup)
|
Run(ctx context.Context, wg *sync.WaitGroup)
|
||||||
|
Restart()
|
||||||
}
|
}
|
||||||
|
|
||||||
type looper struct {
|
type looper struct {
|
||||||
@@ -23,6 +24,7 @@ type looper struct {
|
|||||||
streamMerger command.StreamMerger
|
streamMerger command.StreamMerger
|
||||||
uid int
|
uid int
|
||||||
gid int
|
gid int
|
||||||
|
restart chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) logAndWait(ctx context.Context, err error) {
|
func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||||
@@ -43,14 +45,17 @@ func NewLooper(conf Configurator, firewallConf firewall.Configurator, settings s
|
|||||||
streamMerger: streamMerger,
|
streamMerger: streamMerger,
|
||||||
uid: uid,
|
uid: uid,
|
||||||
gid: gid,
|
gid: gid,
|
||||||
|
restart: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.WaitGroup) {
|
func (l *looper) Restart() { l.restart <- struct{}{} }
|
||||||
|
|
||||||
|
func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
select {
|
select {
|
||||||
case <-restart:
|
case <-l.restart:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -102,7 +107,7 @@ func (l *looper) Run(ctx context.Context, restart <-chan struct{}, wg *sync.Wait
|
|||||||
<-waitError
|
<-waitError
|
||||||
close(waitError)
|
close(waitError)
|
||||||
return
|
return
|
||||||
case <-restart: // triggered restart
|
case <-l.restart: // triggered restart
|
||||||
l.logger.Info("restarting")
|
l.logger.Info("restarting")
|
||||||
tinyproxyCancel()
|
tinyproxyCancel()
|
||||||
<-waitError
|
<-waitError
|
||||||
|
|||||||
Reference in New Issue
Block a user