diff --git a/internal/openvpn/helpers.go b/internal/openvpn/helpers.go new file mode 100644 index 00000000..e8c2f44b --- /dev/null +++ b/internal/openvpn/helpers.go @@ -0,0 +1,36 @@ +package openvpn + +import ( + "context" + "time" + + "github.com/qdm12/gluetun/internal/models" +) + +func (l *looper) signalOrSetStatus(status models.LoopStatus) { + if l.userTrigger { + l.userTrigger = false + select { + case l.running <- status: + default: // receiver calling ApplyStatus dropped out + } + } else { + l.statusManager.SetStatus(status) + } +} + +func (l *looper) logAndWait(ctx context.Context, err error) { + if err != nil { + l.logger.Error(err.Error()) + } + l.logger.Info("retrying in " + l.backoffTime.String()) + timer := time.NewTimer(l.backoffTime) + l.backoffTime *= 2 + select { + case <-timer.C: + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + } +} diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 966f1336..34d7a3ef 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -4,8 +4,6 @@ import ( "context" "net" "net/http" - "os" - "strings" "time" "github.com/qdm12/gluetun/internal/configuration" @@ -14,7 +12,6 @@ import ( "github.com/qdm12/gluetun/internal/loopstate" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/openvpn/state" - "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/golibs/logging" ) @@ -100,196 +97,3 @@ func NewLooper(settings configuration.OpenVPN, backoffTime: defaultBackoffTime, } } - -func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway } - -func (l *looper) signalOrSetStatus(status models.LoopStatus) { - if l.userTrigger { - l.userTrigger = false - select { - case l.running <- status: - default: // receiver calling ApplyStatus droppped out - } - } else { - l.statusManager.SetStatus(status) - } -} - -func (l *looper) Run(ctx context.Context, done chan<- struct{}) { - defer close(done) - - select { - case <-l.start: - case <-ctx.Done(): - return - } - - for ctx.Err() == nil { - settings, allServers := l.state.GetSettingsAndServers() - - providerConf := provider.New(settings.Provider.Name, allServers, time.Now) - - var connection models.OpenVPNConnection - var lines []string - var err error - if settings.Config == "" { - connection, err = providerConf.GetOpenVPNConnection(settings.Provider.ServerSelection) - if err != nil { - l.signalOrSetStatus(constants.Crashed) - l.logAndWait(ctx, err) - continue - } - lines = providerConf.BuildConf(connection, l.username, settings) - } else { - lines, connection, err = l.processCustomConfig(settings) - if err != nil { - l.signalOrSetStatus(constants.Crashed) - l.logAndWait(ctx, err) - continue - } - } - - if err := l.writeOpenvpnConf(lines); err != nil { - l.signalOrSetStatus(constants.Crashed) - l.logAndWait(ctx, err) - continue - } - - if settings.User != "" { - err := l.conf.WriteAuthFile( - settings.User, settings.Password, l.puid, l.pgid) - if err != nil { - l.signalOrSetStatus(constants.Crashed) - l.logAndWait(ctx, err) - continue - } - } - - if err := l.fw.SetVPNConnection(ctx, connection); err != nil { - l.signalOrSetStatus(constants.Crashed) - l.logAndWait(ctx, err) - continue - } - - openvpnCtx, openvpnCancel := context.WithCancel(context.Background()) - - stdoutLines, stderrLines, waitError, err := l.conf.Start( - openvpnCtx, settings.Version, settings.Flags) - if err != nil { - openvpnCancel() - l.signalOrSetStatus(constants.Crashed) - l.logAndWait(ctx, err) - continue - } - - lineCollectionDone := make(chan struct{}) - go l.collectLines(stdoutLines, stderrLines, lineCollectionDone) - closeStreams := func() { - close(stdoutLines) - close(stderrLines) - <-lineCollectionDone - } - - // Needs the stream line from main.go to know when the tunnel is up - portForwardDone := make(chan struct{}) - go func(ctx context.Context) { - defer close(portForwardDone) - select { - // TODO have a way to disable pf with a context - case <-ctx.Done(): - return - case gateway := <-l.portForwardSignals: - l.portForward(ctx, providerConf, l.client, gateway) - } - }(openvpnCtx) - - l.backoffTime = defaultBackoffTime - l.signalOrSetStatus(constants.Running) - - stayHere := true - for stayHere { - select { - case <-ctx.Done(): - openvpnCancel() - <-waitError - close(waitError) - closeStreams() - <-portForwardDone - return - case <-l.stop: - l.userTrigger = true - l.logger.Info("stopping") - openvpnCancel() - <-waitError - // do not close waitError or the waitError - // select case will trigger - closeStreams() - <-portForwardDone - l.stopped <- struct{}{} - case <-l.start: - l.userTrigger = true - l.logger.Info("starting") - stayHere = false - case err := <-waitError: // unexpected error - close(waitError) - closeStreams() - - l.statusManager.Lock() // prevent SetStatus from running in parallel - - openvpnCancel() - l.statusManager.SetStatus(constants.Crashed) - <-portForwardDone - l.logAndWait(ctx, err) - stayHere = false - - l.statusManager.Unlock() - } - } - openvpnCancel() - } -} - -func (l *looper) logAndWait(ctx context.Context, err error) { - if err != nil { - l.logger.Error(err.Error()) - } - l.logger.Info("retrying in " + l.backoffTime.String()) - timer := time.NewTimer(l.backoffTime) - l.backoffTime *= 2 - select { - case <-timer.C: - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C - } - } -} - -// portForward is a blocking operation which may or may not be infinite. -// You should therefore always call it in a goroutine. -func (l *looper) portForward(ctx context.Context, - providerConf provider.Provider, client *http.Client, gateway net.IP) { - settings := l.state.GetSettings() - if !settings.Provider.PortForwarding.Enabled { - return - } - syncState := func(port uint16) (pfFilepath string) { - l.state.SetPortForwarded(port) - settings := l.state.GetSettings() - return settings.Provider.PortForwarding.Filepath - } - providerConf.PortForward(ctx, client, l.pfLogger, - gateway, l.fw, syncState) -} - -func (l *looper) writeOpenvpnConf(lines []string) error { - file, err := os.OpenFile(l.targetConfPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) - if err != nil { - return err - } - _, err = file.WriteString(strings.Join(lines, "\n")) - if err != nil { - return err - } - return file.Close() -} diff --git a/internal/openvpn/portforwarded.go b/internal/openvpn/portforwarded.go index 0ce6ab80..0677473d 100644 --- a/internal/openvpn/portforwarded.go +++ b/internal/openvpn/portforwarded.go @@ -1,5 +1,46 @@ package openvpn +import ( + "context" + "net" + "net/http" + "os" + "strings" + + "github.com/qdm12/gluetun/internal/provider" +) + func (l *looper) GetPortForwarded() (port uint16) { return l.state.GetPortForwarded() } + +func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway } + +// portForward is a blocking operation which may or may not be infinite. +// You should therefore always call it in a goroutine. +func (l *looper) portForward(ctx context.Context, + providerConf provider.Provider, client *http.Client, gateway net.IP) { + settings := l.state.GetSettings() + if !settings.Provider.PortForwarding.Enabled { + return + } + syncState := func(port uint16) (pfFilepath string) { + l.state.SetPortForwarded(port) + settings := l.state.GetSettings() + return settings.Provider.PortForwarding.Filepath + } + providerConf.PortForward(ctx, client, l.pfLogger, + gateway, l.fw, syncState) +} + +func (l *looper) writeOpenvpnConf(lines []string) error { + file, err := os.OpenFile(l.targetConfPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + if err != nil { + return err + } + _, err = file.WriteString(strings.Join(lines, "\n")) + if err != nil { + return err + } + return file.Close() +} diff --git a/internal/openvpn/run.go b/internal/openvpn/run.go new file mode 100644 index 00000000..1311339f --- /dev/null +++ b/internal/openvpn/run.go @@ -0,0 +1,144 @@ +package openvpn + +import ( + "context" + "time" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/provider" +) + +func (l *looper) Run(ctx context.Context, done chan<- struct{}) { + defer close(done) + + select { + case <-l.start: + case <-ctx.Done(): + return + } + + for ctx.Err() == nil { + settings, allServers := l.state.GetSettingsAndServers() + + providerConf := provider.New(settings.Provider.Name, allServers, time.Now) + + var connection models.OpenVPNConnection + var lines []string + var err error + if settings.Config == "" { + connection, err = providerConf.GetOpenVPNConnection(settings.Provider.ServerSelection) + if err != nil { + l.signalOrSetStatus(constants.Crashed) + l.logAndWait(ctx, err) + continue + } + lines = providerConf.BuildConf(connection, l.username, settings) + } else { + lines, connection, err = l.processCustomConfig(settings) + if err != nil { + l.signalOrSetStatus(constants.Crashed) + l.logAndWait(ctx, err) + continue + } + } + + if err := l.writeOpenvpnConf(lines); err != nil { + l.signalOrSetStatus(constants.Crashed) + l.logAndWait(ctx, err) + continue + } + + if settings.User != "" { + err := l.conf.WriteAuthFile( + settings.User, settings.Password, l.puid, l.pgid) + if err != nil { + l.signalOrSetStatus(constants.Crashed) + l.logAndWait(ctx, err) + continue + } + } + + if err := l.fw.SetVPNConnection(ctx, connection); err != nil { + l.signalOrSetStatus(constants.Crashed) + l.logAndWait(ctx, err) + continue + } + + openvpnCtx, openvpnCancel := context.WithCancel(context.Background()) + + stdoutLines, stderrLines, waitError, err := l.conf.Start( + openvpnCtx, settings.Version, settings.Flags) + if err != nil { + openvpnCancel() + l.signalOrSetStatus(constants.Crashed) + l.logAndWait(ctx, err) + continue + } + + lineCollectionDone := make(chan struct{}) + go l.collectLines(stdoutLines, stderrLines, lineCollectionDone) + closeStreams := func() { + close(stdoutLines) + close(stderrLines) + <-lineCollectionDone + } + + // Needs the stream line from main.go to know when the tunnel is up + portForwardDone := make(chan struct{}) + go func(ctx context.Context) { + defer close(portForwardDone) + select { + // TODO have a way to disable pf with a context + case <-ctx.Done(): + return + case gateway := <-l.portForwardSignals: + l.portForward(ctx, providerConf, l.client, gateway) + } + }(openvpnCtx) + + l.backoffTime = defaultBackoffTime + l.signalOrSetStatus(constants.Running) + + stayHere := true + for stayHere { + select { + case <-ctx.Done(): + openvpnCancel() + <-waitError + close(waitError) + closeStreams() + <-portForwardDone + return + case <-l.stop: + l.userTrigger = true + l.logger.Info("stopping") + openvpnCancel() + <-waitError + // do not close waitError or the waitError + // select case will trigger + closeStreams() + <-portForwardDone + l.stopped <- struct{}{} + case <-l.start: + l.userTrigger = true + l.logger.Info("starting") + stayHere = false + case err := <-waitError: // unexpected error + close(waitError) + closeStreams() + + l.statusManager.Lock() // prevent SetStatus from running in parallel + + openvpnCancel() + l.statusManager.SetStatus(constants.Crashed) + <-portForwardDone + l.logAndWait(ctx, err) + stayHere = false + + l.statusManager.Unlock() + } + } + openvpnCancel() + } +}