Maint: create OpenVPN runner in VPN run loop

This commit is contained in:
Quentin McGaw (desktop)
2021-08-19 14:45:57 +00:00
parent 3d8e61900b
commit 9218c7ef19
7 changed files with 48 additions and 26 deletions

View File

@@ -358,7 +358,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
vpnLogger := logger.NewChild(logging.Settings{Prefix: "vpn: "}) vpnLogger := logger.NewChild(logging.Settings{Prefix: "vpn: "})
vpnLooper := vpn.NewLoop(allSettings.VPN, allSettings.VPN.Provider, vpnLooper := vpn.NewLoop(allSettings.VPN, allSettings.VPN.Provider,
allServers, ovpnConf, firewallConf, routingConf, portForwardLooper, allServers, ovpnConf, firewallConf, routingConf, portForwardLooper,
publicIPLooper, unboundLooper, vpnLogger, httpClient, cmder, publicIPLooper, unboundLooper, vpnLogger, httpClient,
buildInfo, allSettings.VersionInformation) buildInfo, allSettings.VersionInformation)
openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler( openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler(
"openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second}) "openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second})

View File

@@ -11,7 +11,6 @@ var _ Interface = (*Configurator)(nil)
type Interface interface { type Interface interface {
VersionGetter VersionGetter
AuthWriter AuthWriter
Runner
Writer Writer
} }

View File

