Simplify main.go

This commit is contained in:
Quentin McGaw
2020-09-12 19:17:19 +00:00
parent 464c7074d0
commit 1fc1776dbf
4 changed files with 84 additions and 95 deletions

View File

@@ -156,10 +156,10 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
} }
} }
connectedCh, dnsReadyCh := make(chan struct{}), make(chan struct{}) tunnelReadyCh, dnsReadyCh := make(chan struct{}), make(chan struct{})
signalConnected := func() { connectedCh <- struct{}{} } signalTunnelReady := func() { tunnelReadyCh <- struct{}{} }
signalDNSReady := func() { dnsReadyCh <- struct{}{} } signalDNSReady := func() { dnsReadyCh <- struct{}{} }
defer close(connectedCh) defer close(tunnelReadyCh)
defer close(dnsReadyCh) defer close(dnsReadyCh)
if allSettings.Firewall.Enabled { if allSettings.Firewall.Enabled {
@@ -186,14 +186,10 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
go collectStreamLines(ctx, streamMerger, logger, signalConnected) go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady)
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers, openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers,
ovpnConf, firewallConf, logger, client, fileManager, streamMerger, cancel) ovpnConf, firewallConf, logger, client, fileManager, streamMerger, cancel)
restartOpenvpn := openvpnLooper.Restart
portForward := openvpnLooper.PortForward
getOpenvpnSettings := openvpnLooper.GetSettings
getPortForwarded := openvpnLooper.GetPortForwarded
wg.Add(1) wg.Add(1)
// wait for restartOpenvpn // wait for restartOpenvpn
go openvpnLooper.Run(ctx, wg) go openvpnLooper.Run(ctx, wg)
@@ -205,19 +201,16 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
go updaterLooper.Run(ctx, wg) go updaterLooper.Run(ctx, wg)
unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid) unboundLooper := dns.NewLooper(dnsConf, allSettings.DNS, logger, streamMerger, uid, gid)
restartUnbound := unboundLooper.Restart
wg.Add(1) wg.Add(1)
// wait for restartUnbound or its ticker launched with RunRestartTicker // wait for unboundLooper.Restart or its ticker launched with RunRestartTicker
go unboundLooper.Run(ctx, wg, signalDNSReady) go unboundLooper.Run(ctx, wg, signalDNSReady)
publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid) publicIPLooper := publicip.NewLooper(client, logger, fileManager, allSettings.System.IPStatusFilepath, allSettings.PublicIPPeriod, uid, gid)
restartPublicIP := publicIPLooper.Restart
setPublicIPPeriod := publicIPLooper.SetPeriod
wg.Add(1) wg.Add(1)
go publicIPLooper.Run(ctx, wg) go publicIPLooper.Run(ctx, wg)
wg.Add(1) wg.Add(1)
go publicIPLooper.RunRestartTicker(ctx, wg) go publicIPLooper.RunRestartTicker(ctx, wg)
setPublicIPPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker publicIPLooper.SetPeriod(allSettings.PublicIPPeriod) // call after RunRestartTicker
tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid, defaultInterface) tinyproxyLooper := tinyproxy.NewLooper(tinyProxyConf, firewallConf, allSettings.TinyProxy, logger, streamMerger, uid, gid, defaultInterface)
restartTinyproxy := tinyproxyLooper.Restart restartTinyproxy := tinyproxyLooper.Restart
@@ -236,52 +229,18 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
restartShadowsocks() restartShadowsocks()
} }
versionInformation := func() {
if !allSettings.VersionInformation {
return
}
message, err := versionpkg.GetMessage(version, commit, httpClient)
if err != nil {
logger.Error(err)
return
}
logger.Info(message)
}
wg.Add(1) wg.Add(1)
go func() { go routeReadyEvents(ctx, wg, tunnelReadyCh, dnsReadyCh,
defer wg.Done() unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient,
tickerWg := &sync.WaitGroup{} allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward,
// for linters only )
var restartTickerContext context.Context
var restartTickerCancel context.CancelFunc = func() {}
for {
select {
case <-ctx.Done():
restartTickerCancel() // for linters only
tickerWg.Wait()
return
case <-connectedCh: // blocks until openvpn is connected
restartTickerCancel() // stop previous restart tickers
tickerWg.Wait()
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
tickerWg.Add(2)
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
onConnected(allSettings, logger, routingConf, portForward, restartUnbound)
case <-dnsReadyCh:
restartPublicIP() // TODO do not restart if disabled
versionInformation()
}
}
}()
httpServer := server.New("0.0.0.0:8000", logger, restartOpenvpn, restartUnbound, updaterLooper.Restart, httpServer := server.New("0.0.0.0:8000", logger, openvpnLooper, unboundLooper, updaterLooper)
getOpenvpnSettings, getPortForwarded)
wg.Add(1) wg.Add(1)
go httpServer.Run(ctx, wg) go httpServer.Run(ctx, wg)
// Start openvpn for the first time // Start openvpn for the first time
restartOpenvpn() openvpnLooper.Restart()
signalsCh := make(chan os.Signal, 1) signalsCh := make(chan os.Signal, 1)
signal.Notify(signalsCh, signal.Notify(signalsCh,
@@ -352,7 +311,7 @@ func printVersions(ctx context.Context, logger logging.Logger, versionFunctions
} }
} }
func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger, logger logging.Logger, signalConnected func()) { func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger, logger logging.Logger, signalTunnelReady func()) {
// Blocking line merging paramsReader for all programs: openvpn, tinyproxy, unbound and shadowsocks // Blocking line merging paramsReader for all programs: openvpn, tinyproxy, unbound and shadowsocks
logger.Info("Launching standard output merger") logger.Info("Launching standard output merger")
streamMerger.CollectLines(ctx, func(line string) { streamMerger.CollectLines(ctx, func(line string) {
@@ -369,29 +328,61 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger,
logger.Error(line) logger.Error(line)
} }
if strings.Contains(line, "Initialization Sequence Completed") { if strings.Contains(line, "Initialization Sequence Completed") {
signalConnected() signalTunnelReady()
} }
}, func(err error) { }, func(err error) {
logger.Warn(err) logger.Warn(err)
}) })
} }
func onConnected(allSettings settings.Settings, logger logging.Logger, routingConf routing.Routing, func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dnsReadyCh <-chan struct{},
portForward, restartUnbound func(), unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
) { routing routing.Routing, logger logging.Logger, httpClient *http.Client,
restartUnbound() versionInformation, portForwardingEnabled bool, startPortForward func()) {
if allSettings.OpenVPN.Provider.PortForwarding.Enabled { defer wg.Done()
time.AfterFunc(5*time.Second, portForward) tickerWg := &sync.WaitGroup{}
} // for linters only
defaultInterface, _, err := routingConf.DefaultRoute() var restartTickerContext context.Context
if err != nil { var restartTickerCancel context.CancelFunc = func() {}
logger.Warn(err) for {
} else { select {
vpnGatewayIP, err := routingConf.VPNGatewayIP(defaultInterface) case <-ctx.Done():
if err != nil { restartTickerCancel() // for linters only
logger.Warn(err) tickerWg.Wait()
} else { return
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP) case <-tunnelReadyCh: // blocks until openvpn is connected
unboundLooper.Restart()
restartTickerCancel() // stop previous restart tickers
tickerWg.Wait()
restartTickerContext, restartTickerCancel = context.WithCancel(ctx)
tickerWg.Add(2)
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
if portForwardingEnabled {
time.AfterFunc(5*time.Second, startPortForward)
}
defaultInterface, _, err := routing.DefaultRoute()
if err != nil {
logger.Warn(err)
} else {
vpnGatewayIP, err := routing.VPNGatewayIP(defaultInterface)
if err != nil {
logger.Warn(err)
} else {
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
}
}
case <-dnsReadyCh:
publicIPLooper.Restart() // TODO do not restart if disabled
if !versionInformation {
break
}
message, err := versionpkg.GetMessage(version, commit, httpClient)
if err != nil {
logger.Error(err)
break
}
logger.Info(message)
} }
} }
} }

