diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 562fceda..cfe9f956 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -384,7 +384,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, // Start openvpn for the first time in a blocking call // until openvpn is launched - _, _ = openvpnLooper.SetStatus(ctx, constants.Running) // TODO option to disable with variable + _, _ = openvpnLooper.ApplyStatus(ctx, constants.Running) // TODO option to disable with variable <-ctx.Done() diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 48ccf192..d233355a 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -5,7 +5,6 @@ import ( "net" "net/http" "strings" - "sync" "time" "github.com/qdm12/gluetun/internal/configuration" @@ -21,7 +20,7 @@ import ( type Looper interface { Run(ctx context.Context, done chan<- struct{}) GetStatus() (status models.LoopStatus) - SetStatus(ctx context.Context, status models.LoopStatus) ( + ApplyStatus(ctx context.Context, status models.LoopStatus) ( outcome string, err error) GetSettings() (settings configuration.OpenVPN) SetSettings(ctx context.Context, settings configuration.OpenVPN) ( @@ -33,7 +32,7 @@ type Looper interface { } type looper struct { - state state + state *state // Fixed parameters username string puid int @@ -48,15 +47,16 @@ type looper struct { openFile os.OpenFileFunc tunnelReady chan<- struct{} healthy <-chan bool - // Internal channels and locks - loopLock sync.Mutex - running chan models.LoopStatus - stop, stopped chan struct{} - start chan struct{} + // Internal channels and values + stop <-chan struct{} + stopped chan<- struct{} + start <-chan struct{} + running chan<- models.LoopStatus portForwardSignals chan net.IP - crashed bool - backoffTime time.Duration - healthWaitTime time.Duration + userTrigger bool + // Internal constant values + backoffTime time.Duration + healthWaitTime time.Duration } const ( @@ -69,12 +69,16 @@ func NewLooper(settings configuration.OpenVPN, conf Configurator, fw firewall.Configurator, routing routing.Routing, logger logging.ParentLogger, client *http.Client, openFile os.OpenFileFunc, tunnelReady chan<- struct{}, healthy <-chan bool) Looper { + start := make(chan struct{}) + running := make(chan models.LoopStatus) + stop := make(chan struct{}) + stopped := make(chan struct{}) + + state := newState(constants.Stopped, settings, allServers, + start, running, stop, stopped) + return &looper{ - state: state{ - status: constants.Stopped, - settings: settings, - allServers: allServers, - }, + state: state, username: username, puid: puid, pgid: pgid, @@ -87,11 +91,12 @@ func NewLooper(settings configuration.OpenVPN, openFile: openFile, tunnelReady: tunnelReady, healthy: healthy, - start: make(chan struct{}), - running: make(chan models.LoopStatus), - stop: make(chan struct{}), - stopped: make(chan struct{}), + start: start, + running: running, + stop: stop, + stopped: stopped, portForwardSignals: make(chan net.IP), + userTrigger: true, backoffTime: defaultBackoffTime, healthWaitTime: defaultHealthWaitTime, } @@ -99,15 +104,21 @@ func NewLooper(settings configuration.OpenVPN, func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway } -func (l *looper) signalCrashedStatus() { - if !l.crashed { - l.crashed = true - l.running <- constants.Crashed +func (l *looper) signalOrSetStatus(status models.LoopStatus) { + if l.userTrigger { + l.userTrigger = false + select { + case l.running <- status: + default: // receiver calling ApplyStatus droppped out + } + } else { + l.state.SetStatus(status) } } func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocognit defer close(done) + select { case <-l.start: case <-ctx.Done(): @@ -115,17 +126,17 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocog } for ctx.Err() == nil { - settings, allServers := l.state.getSettingsAndServers() + settings, allServers := l.state.GetSettingsAndServers() providerConf := provider.New(settings.Provider.Name, allServers, time.Now) var connection models.OpenVPNConnection var lines []string var err error - if len(settings.Config) == 0 { + if settings.Config == "" { connection, err = providerConf.GetOpenVPNConnection(settings.Provider.ServerSelection) if err != nil { - l.signalCrashedStatus() + l.signalOrSetStatus(constants.Crashed) l.logAndWait(ctx, err) continue } @@ -133,28 +144,30 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocog } else { lines, connection, err = l.processCustomConfig(settings) if err != nil { - l.signalCrashedStatus() + l.signalOrSetStatus(constants.Crashed) l.logAndWait(ctx, err) continue } } if err := writeOpenvpnConf(lines, l.openFile); err != nil { - l.signalCrashedStatus() + l.signalOrSetStatus(constants.Crashed) l.logAndWait(ctx, err) continue } if settings.User != "" { - if err := l.conf.WriteAuthFile(settings.User, settings.Password, l.puid, l.pgid); err != nil { - l.signalCrashedStatus() + err := l.conf.WriteAuthFile( + settings.User, settings.Password, l.puid, l.pgid) + if err != nil { + l.signalOrSetStatus(constants.Crashed) l.logAndWait(ctx, err) continue } } if err := l.fw.SetVPNConnection(ctx, connection); err != nil { - l.signalCrashedStatus() + l.signalOrSetStatus(constants.Crashed) l.logAndWait(ctx, err) continue } @@ -164,7 +177,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocog stdoutLines, stderrLines, waitError, err := l.conf.Start(openvpnCtx, settings.Version) if err != nil { openvpnCancel() - l.signalCrashedStatus() + l.signalOrSetStatus(constants.Crashed) l.logAndWait(ctx, err) continue } @@ -190,13 +203,8 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocog } }(openvpnCtx) - if l.crashed { - l.crashed = false - l.backoffTime = defaultBackoffTime - l.state.setStatusWithLock(constants.Running) - } else { - l.running <- constants.Running - } + l.backoffTime = defaultBackoffTime + l.signalOrSetStatus(constants.Running) stayHere := true for stayHere { @@ -209,6 +217,7 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocog <-portForwardDone return case <-l.stop: + l.userTrigger = true l.logger.Info("stopping") openvpnCancel() <-waitError @@ -218,17 +227,22 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocog <-portForwardDone l.stopped <- struct{}{} case <-l.start: + l.userTrigger = true l.logger.Info("starting") stayHere = false case err := <-waitError: // unexpected error - openvpnCancel() close(waitError) closeStreams() + + l.state.Lock() // prevent SetStatus from running in parallel + + openvpnCancel() + l.state.SetStatus(constants.Crashed) <-portForwardDone - l.state.setStatusWithLock(constants.Crashed) l.logAndWait(ctx, err) - l.crashed = true stayHere = false + + l.state.Unlock() case healthy := <-l.healthy: if healthy { continue @@ -238,19 +252,19 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { //nolint:gocog if healthy || ctx.Err() != nil { continue } - l.crashed = true // flag as crashed - l.state.setStatusWithLock(constants.Stopping) + l.logger.Warn("unhealthy program: restarting openvpn") + l.state.SetStatus(constants.Stopping) openvpnCancel() <-waitError close(waitError) closeStreams() <-portForwardDone - l.state.setStatusWithLock(constants.Stopped) + l.state.SetStatus(constants.Stopped) stayHere = false } } - openvpnCancel() // just for the linter + openvpnCancel() } } @@ -258,7 +272,7 @@ func (l *looper) logAndWait(ctx context.Context, err error) { if err != nil { l.logger.Error(err) } - l.logger.Info("retrying in %s", l.backoffTime) + l.logger.Info("retrying in " + l.backoffTime.String()) timer := time.NewTimer(l.backoffTime) l.backoffTime *= 2 select { @@ -333,3 +347,27 @@ func writeOpenvpnConf(lines []string, openFile os.OpenFileFunc) error { } return file.Close() } + +func (l *looper) GetStatus() (status models.LoopStatus) { + return l.state.GetStatus() +} +func (l *looper) ApplyStatus(ctx context.Context, status models.LoopStatus) ( + outcome string, err error) { + return l.state.ApplyStatus(ctx, status) +} +func (l *looper) GetSettings() (settings configuration.OpenVPN) { + return l.state.GetSettings() +} +func (l *looper) SetSettings(ctx context.Context, settings configuration.OpenVPN) ( + outcome string) { + return l.state.SetSettings(ctx, settings) +} +func (l *looper) GetServers() (servers models.AllServers) { + return l.state.GetServers() +} +func (l *looper) SetServers(servers models.AllServers) { + l.state.SetServers(servers) +} +func (l *looper) GetPortForwarded() (port uint16) { + return l.state.GetPortForwarded() +} diff --git a/internal/openvpn/state.go b/internal/openvpn/state.go index 9bce13ed..6774f829 100644 --- a/internal/openvpn/state.go +++ b/internal/openvpn/state.go @@ -12,24 +12,63 @@ import ( "github.com/qdm12/gluetun/internal/models" ) -type state struct { - status models.LoopStatus - settings configuration.OpenVPN - allServers models.AllServers - portForwarded uint16 - statusMu sync.RWMutex - settingsMu sync.RWMutex - allServersMu sync.RWMutex - portForwardedMu sync.RWMutex +func newState(status models.LoopStatus, + settings configuration.OpenVPN, allServers models.AllServers, + start chan<- struct{}, running <-chan models.LoopStatus, + stop chan<- struct{}, stopped <-chan struct{}) *state { + return &state{ + status: status, + settings: settings, + allServers: allServers, + start: start, + running: running, + stop: stop, + stopped: stopped, + } } -func (s *state) setStatusWithLock(status models.LoopStatus) { +type state struct { + loopMu sync.RWMutex + + status models.LoopStatus + statusMu sync.RWMutex + + settings configuration.OpenVPN + settingsMu sync.RWMutex + + allServers models.AllServers + allServersMu sync.RWMutex + + portForwarded uint16 + portForwardedMu sync.RWMutex + + start chan<- struct{} + running <-chan models.LoopStatus + stop chan<- struct{} + stopped <-chan struct{} +} + +func (s *state) Lock() { s.loopMu.Lock() } +func (s *state) Unlock() { s.loopMu.Unlock() } + +// SetStatus sets the status thread safely. +// It should only be called by the loop internal code since +// it does not interact with the loop code directly. +func (s *state) SetStatus(status models.LoopStatus) { s.statusMu.Lock() defer s.statusMu.Unlock() s.status = status } -func (s *state) getSettingsAndServers() (settings configuration.OpenVPN, allServers models.AllServers) { +// GetStatus gets the status thread safely. +func (s *state) GetStatus() (status models.LoopStatus) { + s.statusMu.RLock() + defer s.statusMu.RUnlock() + return s.status +} + +func (s *state) GetSettingsAndServers() (settings configuration.OpenVPN, + allServers models.AllServers) { s.settingsMu.RLock() s.allServersMu.RLock() settings = s.settings @@ -39,100 +78,102 @@ func (s *state) getSettingsAndServers() (settings configuration.OpenVPN, allServ return settings, allServers } -func (l *looper) GetStatus() (status models.LoopStatus) { - l.state.statusMu.RLock() - defer l.state.statusMu.RUnlock() - return l.state.status -} - var ErrInvalidStatus = errors.New("invalid status") -func (l *looper) SetStatus(ctx context.Context, status models.LoopStatus) ( +// ApplyStatus sends signals to the running loop depending on the +// current status and status requested, such that its next status +// matches the requested one. It is thread safe and a synchronous call +// since it waits to the loop to fully change its status. +func (s *state) ApplyStatus(ctx context.Context, status models.LoopStatus) ( outcome string, err error) { - l.state.statusMu.Lock() - defer l.state.statusMu.Unlock() - existingStatus := l.state.status + // prevent simultaneous loop changes by restricting + // multiple SetStatus calls to run sequentially. + s.loopMu.Lock() + defer s.loopMu.Unlock() + + // not a read lock as we want to modify it eventually in + // the code below before any other call. + s.statusMu.Lock() + existingStatus := s.status switch status { case constants.Running: - switch existingStatus { - case constants.Starting, constants.Running, constants.Stopping, constants.Crashed: - return fmt.Sprintf("already %s", existingStatus), nil + if existingStatus != constants.Stopped { + return "already " + existingStatus.String(), nil } - l.loopLock.Lock() - defer l.loopLock.Unlock() - l.state.status = constants.Starting - l.state.statusMu.Unlock() - l.start <- struct{}{} + s.status = constants.Starting + s.statusMu.Unlock() + s.start <- struct{}{} + + // Wait for the loop to react to the start signal newStatus := constants.Starting // for canceled context select { case <-ctx.Done(): - case newStatus = <-l.running: + case newStatus = <-s.running: } - l.state.statusMu.Lock() - l.state.status = newStatus + s.SetStatus(newStatus) + return newStatus.String(), nil case constants.Stopped: - switch existingStatus { - case constants.Starting, constants.Stopping, constants.Stopped, constants.Crashed: - return fmt.Sprintf("already %s", existingStatus), nil + if existingStatus != constants.Running { + return "already " + existingStatus.String(), nil } - l.loopLock.Lock() - defer l.loopLock.Unlock() - l.state.status = constants.Stopping - l.state.statusMu.Unlock() - l.stop <- struct{}{} + s.status = constants.Stopping + s.statusMu.Unlock() + s.stop <- struct{}{} + + // Wait for the loop to react to the stop signal newStatus := constants.Stopping // for canceled context select { case <-ctx.Done(): - case <-l.stopped: + case <-s.stopped: newStatus = constants.Stopped } - l.state.statusMu.Lock() - l.state.status = newStatus - return status.String(), nil + s.SetStatus(newStatus) + + return newStatus.String(), nil default: return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s", ErrInvalidStatus, status, constants.Running, constants.Stopped) } } -func (l *looper) GetSettings() (settings configuration.OpenVPN) { - l.state.settingsMu.RLock() - defer l.state.settingsMu.RUnlock() - return l.state.settings +func (s *state) GetSettings() (settings configuration.OpenVPN) { + s.settingsMu.RLock() + defer s.settingsMu.RUnlock() + return s.settings } -func (l *looper) SetSettings(ctx context.Context, settings configuration.OpenVPN) ( +func (s *state) SetSettings(ctx context.Context, settings configuration.OpenVPN) ( outcome string) { - l.state.settingsMu.Lock() - settingsUnchanged := reflect.DeepEqual(l.state.settings, settings) + s.settingsMu.Lock() + defer s.settingsMu.Unlock() + settingsUnchanged := reflect.DeepEqual(s.settings, settings) if settingsUnchanged { - l.state.settingsMu.Unlock() return "settings left unchanged" } - l.state.settings = settings - _, _ = l.SetStatus(ctx, constants.Stopped) - outcome, _ = l.SetStatus(ctx, constants.Running) + s.settings = settings + _, _ = s.ApplyStatus(ctx, constants.Stopped) + outcome, _ = s.ApplyStatus(ctx, constants.Running) return outcome } -func (l *looper) GetServers() (servers models.AllServers) { - l.state.allServersMu.RLock() - defer l.state.allServersMu.RUnlock() - return l.state.allServers +func (s *state) GetServers() (servers models.AllServers) { + s.allServersMu.RLock() + defer s.allServersMu.RUnlock() + return s.allServers } -func (l *looper) SetServers(servers models.AllServers) { - l.state.allServersMu.Lock() - defer l.state.allServersMu.Unlock() - l.state.allServers = servers +func (s *state) SetServers(servers models.AllServers) { + s.allServersMu.Lock() + defer s.allServersMu.Unlock() + s.allServers = servers } -func (l *looper) GetPortForwarded() (port uint16) { - l.state.portForwardedMu.RLock() - defer l.state.portForwardedMu.RUnlock() - return l.state.portForwarded +func (s *state) GetPortForwarded() (port uint16) { + s.portForwardedMu.RLock() + defer s.portForwardedMu.RUnlock() + return s.portForwarded } diff --git a/internal/server/handlerv0.go b/internal/server/handlerv0.go index d180c613..ce7fe9a2 100644 --- a/internal/server/handlerv0.go +++ b/internal/server/handlerv0.go @@ -39,9 +39,9 @@ func (h *handlerV0) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "/version": http.Redirect(w, r, "/v1/version", http.StatusPermanentRedirect) case "/openvpn/actions/restart": - outcome, _ := h.openvpn.SetStatus(h.ctx, constants.Stopped) + outcome, _ := h.openvpn.ApplyStatus(h.ctx, constants.Stopped) h.logger.Info("openvpn: %s", outcome) - outcome, _ = h.openvpn.SetStatus(h.ctx, constants.Running) + outcome, _ = h.openvpn.ApplyStatus(h.ctx, constants.Running) h.logger.Info("openvpn: %s", outcome) if _, err := w.Write([]byte("openvpn restarted, please consider using the /v1/ API in the future.")); err != nil { h.logger.Warn(err) diff --git a/internal/server/openvpn.go b/internal/server/openvpn.go index 60996ded..cc24ad02 100644 --- a/internal/server/openvpn.go +++ b/internal/server/openvpn.go @@ -79,7 +79,7 @@ func (h *openvpnHandler) setStatus(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } - outcome, err := h.looper.SetStatus(h.ctx, status) + outcome, err := h.looper.ApplyStatus(h.ctx, status) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return