diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 10db23a9..eb3e333d 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -358,7 +358,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, vpnLogger := logger.NewChild(logging.Settings{Prefix: "vpn: "}) vpnLooper := vpn.NewLoop(allSettings.VPN, allSettings.VPN.Provider, allServers, ovpnConf, firewallConf, routingConf, portForwardLooper, - publicIPLooper, unboundLooper, vpnLogger, httpClient, + cmder, publicIPLooper, unboundLooper, vpnLogger, httpClient, buildInfo, allSettings.VersionInformation) openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler( "openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second}) diff --git a/internal/openvpn/openvpn.go b/internal/openvpn/openvpn.go index fa23e4be..8b79fb6c 100644 --- a/internal/openvpn/openvpn.go +++ b/internal/openvpn/openvpn.go @@ -11,7 +11,6 @@ var _ Interface = (*Configurator)(nil) type Interface interface { VersionGetter AuthWriter - Runner Writer } diff --git a/internal/openvpn/run.go b/internal/openvpn/run.go index 760ea44c..bde03641 100644 --- a/internal/openvpn/run.go +++ b/internal/openvpn/run.go @@ -4,17 +4,27 @@ import ( "context" "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/golibs/command" "github.com/qdm12/golibs/logging" ) -type Runner interface { - Run(ctx context.Context, errCh chan<- error, ready chan<- struct{}, - logger logging.Logger, settings configuration.OpenVPN) +type Runner struct { + settings configuration.OpenVPN + starter command.Starter + logger logging.Logger } -func (c *Configurator) Run(ctx context.Context, errCh chan<- error, - ready chan<- struct{}, logger logging.Logger, settings configuration.OpenVPN) { - stdoutLines, stderrLines, waitError, err := c.start(ctx, settings.Version, settings.Flags) +func NewRunner(settings configuration.OpenVPN, starter command.Starter, + logger logging.Logger) *Runner { + return &Runner{ + starter: starter, + logger: logger, + settings: settings, + } +} + +func (r *Runner) Run(ctx context.Context, errCh chan<- error, ready chan<- struct{}) { + stdoutLines, stderrLines, waitError, err := start(ctx, r.starter, r.settings.Version, r.settings.Flags) if err != nil { errCh <- err return @@ -22,7 +32,7 @@ func (c *Configurator) Run(ctx context.Context, errCh chan<- error, streamCtx, streamCancel := context.WithCancel(context.Background()) streamDone := make(chan struct{}) - go streamLines(streamCtx, streamDone, logger, + go streamLines(streamCtx, streamDone, r.logger, stdoutLines, stderrLines, ready) select { diff --git a/internal/openvpn/start.go b/internal/openvpn/start.go index 87985e3f..986f5528 100644 --- a/internal/openvpn/start.go +++ b/internal/openvpn/start.go @@ -8,6 +8,7 @@ import ( "syscall" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/golibs/command" ) var ErrVersionUnknown = errors.New("OpenVPN version is unknown") @@ -17,7 +18,7 @@ const ( binOpenvpn25 = "openvpn" ) -func (c *Configurator) start(ctx context.Context, version string, flags []string) ( +func start(ctx context.Context, starter command.Starter, version string, flags []string) ( stdoutLines, stderrLines chan string, waitError chan error, err error) { var bin string switch version { @@ -29,12 +30,10 @@ func (c *Configurator) start(ctx context.Context, version string, flags []string return nil, nil, nil, fmt.Errorf("%w: %s", ErrVersionUnknown, version) } - c.logger.Info("starting OpenVPN " + version) - args := []string{"--config", constants.OpenVPNConf} args = append(args, flags...) cmd := exec.CommandContext(ctx, bin, args...) cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} - return c.cmder.Start(cmd) + return starter.Start(cmd) } diff --git a/internal/vpn/loop.go b/internal/vpn/loop.go index fc95a077..6226f66e 100644 --- a/internal/vpn/loop.go +++ b/internal/vpn/loop.go @@ -15,6 +15,7 @@ import ( "github.com/qdm12/gluetun/internal/publicip" "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/vpn/state" + "github.com/qdm12/golibs/command" "github.com/qdm12/golibs/logging" ) @@ -42,8 +43,9 @@ type Loop struct { publicip publicip.Looper dnsLooper dns.Looper // Other objects - logger logging.Logger - client *http.Client + starter command.Starter // for OpenVPN + logger logging.Logger + client *http.Client // Internal channels and values stop <-chan struct{} stopped chan<- struct{} @@ -67,7 +69,7 @@ func NewLoop(vpnSettings configuration.VPN, providerSettings configuration.Provider, allServers models.AllServers, openvpnConf openvpn.Interface, fw firewallConfigurer, routing routing.VPNGetter, - portForward portforward.StartStopper, + portForward portforward.StartStopper, starter command.Starter, publicip publicip.Looper, dnsLooper dns.Looper, logger logging.Logger, client *http.Client, buildInfo models.BuildInformation, versionInfo bool) *Loop { @@ -90,6 +92,7 @@ func NewLoop(vpnSettings configuration.VPN, portForward: portForward, publicip: publicip, dnsLooper: dnsLooper, + starter: starter, logger: logger, client: client, start: start, diff --git a/internal/vpn/openvpn.go b/internal/vpn/openvpn.go index f5a2346e..212a5ec8 100644 --- a/internal/vpn/openvpn.go +++ b/internal/vpn/openvpn.go @@ -11,6 +11,8 @@ import ( "github.com/qdm12/gluetun/internal/openvpn" "github.com/qdm12/gluetun/internal/openvpn/custom" "github.com/qdm12/gluetun/internal/provider" + "github.com/qdm12/golibs/command" + "github.com/qdm12/golibs/logging" ) var ( @@ -24,8 +26,9 @@ var ( // It returns a serverName for port forwarding (PIA) and an error if it fails. func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter, openvpnConf openvpn.Interface, providerConf provider.Provider, - openVPNSettings configuration.OpenVPN, providerSettings configuration.Provider) ( - serverName string, err error) { + openVPNSettings configuration.OpenVPN, providerSettings configuration.Provider, + starter command.Starter, logger logging.Logger) ( + runner vpnRunner, serverName string, err error) { var connection models.Connection var lines []string if openVPNSettings.Config == "" { @@ -37,23 +40,25 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter, lines, connection, err = custom.BuildConfig(openVPNSettings) } if err != nil { - return "", fmt.Errorf("%w: %s", errBuildConfig, err) + return nil, "", fmt.Errorf("%w: %s", errBuildConfig, err) } if err := openvpnConf.WriteConfig(lines); err != nil { - return "", fmt.Errorf("%w: %s", errWriteConfig, err) + return nil, "", fmt.Errorf("%w: %s", errWriteConfig, err) } if openVPNSettings.User != "" { err := openvpnConf.WriteAuthFile(openVPNSettings.User, openVPNSettings.Password) if err != nil { - return "", fmt.Errorf("%w: %s", errWriteAuth, err) + return nil, "", fmt.Errorf("%w: %s", errWriteAuth, err) } } if err := fw.SetVPNConnection(ctx, connection); err != nil { - return "", fmt.Errorf("%w: %s", errFirewall, err) + return nil, "", fmt.Errorf("%w: %s", errFirewall, err) } - return connection.Hostname, nil + runner = openvpn.NewRunner(openVPNSettings, starter, logger) + + return runner, connection.Hostname, nil } diff --git a/internal/vpn/run.go b/internal/vpn/run.go index 72bdff5b..28e8cc20 100644 --- a/internal/vpn/run.go +++ b/internal/vpn/run.go @@ -12,6 +12,10 @@ type Runner interface { Run(ctx context.Context, done chan<- struct{}) } +type vpnRunner interface { + Run(ctx context.Context, errCh chan<- error, ready chan<- struct{}) +} + func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { defer close(done) @@ -26,7 +30,10 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { providerConf := provider.New(providerSettings.Name, allServers, time.Now) - serverName, err := setupOpenVPN(ctx, l.fw, l.openvpnConf, providerConf, VPNSettings.OpenVPN, providerSettings) + vpnRunner, serverName, err := setupOpenVPN(ctx, l.fw, + l.openvpnConf, providerConf, + VPNSettings.OpenVPN, providerSettings, + l.starter, l.logger) if err != nil { l.crashed(ctx, err) continue @@ -41,8 +48,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { waitError := make(chan error) tunnelReady := make(chan struct{}) - go l.openvpnConf.Run(openvpnCtx, waitError, tunnelReady, - l.logger, VPNSettings.OpenVPN) + go vpnRunner.Run(openvpnCtx, waitError, tunnelReady) if err := l.waitForError(ctx, waitError); err != nil { openvpnCancel()