Main function improved

- More explicit cli operation
- Using ctx and os.Args injected for eventual later testing
- Returning exit code
- Cli code moved to cli package
This commit is contained in:
Quentin McGaw
2020-06-02 23:03:18 +00:00
parent a7739b6f5d
commit 3ab1298b1f
2 changed files with 28 additions and 11 deletions

View File

@@ -13,14 +13,13 @@ import (
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files" "github.com/qdm12/golibs/files"
libhealthcheck "github.com/qdm12/golibs/healthcheck"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/network"
"github.com/qdm12/private-internet-access-docker/internal/alpine" "github.com/qdm12/private-internet-access-docker/internal/alpine"
"github.com/qdm12/private-internet-access-docker/internal/cli"
"github.com/qdm12/private-internet-access-docker/internal/constants" "github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/dns" "github.com/qdm12/private-internet-access-docker/internal/dns"
"github.com/qdm12/private-internet-access-docker/internal/firewall" "github.com/qdm12/private-internet-access-docker/internal/firewall"
"github.com/qdm12/private-internet-access-docker/internal/healthcheck"
"github.com/qdm12/private-internet-access-docker/internal/models" "github.com/qdm12/private-internet-access-docker/internal/models"
"github.com/qdm12/private-internet-access-docker/internal/mullvad" "github.com/qdm12/private-internet-access-docker/internal/mullvad"
"github.com/qdm12/private-internet-access-docker/internal/openvpn" "github.com/qdm12/private-internet-access-docker/internal/openvpn"
@@ -37,14 +36,26 @@ import (
) )
func main() { func main() {
if libhealthcheck.Mode(os.Args) { ctx := context.Background()
if err := healthcheck.HealthCheck(); err != nil { os.Exit(_main(ctx, os.Args))
fmt.Println(err) }
os.Exit(1)
func _main(background context.Context, args []string) int {
if len(args) > 1 { // cli operation
var err error
switch args[1] {
case "healthcheck":
err = cli.HealthCheck()
default:
err = fmt.Errorf("command %q is unknown", args[1])
} }
os.Exit(0) if err != nil {
fmt.Println(err)
return 1
}
return 0
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(background)
defer cancel() defer cancel()
logger := createLogger() logger := createLogger()
fatalOnError := makeFatalOnError(logger, cancel) fatalOnError := makeFatalOnError(logger, cancel)
@@ -282,9 +293,11 @@ func main() {
syscall.SIGTERM, syscall.SIGTERM,
os.Interrupt, os.Interrupt,
) )
exitStatus := 0
select { select {
case signal := <-signalsCh: case signal := <-signalsCh:
logger.Warn("Caught OS signal %s, shutting down", signal) logger.Warn("Caught OS signal %s, shutting down", signal)
exitStatus = 1
cancel() cancel()
case <-ctx.Done(): case <-ctx.Done():
logger.Warn("context canceled, shutting down") logger.Warn("context canceled, shutting down")
@@ -292,18 +305,22 @@ func main() {
logger.Info("Clearing ip status file %s", allSettings.System.IPStatusFilepath) logger.Info("Clearing ip status file %s", allSettings.System.IPStatusFilepath)
if err := fileManager.Remove(string(allSettings.System.IPStatusFilepath)); err != nil { if err := fileManager.Remove(string(allSettings.System.IPStatusFilepath)); err != nil {
logger.Error(err) logger.Error(err)
exitStatus = 1
} }
if allSettings.PIA.PortForwarding.Enabled { if allSettings.PIA.PortForwarding.Enabled {
logger.Info("Clearing forwarded port status file %s", allSettings.PIA.PortForwarding.Filepath) logger.Info("Clearing forwarded port status file %s", allSettings.PIA.PortForwarding.Filepath)
if err := fileManager.Remove(string(allSettings.PIA.PortForwarding.Filepath)); err != nil { if err := fileManager.Remove(string(allSettings.PIA.PortForwarding.Filepath)); err != nil {
logger.Error(err) logger.Error(err)
exitStatus = 1
} }
} }
timeoutCtx, cancel := context.WithTimeout(context.Background(), time.Second) timeoutCtx, timeoutCancel := context.WithTimeout(background, time.Second)
defer cancel() defer timeoutCancel()
for _, err := range waiter.WaitForAll(timeoutCtx) { for _, err := range waiter.WaitForAll(timeoutCtx) {
logger.Error(err) logger.Error(err)
exitStatus = 1
} }
return exitStatus
} }
func makeFatalOnError(logger logging.Logger, cancel func()) func(err error) { func makeFatalOnError(logger logging.Logger, cancel func()) func(err error) {

View File

@@ -1,4 +1,4 @@
package healthcheck package cli
import ( import (
"fmt" "fmt"