View File

@@ -154,6 +154,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
continue continue
} }
// Needs the stream line from main.go to know when the tunnel is up
go func(ctx context.Context) { go func(ctx context.Context) {
for { for {
select { select {

View File

@@ -6,7 +6,7 @@ import (
) )
func (s *server) handleGetPortForwarded(w http.ResponseWriter) { func (s *server) handleGetPortForwarded(w http.ResponseWriter) {
port := s.getPortForwarded() port := s.openvpnLooper.GetPortForwarded()
data, err := json.Marshal(struct { data, err := json.Marshal(struct {
Port uint16 `json:"port"` Port uint16 `json:"port"`
}{port}) }{port})
@@ -22,7 +22,7 @@ func (s *server) handleGetPortForwarded(w http.ResponseWriter) {
} }
func (s *server) handleGetOpenvpnSettings(w http.ResponseWriter) { func (s *server) handleGetOpenvpnSettings(w http.ResponseWriter) {
settings := s.getOpenvpnSettings() settings := s.openvpnLooper.GetSettings()
data, err := json.Marshal(settings) data, err := json.Marshal(settings)
if err != nil { if err != nil {
s.logger.Warn(err) s.logger.Warn(err)

View File

@@ -8,7 +8,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/dns"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/updater"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -17,27 +19,22 @@ type Server interface {
} }
type server struct { type server struct {
address string address string
logger logging.Logger logger logging.Logger
restartOpenvpn func() openvpnLooper openvpn.Looper
restartUnbound func() unboundLooper dns.Looper
restartUpdater func() updaterLooper updater.Looper
getOpenvpnSettings func() settings.OpenVPN lookupIP func(host string) ([]net.IP, error)
getPortForwarded func() uint16
lookupIP func(host string) ([]net.IP, error)
} }
func New(address string, logger logging.Logger, restartOpenvpn, restartUnbound, restartUpdater func(), func New(address string, logger logging.Logger, openvpnLooper openvpn.Looper, unboundLooper dns.Looper, updaterLooper updater.Looper) Server {
getOpenvpnSettings func() settings.OpenVPN, getPortForwarded func() uint16) Server {
return &server{ return &server{
address: address, address: address,
logger: logger.WithPrefix("http server: "), logger: logger.WithPrefix("http server: "),
restartOpenvpn: restartOpenvpn, openvpnLooper: openvpnLooper,
restartUnbound: restartUnbound, unboundLooper: unboundLooper,
restartUpdater: restartUpdater, updaterLooper: updaterLooper,
getOpenvpnSettings: getOpenvpnSettings, lookupIP: net.LookupIP,
getPortForwarded: getPortForwarded,
lookupIP: net.LookupIP,
} }
} }
@@ -68,10 +65,10 @@ func (s *server) makeHandler() http.HandlerFunc {
case http.MethodGet: case http.MethodGet:
switch r.RequestURI { switch r.RequestURI {
case "/openvpn/actions/restart": case "/openvpn/actions/restart":
s.restartOpenvpn() s.openvpnLooper.Restart()
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
case "/unbound/actions/restart": case "/unbound/actions/restart":
s.restartUnbound() s.unboundLooper.Restart()
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
case "/openvpn/portforwarded": case "/openvpn/portforwarded":
s.handleGetPortForwarded(w) s.handleGetPortForwarded(w)
@@ -80,7 +77,7 @@ func (s *server) makeHandler() http.HandlerFunc {
case "/health": case "/health":
s.handleHealth(w) s.handleHealth(w)
case "/updater/restart": case "/updater/restart":
s.restartUpdater() s.updaterLooper.Restart()
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
default: default:
routeDoesNotExist(s.logger, w, r) routeDoesNotExist(s.logger, w, r)