diff --git a/cmd/main.go b/cmd/main.go index a4ec4eca..9a982d81 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -35,7 +35,7 @@ import ( "github.com/qdm12/private-internet-access-docker/internal/windscribe" ) -func main() { +func main() { //nolint:gocognit logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel, -1) if err != nil { panic(err) @@ -116,19 +116,14 @@ func main() { err = firewallConf.AcceptAll(ctx) e.FatalOnError(err) + connected, signalConnected := context.WithCancel(context.Background()) go func() { // Blocking line merging paramsReader for all programs: openvpn, tinyproxy, unbound and shadowsocks logger.Info("Launching standard output merger") streamMerger.CollectLines(ctx, func(line string) { logger.Info(line) if strings.Contains(line, "Initialization Sequence Completed") { - go onConnected(logger, routingConf, fileManager, piaConf, - defaultInterface, - allSettings.PIA.PortForwarding.Enabled, - allSettings.PIA.PortForwarding.Filepath, - allSettings.System.IPStatusFilepath, - allSettings.System.UID, - allSettings.System.GID) + signalConnected() } }, func(err error) { logger.Error(err) @@ -304,6 +299,52 @@ func main() { return err }) + go func() { + <-connected.Done() // blocks until openvpn is connected + + ip, err := routingConf.CurrentPublicIP(defaultInterface) + if err != nil { + logger.Error(err) + } else { + logger.Info("Tunnel IP is %s, see more information at https://ipinfo.io/%s", ip, ip) + err = fileManager.WriteLinesToFile( + string(allSettings.System.IPStatusFilepath), + []string{ip.String()}, + files.Ownership(allSettings.System.UID, allSettings.System.GID), + files.Permissions(0400)) + if err != nil { + logger.Error(err) + } + } + + if allSettings.PIA.PortForwarding.Enabled { + pfLogger := logger.WithPrefix("port forwarding: ") + var port uint16 + var err error + for { + port, err = piaConf.GetPortForward() + if err != nil { + pfLogger.Error(err) + pfLogger.Info("retrying in 5 seconds...") + time.Sleep(5 * time.Second) + } else { + pfLogger.Info("port forwarded is %d", port) + break + } + } + pfLogger.Info("writing forwarded port to %s", allSettings.PIA.PortForwarding.Filepath) + if err := piaConf.WritePortForward(allSettings.PIA.PortForwarding.Filepath, port, allSettings.System.UID, allSettings.System.GID); err != nil { + pfLogger.Error(err) + } + pfLogger.Info("allowing forwarded port %d through firewall", port) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil { + pfLogger.Error(err) + } + } + }() + signalsCh := make(chan os.Signal, 1) signal.Notify(signalsCh, syscall.SIGINT, @@ -332,58 +373,3 @@ func main() { logger.Error(err) } } - -func onConnected( - logger logging.Logger, - routingConf routing.Routing, - fileManager files.FileManager, - piaConf pia.Configurator, - defaultInterface string, - portForwarding bool, - portForwardingFilepath models.Filepath, - ipStatusFilepath models.Filepath, - uid, gid int, -) { - ip, err := routingConf.CurrentPublicIP(defaultInterface) - if err != nil { - logger.Error(err) - } else { - logger.Info("Tunnel IP is %s, see more information at https://ipinfo.io/%s", ip, ip) - err := fileManager.WriteLinesToFile( - string(ipStatusFilepath), - []string{ip.String()}, - files.Ownership(uid, gid), - files.Permissions(0400)) - if err != nil { - logger.Error(err) - } - } - if !portForwarding { - return - } - time.AfterFunc(5*time.Second, func() { - pfLogger := logger.WithPrefix("port forwarding: ") - var port uint16 - for { - port, err = piaConf.GetPortForward() - if err != nil { - pfLogger.Error(err) - pfLogger.Info("retrying in 5 seconds...") - time.Sleep(5 * time.Second) - } else { - pfLogger.Info("port forwarded is %d", port) - break - } - } - pfLogger.Info("writing forwarded port to %s", portForwardingFilepath) - if err := piaConf.WritePortForward(portForwardingFilepath, port, uid, gid); err != nil { - pfLogger.Error(err) - } - pfLogger.Info("allowing forwarded port %d through firewall", port) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil { - pfLogger.Error(err) - } - }) -}