Simplified loop mechanism for openvpn and dns
- Refers to #91 - http control server starts without waiting for unbound and/or openvpn - Trying to get rid of waiter and use channels directly - Simpler main.go - More robust logic overall
This commit is contained in:
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
@@ -211,40 +210,49 @@ func _main(background context.Context, args []string) int {
|
||||
go streamMerger.Merge(ctx, stderr, command.MergeName("shadowsocks error"), command.MergeColor(constants.ColorShadowsocksError()))
|
||||
}
|
||||
|
||||
httpServer := server.New("0.0.0.0:8000", logger)
|
||||
restartOpenvpn := make(chan struct{})
|
||||
restartUnbound := make(chan struct{})
|
||||
openvpnDone := make(chan struct{})
|
||||
unboundDone := make(chan struct{})
|
||||
serverDone := make(chan struct{})
|
||||
|
||||
go openvpnRunLoop(ctx, ovpnConf, streamMerger, logger, httpServer, waiter, fatalOnError)
|
||||
openvpnLooper := openvpn.NewLooper(ovpnConf, allSettings.OpenVPN, logger, streamMerger, fatalOnError)
|
||||
// wait for restartOpenvpn
|
||||
go openvpnLooper.Run(ctx, restartOpenvpn, openvpnDone)
|
||||
|
||||
waiter.Add(func() error {
|
||||
err := httpServer.Run(ctx)
|
||||
logger.Error("http server: %s", err)
|
||||
return err
|
||||
})
|
||||
|
||||
startUnboundCh := make(chan struct{})
|
||||
go unboundRunLoop(ctx, startUnboundCh, logger, dnsConf, allSettings.DNS, allSettings.System.UID, allSettings.System.GID, waiter, streamMerger, httpServer)
|
||||
if !allSettings.DNS.Enabled {
|
||||
httpServer.SetUnboundRestart(func() {})
|
||||
dnsConf.UseDNSInternally(allSettings.DNS.PlaintextAddress)
|
||||
if err := dnsConf.UseDNSSystemWide(allSettings.DNS.PlaintextAddress); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger,
|
||||
streamMerger, allSettings.System.UID, allSettings.System.GID)
|
||||
// wait for restartUnbound
|
||||
go unboundLooper.Run(ctx, restartUnbound, unboundDone)
|
||||
|
||||
go func() {
|
||||
first := true
|
||||
var restartTickerContext context.Context
|
||||
var restartTickerCancel context.CancelFunc = func() {}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
restartTickerCancel()
|
||||
return
|
||||
case <-connectedCh: // blocks until openvpn is connected
|
||||
if allSettings.DNS.Enabled {
|
||||
startUnboundCh <- struct{}{}
|
||||
if first {
|
||||
first = false
|
||||
restartUnbound <- struct{}{}
|
||||
}
|
||||
restartTickerCancel()
|
||||
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
|
||||
go unboundLooper.RunRestartTicker(restartTickerContext, restartUnbound)
|
||||
onConnected(allSettings, logger, fileManager, routingConf, defaultInterface, providerConf)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
httpServer := server.New("0.0.0.0:8000", logger, restartOpenvpn, restartUnbound)
|
||||
go httpServer.Run(ctx, serverDone)
|
||||
|
||||
// Start openvpn for the first time
|
||||
restartOpenvpn <- struct{}{}
|
||||
|
||||
signalsCh := make(chan os.Signal, 1)
|
||||
signal.Notify(signalsCh,
|
||||
syscall.SIGINT,
|
||||
@@ -278,6 +286,9 @@ func _main(background context.Context, args []string) int {
|
||||
logger.Error(err)
|
||||
exitStatus = 1
|
||||
}
|
||||
<-serverDone
|
||||
<-unboundDone
|
||||
<-openvpnDone
|
||||
return exitStatus
|
||||
}
|
||||
|
||||
@@ -348,34 +359,6 @@ func trimEventualProgramPrefix(s string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func openvpnRunLoop(ctx context.Context, ovpnConf openvpn.Configurator, streamMerger command.StreamMerger,
|
||||
logger logging.Logger, httpServer server.Server, waiter command.Waiter, fatalOnError func(err error)) {
|
||||
logger = logger.WithPrefix("openvpn: ")
|
||||
waitErrors := make(chan error)
|
||||
for ctx.Err() == nil {
|
||||
logger.Info("starting")
|
||||
openvpnCtx, openvpnCancel := context.WithCancel(ctx)
|
||||
stream, waitFn, err := ovpnConf.Start(openvpnCtx)
|
||||
fatalOnError(err)
|
||||
httpServer.SetOpenVPNRestart(openvpnCancel)
|
||||
go streamMerger.Merge(openvpnCtx, stream, command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn()))
|
||||
waiter.Add(func() error {
|
||||
return <-waitErrors
|
||||
})
|
||||
err = waitFn()
|
||||
waitErrors <- fmt.Errorf("openvpn: %w", err)
|
||||
switch {
|
||||
case ctx.Err() != nil:
|
||||
logger.Warn("context canceled: exiting openvpn run loop")
|
||||
case openvpnCtx.Err() == context.Canceled:
|
||||
logger.Info("triggered openvpn restart")
|
||||
default:
|
||||
logger.Warn(err)
|
||||
openvpnCancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func onConnected(allSettings settings.Settings,
|
||||
logger logging.Logger, fileManager files.FileManager,
|
||||
routingConf routing.Routing, defaultInterface string,
|
||||
@@ -393,7 +376,7 @@ func onConnected(allSettings settings.Settings,
|
||||
} else {
|
||||
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
|
||||
}
|
||||
time.AfterFunc(7*time.Second, func() { // wait for Unbound to start - TODO use signal channel
|
||||
time.AfterFunc(10*time.Second, func() { // wait for Unbound to start - TODO use signal channel
|
||||
publicIP, err := publicip.NewIPGetter(network.NewClient(3 * time.Second)).Get()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
@@ -411,126 +394,6 @@ func onConnected(allSettings settings.Settings,
|
||||
})
|
||||
}
|
||||
|
||||
func fallbackToUnencryptedIPv4DNS(dnsConf dns.Configurator, providers []models.DNSProvider) error {
|
||||
var targetIP net.IP
|
||||
for _, provider := range providers {
|
||||
data := constants.DNSProviderMapping()[provider]
|
||||
for _, targetIP = range data.IPs {
|
||||
if targetIP.To4() != nil {
|
||||
dnsConf.UseDNSInternally(targetIP)
|
||||
return dnsConf.UseDNSSystemWide(targetIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
// No IPv4 address found
|
||||
return fmt.Errorf("no ipv4 DNS address found for providers %s", providers)
|
||||
}
|
||||
|
||||
func unboundRun(ctx, oldCtx context.Context, oldCancel context.CancelFunc, timer *time.Timer,
|
||||
dnsConf dns.Configurator, settings settings.DNS, uid, gid int,
|
||||
streamMerger command.StreamMerger, waiter command.Waiter, httpServer server.Server) (
|
||||
newCtx context.Context, newCancel context.CancelFunc, setupErr, startErr, waitErr error) {
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
timer.Reset(settings.UpdatePeriod)
|
||||
}
|
||||
if err := dnsConf.DownloadRootHints(uid, gid); err != nil {
|
||||
return oldCtx, oldCancel, err, nil, nil
|
||||
}
|
||||
if err := dnsConf.DownloadRootKey(uid, gid); err != nil {
|
||||
return oldCtx, oldCancel, err, nil, nil
|
||||
}
|
||||
if err := dnsConf.MakeUnboundConf(settings, uid, gid); err != nil {
|
||||
return oldCtx, oldCancel, err, nil, nil
|
||||
}
|
||||
newCtx, newCancel = context.WithCancel(ctx)
|
||||
oldCancel()
|
||||
stream, waitFn, err := dnsConf.Start(newCtx, settings.VerbosityDetailsLevel)
|
||||
if err != nil {
|
||||
return newCtx, newCancel, nil, err, nil
|
||||
}
|
||||
go streamMerger.Merge(newCtx, stream, command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound()))
|
||||
dnsConf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
|
||||
if err := dnsConf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound
|
||||
return newCtx, newCancel, nil, err, nil
|
||||
}
|
||||
if err := dnsConf.WaitForUnbound(); err != nil {
|
||||
return newCtx, newCancel, nil, err, nil
|
||||
}
|
||||
// Unbound is up and running at this point
|
||||
httpServer.SetUnboundRestart(newCancel)
|
||||
waitError := make(chan error)
|
||||
waiterError := make(chan error)
|
||||
waiter.Add(func() error { //nolint:scopelint
|
||||
return <-waiterError
|
||||
})
|
||||
go func() {
|
||||
err := fmt.Errorf("unbound: %w", waitFn())
|
||||
waitError <- err
|
||||
waiterError <- err
|
||||
}()
|
||||
if timer == nil {
|
||||
waitErr := <-waitError
|
||||
return newCtx, newCancel, nil, nil, waitErr
|
||||
}
|
||||
select {
|
||||
case <-timer.C:
|
||||
return newCtx, newCancel, nil, nil, nil
|
||||
case waitErr := <-waitError:
|
||||
return newCtx, newCancel, nil, nil, waitErr
|
||||
}
|
||||
}
|
||||
|
||||
func unboundRunLoop(ctx context.Context, startCh chan struct{}, logger logging.Logger, dnsConf dns.Configurator,
|
||||
settings settings.DNS, uid, gid int,
|
||||
waiter command.Waiter, streamMerger command.StreamMerger, httpServer server.Server,
|
||||
) {
|
||||
logger = logger.WithPrefix("unbound dns over tls setup: ")
|
||||
select {
|
||||
case <-startCh:
|
||||
case <-ctx.Done():
|
||||
logger.Warn("context canceled: exiting unbound run loop")
|
||||
return
|
||||
}
|
||||
if err := fallbackToUnencryptedIPv4DNS(dnsConf, settings.Providers); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
var timer *time.Timer
|
||||
if settings.UpdatePeriod > 0 {
|
||||
timer = time.NewTimer(settings.UpdatePeriod)
|
||||
}
|
||||
unboundCtx, unboundCancel := context.WithCancel(ctx)
|
||||
defer unboundCancel()
|
||||
for ctx.Err() == nil {
|
||||
var setupErr, startErr, waitErr error
|
||||
unboundCtx, unboundCancel, setupErr, startErr, waitErr = unboundRun(
|
||||
ctx, unboundCtx, unboundCancel, timer, dnsConf, settings,
|
||||
uid, gid, streamMerger, waiter, httpServer)
|
||||
switch {
|
||||
case ctx.Err() != nil:
|
||||
logger.Warn("context canceled: exiting unbound run loop")
|
||||
case timer != nil && !timer.Stop():
|
||||
logger.Info("planned restart of unbound")
|
||||
case unboundCtx.Err() == context.Canceled:
|
||||
logger.Info("triggered restart of unbound")
|
||||
case setupErr != nil:
|
||||
logger.Warn(setupErr)
|
||||
case startErr != nil:
|
||||
logger.Error(startErr)
|
||||
unboundCancel()
|
||||
if err := fallbackToUnencryptedIPv4DNS(dnsConf, settings.Providers); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
case waitErr != nil:
|
||||
logger.Warn(waitErr)
|
||||
if err := fallbackToUnencryptedIPv4DNS(dnsConf, settings.Providers); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
logger.Warn("restarting unbound because of unexpected exit")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupPortForwarding(logger logging.Logger, providerConf provider.Provider, filepath models.Filepath, uid, gid int) {
|
||||
pfLogger := logger.WithPrefix("port forwarding: ")
|
||||
var port uint16
|
||||
|
||||
Reference in New Issue
Block a user