diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index b2fe62fc..94c79083 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -53,42 +53,79 @@ func main() { Commit: commit, BuildDate: buildDate, } + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + + logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel) + if err != nil { + fmt.Println(err) + nativeos.Exit(1) + } + args := nativeos.Args os := os.New() osUser := user.New() unix := unix.New() cli := cli.New() - nativeos.Exit(_main(ctx, buildInfo, args, os, osUser, unix, cli)) + + errorCh := make(chan error) + go func() { + errorCh <- _main(ctx, buildInfo, args, logger, os, osUser, unix, cli) + }() + + signalsCh := make(chan nativeos.Signal, 1) + signal.Notify(signalsCh, + syscall.SIGINT, + syscall.SIGTERM, + nativeos.Interrupt, + ) + + select { + case signal := <-signalsCh: + logger.Warn("Caught OS signal %s, shutting down", signal) + case err := <-errorCh: + logger.Error(err) + close(errorCh) + } + + cancel() + + const shutdownGracePeriod = 5 * time.Second + timer := time.NewTimer(shutdownGracePeriod) + select { + case <-errorCh: + if !timer.Stop() { + <-timer.C + } + logger.Info("Shutdown successful") + case <-timer.C: + logger.Warn("Shutdown timed out") + } + + nativeos.Exit(1) } //nolint:gocognit,gocyclo func _main(background context.Context, buildInfo models.BuildInformation, - args []string, os os.OS, osUser user.OSUser, unix unix.Unix, - cli cli.CLI) int { + args []string, logger logging.Logger, os os.OS, osUser user.OSUser, unix unix.Unix, + cli cli.CLI) error { if len(args) > 1 { // cli operation - var err error switch args[1] { case "healthcheck": - err = cli.HealthCheck(background) + return cli.HealthCheck(background) case "clientkey": - err = cli.ClientKey(args[2:], os.OpenFile) + return cli.ClientKey(args[2:], os.OpenFile) case "openvpnconfig": - err = cli.OpenvpnConfig(os) + return cli.OpenvpnConfig(os) case "update": - err = cli.Update(args[2:], os) + return cli.Update(args[2:], os) default: - err = fmt.Errorf("command %q is unknown", args[1]) + return fmt.Errorf("command %q is unknown", args[1]) } - if err != nil { - fmt.Println(err) - return 1 - } - return 0 } ctx, cancel := context.WithCancel(background) defer cancel() - logger := createLogger() const clientTimeout = 15 * time.Second httpClient := &http.Client{Timeout: clientTimeout} @@ -114,26 +151,22 @@ func _main(background context.Context, buildInfo models.BuildInformation, allSettings, err := settings.GetAllSettings(paramsReader) if err != nil { - logger.Error(err) - return 1 + return err } logger.Info(allSettings.String()) if err := os.MkdirAll("/tmp/gluetun", 0644); err != nil { - logger.Error(err) - return 1 + return err } if err := os.MkdirAll("/gluetun", 0644); err != nil { - logger.Error(err) - return 1 + return err } // TODO run this in a loop or in openvpn to reload from file without restarting storage := storage.New(logger, os, constants.ServersData) allServers, err := storage.SyncServers(constants.GetAllServers()) if err != nil { - logger.Error(err) - return 1 + return err } // Should never change @@ -142,16 +175,14 @@ func _main(background context.Context, buildInfo models.BuildInformation, const defaultUsername = "nonrootuser" nonRootUsername, err := alpineConf.CreateUser(defaultUsername, puid) if err != nil { - logger.Error(err) - return 1 + return err } if nonRootUsername != defaultUsername { logger.Info("using existing username %s corresponding to user id %d", nonRootUsername, puid) } if err := os.Chown("/etc/unbound", puid, pgid); err != nil { - logger.Error(err) - return 1 + return err } if allSettings.Firewall.Debug { @@ -161,27 +192,23 @@ func _main(background context.Context, buildInfo models.BuildInformation, defaultInterface, defaultGateway, err := routingConf.DefaultRoute() if err != nil { - logger.Error(err) - return 1 + return err } localSubnet, err := routingConf.LocalSubnet() if err != nil { - logger.Error(err) - return 1 + return err } defaultIP, err := routingConf.DefaultIP() if err != nil { - logger.Error(err) - return 1 + return err } firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localSubnet, defaultIP) if err := routingConf.Setup(); err != nil { - logger.Error(err) - return 1 + return err } defer func() { routingConf.SetVerbose(false) @@ -191,20 +218,17 @@ func _main(background context.Context, buildInfo models.BuildInformation, }() if err := firewallConf.SetOutboundSubnets(ctx, allSettings.Firewall.OutboundSubnets); err != nil { - logger.Error(err) - return 1 + return err } if err := routingConf.SetOutboundRoutes(allSettings.Firewall.OutboundSubnets); err != nil { - logger.Error(err) - return 1 + return err } if err := ovpnConf.CheckTUN(); err != nil { logger.Warn(err) err = ovpnConf.CreateTUN() if err != nil { - logger.Error(err) - return 1 + return err } } @@ -217,30 +241,27 @@ func _main(background context.Context, buildInfo models.BuildInformation, if allSettings.Firewall.Enabled { err := firewallConf.SetEnabled(ctx, true) // disabled by default if err != nil { - logger.Error(err) - return 1 + return err } } for _, vpnPort := range allSettings.Firewall.VPNInputPorts { err = firewallConf.SetAllowedPort(ctx, vpnPort, string(constants.TUN)) if err != nil { - logger.Error(err) - return 1 + return err } } for _, port := range allSettings.Firewall.InputPorts { err = firewallConf.SetAllowedPort(ctx, port, defaultInterface) if err != nil { - logger.Error(err) - return 1 + return err } } // TODO move inside firewall? wg := &sync.WaitGroup{} - go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady) + go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady) // TODO waitgroup openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers, ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, streamMerger, cancel) @@ -296,55 +317,18 @@ func _main(background context.Context, buildInfo models.BuildInformation, // until openvpn is launched _, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable - signalsCh := make(chan nativeos.Signal, 1) - signal.Notify(signalsCh, - syscall.SIGINT, - syscall.SIGTERM, - nativeos.Interrupt, - ) - shutdownErrorsCount := 0 - select { - case signal := <-signalsCh: - logger.Warn("Caught OS signal %s, shutting down", signal) - cancel() - case <-ctx.Done(): - logger.Warn("context canceled, shutting down") - } + <-ctx.Done() + if allSettings.OpenVPN.Provider.PortForwarding.Enabled { logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath) if err := os.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil { logger.Error(err) - shutdownErrorsCount++ } } - const shutdownGracePeriod = 5 * time.Second - waiting, waited := context.WithTimeout(context.Background(), shutdownGracePeriod) - go func() { - defer waited() - wg.Wait() - }() - <-waiting.Done() - if waiting.Err() == context.DeadlineExceeded { - if shutdownErrorsCount > 0 { - logger.Warn("Shutdown had %d errors", shutdownErrorsCount) - } - logger.Warn("Shutdown timed out") - return 1 - } - if shutdownErrorsCount > 0 { - logger.Warn("Shutdown had %d errors") - return 1 - } - logger.Info("Shutdown successful") - return 0 -} -func createLogger() logging.Logger { - logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel) - if err != nil { - panic(err) - } - return logger + wg.Wait() + + return nil } func printVersions(ctx context.Context, logger logging.Logger,