@@ -4,17 +4,27 @@ import (
"context" "context"
"github.com/qdm12/gluetun/internal/configuration" "github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
type Runner interface { type Runner struct {
Run(ctx context.Context, errCh chan<- error, ready chan<- struct{}, settings configuration.OpenVPN
logger logging.Logger, settings configuration.OpenVPN) starter command.Starter
logger logging.Logger
} }
func (c *Configurator) Run(ctx context.Context, errCh chan<- error, func NewRunner(settings configuration.OpenVPN, starter command.Starter,
ready chan<- struct{}, logger logging.Logger, settings configuration.OpenVPN) { logger logging.Logger) *Runner {
stdoutLines, stderrLines, waitError, err := c.start(ctx, settings.Version, settings.Flags) return &Runner{
starter: starter,
logger: logger,
settings: settings,
}
}
func (r *Runner) Run(ctx context.Context, errCh chan<- error, ready chan<- struct{}) {
stdoutLines, stderrLines, waitError, err := start(ctx, r.starter, r.settings.Version, r.settings.Flags)
if err != nil { if err != nil {
errCh <- err errCh <- err
return return
@@ -22,7 +32,7 @@ func (c *Configurator) Run(ctx context.Context, errCh chan<- error,
streamCtx, streamCancel := context.WithCancel(context.Background()) streamCtx, streamCancel := context.WithCancel(context.Background())
streamDone := make(chan struct{}) streamDone := make(chan struct{})
go streamLines(streamCtx, streamDone, logger, go streamLines(streamCtx, streamDone, r.logger,
stdoutLines, stderrLines, ready) stdoutLines, stderrLines, ready)
select { select {

View File

@@ -8,6 +8,7 @@ import (
"syscall" "syscall"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/command"
) )
var ErrVersionUnknown = errors.New("OpenVPN version is unknown") var ErrVersionUnknown = errors.New("OpenVPN version is unknown")
@@ -17,7 +18,7 @@ const (
binOpenvpn25 = "openvpn" binOpenvpn25 = "openvpn"
) )
func (c *Configurator) start(ctx context.Context, version string, flags []string) ( func start(ctx context.Context, starter command.Starter, version string, flags []string) (
stdoutLines, stderrLines chan string, waitError chan error, err error) { stdoutLines, stderrLines chan string, waitError chan error, err error) {
var bin string var bin string
switch version { switch version {
@@ -29,12 +30,10 @@ func (c *Configurator) start(ctx context.Context, version string, flags []string
return nil, nil, nil, fmt.Errorf("%w: %s", ErrVersionUnknown, version) return nil, nil, nil, fmt.Errorf("%w: %s", ErrVersionUnknown, version)
} }
c.logger.Info("starting OpenVPN " + version)
args := []string{"--config", constants.OpenVPNConf} args := []string{"--config", constants.OpenVPNConf}
args = append(args, flags...) args = append(args, flags...)
cmd := exec.CommandContext(ctx, bin, args...) cmd := exec.CommandContext(ctx, bin, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
return c.cmder.Start(cmd) return starter.Start(cmd)
} }

View File

@@ -15,6 +15,7 @@ import (
"github.com/qdm12/gluetun/internal/publicip" "github.com/qdm12/gluetun/internal/publicip"
"github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/gluetun/internal/vpn/state" "github.com/qdm12/gluetun/internal/vpn/state"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -42,6 +43,7 @@ type Loop struct {
publicip publicip.Looper publicip publicip.Looper
dnsLooper dns.Looper dnsLooper dns.Looper
// Other objects // Other objects
starter command.Starter // for OpenVPN
logger logging.Logger logger logging.Logger
client *http.Client client *http.Client
// Internal channels and values // Internal channels and values
@@ -67,7 +69,7 @@ func NewLoop(vpnSettings configuration.VPN,
providerSettings configuration.Provider, providerSettings configuration.Provider,
allServers models.AllServers, openvpnConf openvpn.Interface, allServers models.AllServers, openvpnConf openvpn.Interface,
fw firewallConfigurer, routing routing.VPNGetter, fw firewallConfigurer, routing routing.VPNGetter,
portForward portforward.StartStopper, portForward portforward.StartStopper, starter command.Starter,
publicip publicip.Looper, dnsLooper dns.Looper, publicip publicip.Looper, dnsLooper dns.Looper,
logger logging.Logger, client *http.Client, logger logging.Logger, client *http.Client,
buildInfo models.BuildInformation, versionInfo bool) *Loop { buildInfo models.BuildInformation, versionInfo bool) *Loop {
@@ -90,6 +92,7 @@ func NewLoop(vpnSettings configuration.VPN,
portForward: portForward, portForward: portForward,
publicip: publicip, publicip: publicip,
dnsLooper: dnsLooper, dnsLooper: dnsLooper,
starter: starter,
logger: logger, logger: logger,
client: client, client: client,
start: start, start: start,

View File

@@ -11,6 +11,8 @@ import (
"github.com/qdm12/gluetun/internal/openvpn" "github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/openvpn/custom" "github.com/qdm12/gluetun/internal/openvpn/custom"
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/logging"
) )
var ( var (
@@ -24,8 +26,9 @@ var (
// It returns a serverName for port forwarding (PIA) and an error if it fails. // It returns a serverName for port forwarding (PIA) and an error if it fails.
func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter, func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
openvpnConf openvpn.Interface, providerConf provider.Provider, openvpnConf openvpn.Interface, providerConf provider.Provider,
openVPNSettings configuration.OpenVPN, providerSettings configuration.Provider) ( openVPNSettings configuration.OpenVPN, providerSettings configuration.Provider,
serverName string, err error) { starter command.Starter, logger logging.Logger) (
runner vpnRunner, serverName string, err error) {
var connection models.Connection var connection models.Connection
var lines []string var lines []string
if openVPNSettings.Config == "" { if openVPNSettings.Config == "" {
@@ -37,23 +40,25 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
lines, connection, err = custom.BuildConfig(openVPNSettings) lines, connection, err = custom.BuildConfig(openVPNSettings)
} }
if err != nil { if err != nil {
return "", fmt.Errorf("%w: %s", errBuildConfig, err) return nil, "", fmt.Errorf("%w: %s", errBuildConfig, err)
} }
if err := openvpnConf.WriteConfig(lines); err != nil { if err := openvpnConf.WriteConfig(lines); err != nil {
return "", fmt.Errorf("%w: %s", errWriteConfig, err) return nil, "", fmt.Errorf("%w: %s", errWriteConfig, err)
} }
if openVPNSettings.User != "" { if openVPNSettings.User != "" {
err := openvpnConf.WriteAuthFile(openVPNSettings.User, openVPNSettings.Password) err := openvpnConf.WriteAuthFile(openVPNSettings.User, openVPNSettings.Password)
if err != nil { if err != nil {
return "", fmt.Errorf("%w: %s", errWriteAuth, err) return nil, "", fmt.Errorf("%w: %s", errWriteAuth, err)
} }
} }
if err := fw.SetVPNConnection(ctx, connection); err != nil { if err := fw.SetVPNConnection(ctx, connection); err != nil {
return "", fmt.Errorf("%w: %s", errFirewall, err) return nil, "", fmt.Errorf("%w: %s", errFirewall, err)
} }
return connection.Hostname, nil runner = openvpn.NewRunner(openVPNSettings, starter, logger)
return runner, connection.Hostname, nil
} }

View File

@@ -12,6 +12,10 @@ type Runner interface {
Run(ctx context.Context, done chan<- struct{}) Run(ctx context.Context, done chan<- struct{})
} }
type vpnRunner interface {
Run(ctx context.Context, errCh chan<- error, ready chan<- struct{})
}
func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
defer close(done) defer close(done)
@@ -26,7 +30,10 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
providerConf := provider.New(providerSettings.Name, allServers, time.Now) providerConf := provider.New(providerSettings.Name, allServers, time.Now)
serverName, err := setupOpenVPN(ctx, l.fw, l.openvpnConf, providerConf, VPNSettings.OpenVPN, providerSettings) vpnRunner, serverName, err := setupOpenVPN(ctx, l.fw,
l.openvpnConf, providerConf,
VPNSettings.OpenVPN, providerSettings,
l.starter, l.logger)
if err != nil { if err != nil {
l.crashed(ctx, err) l.crashed(ctx, err)
continue continue
@@ -41,8 +48,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
waitError := make(chan error) waitError := make(chan error)
tunnelReady := make(chan struct{}) tunnelReady := make(chan struct{})
go l.openvpnConf.Run(openvpnCtx, waitError, tunnelReady, go vpnRunner.Run(openvpnCtx, waitError, tunnelReady)
l.logger, VPNSettings.OpenVPN)
if err := l.waitForError(ctx, waitError); err != nil { if err := l.waitForError(ctx, waitError); err != nil {
openvpnCancel() openvpnCancel()