diff --git a/internal/httpproxy/state/state.go b/internal/httpproxy/state/state.go new file mode 100644 index 00000000..73240b36 --- /dev/null +++ b/internal/httpproxy/state/state.go @@ -0,0 +1,28 @@ +package state + +import ( + "sync" + + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/loopstate" +) + +var _ Manager = (*State)(nil) + +type Manager interface { + SettingsGetterSetter +} + +func New(statusApplier loopstate.Applier, + settings configuration.HTTPProxy) *State { + return &State{ + statusApplier: statusApplier, + settings: settings, + } +} + +type State struct { + statusApplier loopstate.Applier + settings configuration.HTTPProxy + settingsMu sync.RWMutex +} diff --git a/internal/loopstate/apply.go b/internal/loopstate/apply.go new file mode 100644 index 00000000..a3f59dc3 --- /dev/null +++ b/internal/loopstate/apply.go @@ -0,0 +1,77 @@ +package loopstate + +import ( + "context" + "errors" + "fmt" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/models" +) + +type Applier interface { + ApplyStatus(ctx context.Context, status models.LoopStatus) ( + outcome string, err error) +} + +var ErrInvalidStatus = errors.New("invalid status") + +// ApplyStatus sends signals to the running loop depending on the +// current status and status requested, such that its next status +// matches the requested one. It is thread safe and a synchronous call +// since it waits to the loop to fully change its status. +func (s *State) ApplyStatus(ctx context.Context, status models.LoopStatus) ( + outcome string, err error) { + // prevent simultaneous loop changes by restricting + // multiple ApplyStatus calls to run sequentially. + s.loopMu.Lock() + defer s.loopMu.Unlock() + + // not a read lock as we want to modify it eventually in + // the code below before any other call. + s.statusMu.Lock() + existingStatus := s.status + + switch status { + case constants.Running: + if existingStatus != constants.Stopped { + return "already " + existingStatus.String(), nil + } + + s.status = constants.Starting + s.statusMu.Unlock() + s.start <- struct{}{} + + // Wait for the loop to react to the start signal + newStatus := constants.Starting // for canceled context + select { + case <-ctx.Done(): + case newStatus = <-s.running: + } + s.SetStatus(newStatus) + + return newStatus.String(), nil + case constants.Stopped: + if existingStatus != constants.Running { + return "already " + existingStatus.String(), nil + } + + s.status = constants.Stopping + s.statusMu.Unlock() + s.stop <- struct{}{} + + // Wait for the loop to react to the stop signal + newStatus := constants.Stopping // for canceled context + select { + case <-ctx.Done(): + case <-s.stopped: + newStatus = constants.Stopped + } + s.SetStatus(newStatus) + + return newStatus.String(), nil + default: + return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s", + ErrInvalidStatus, status, constants.Running, constants.Stopped) + } +} diff --git a/internal/loopstate/get.go b/internal/loopstate/get.go new file mode 100644 index 00000000..fb675dfa --- /dev/null +++ b/internal/loopstate/get.go @@ -0,0 +1,14 @@ +package loopstate + +import "github.com/qdm12/gluetun/internal/models" + +type Getter interface { + GetStatus() (status models.LoopStatus) +} + +// GetStatus gets the status thread safely. +func (s *State) GetStatus() (status models.LoopStatus) { + s.statusMu.RLock() + defer s.statusMu.RUnlock() + return s.status +} diff --git a/internal/loopstate/lock.go b/internal/loopstate/lock.go new file mode 100644 index 00000000..3907fef4 --- /dev/null +++ b/internal/loopstate/lock.go @@ -0,0 +1,9 @@ +package loopstate + +type Locker interface { + Lock() + Unlock() +} + +func (s *State) Lock() { s.loopMu.Lock() } +func (s *State) Unlock() { s.loopMu.Unlock() } diff --git a/internal/loopstate/set.go b/internal/loopstate/set.go new file mode 100644 index 00000000..69f52b00 --- /dev/null +++ b/internal/loopstate/set.go @@ -0,0 +1,16 @@ +package loopstate + +import "github.com/qdm12/gluetun/internal/models" + +type Setter interface { + SetStatus(status models.LoopStatus) +} + +// SetStatus sets the status thread safely. +// It should only be called by the loop internal code since +// it does not interact with the loop code directly. +func (s *State) SetStatus(status models.LoopStatus) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.status = status +} diff --git a/internal/loopstate/state.go b/internal/loopstate/state.go new file mode 100644 index 00000000..2a8384bb --- /dev/null +++ b/internal/loopstate/state.go @@ -0,0 +1,38 @@ +package loopstate + +import ( + "sync" + + "github.com/qdm12/gluetun/internal/models" +) + +type Manager interface { + Locker + Getter + Setter + Applier +} + +func New(status models.LoopStatus, + start chan<- struct{}, running <-chan models.LoopStatus, + stop chan<- struct{}, stopped <-chan struct{}) *State { + return &State{ + status: status, + start: start, + running: running, + stop: stop, + stopped: stopped, + } +} + +type State struct { + loopMu sync.RWMutex + + status models.LoopStatus + statusMu sync.RWMutex + + start chan<- struct{} + running <-chan models.LoopStatus + stop chan<- struct{} + stopped <-chan struct{} +} diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 2cadc585..966f1336 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -11,7 +11,9 @@ import ( "github.com/qdm12/gluetun/internal/configuration" "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/firewall" + "github.com/qdm12/gluetun/internal/loopstate" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/openvpn/state" "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/golibs/logging" @@ -32,7 +34,8 @@ type Looper interface { } type looper struct { - state *state + statusManager loopstate.Manager + state state.Manager // Fixed parameters username string puid int @@ -71,10 +74,11 @@ func NewLooper(settings configuration.OpenVPN, stop := make(chan struct{}) stopped := make(chan struct{}) - state := newState(constants.Stopped, settings, allServers, - start, running, stop, stopped) + statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped) + state := state.New(statusManager, settings, allServers) return &looper{ + statusManager: statusManager, state: state, username: username, puid: puid, @@ -107,7 +111,7 @@ func (l *looper) signalOrSetStatus(status models.LoopStatus) { default: // receiver calling ApplyStatus droppped out } } else { - l.state.SetStatus(status) + l.statusManager.SetStatus(status) } } @@ -230,15 +234,15 @@ func (l *looper) Run(ctx context.Context, done chan<- struct{}) { close(waitError) closeStreams() - l.state.Lock() // prevent SetStatus from running in parallel + l.statusManager.Lock() // prevent SetStatus from running in parallel openvpnCancel() - l.state.SetStatus(constants.Crashed) + l.statusManager.SetStatus(constants.Crashed) <-portForwardDone l.logAndWait(ctx, err) stayHere = false - l.state.Unlock() + l.statusManager.Unlock() } } openvpnCancel() @@ -265,18 +269,13 @@ func (l *looper) logAndWait(ctx context.Context, err error) { // You should therefore always call it in a goroutine. func (l *looper) portForward(ctx context.Context, providerConf provider.Provider, client *http.Client, gateway net.IP) { - l.state.portForwardedMu.RLock() - settings := l.state.settings - l.state.portForwardedMu.RUnlock() + settings := l.state.GetSettings() if !settings.Provider.PortForwarding.Enabled { return } syncState := func(port uint16) (pfFilepath string) { - l.state.portForwardedMu.Lock() - defer l.state.portForwardedMu.Unlock() - l.state.portForwarded = port - l.state.settingsMu.RLock() - defer l.state.settingsMu.RUnlock() + l.state.SetPortForwarded(port) + settings := l.state.GetSettings() return settings.Provider.PortForwarding.Filepath } providerConf.PortForward(ctx, client, l.pfLogger, @@ -294,27 +293,3 @@ func (l *looper) writeOpenvpnConf(lines []string) error { } return file.Close() } - -func (l *looper) GetStatus() (status models.LoopStatus) { - return l.state.GetStatus() -} -func (l *looper) ApplyStatus(ctx context.Context, status models.LoopStatus) ( - outcome string, err error) { - return l.state.ApplyStatus(ctx, status) -} -func (l *looper) GetSettings() (settings configuration.OpenVPN) { - return l.state.GetSettings() -} -func (l *looper) SetSettings(ctx context.Context, settings configuration.OpenVPN) ( - outcome string) { - return l.state.SetSettings(ctx, settings) -} -func (l *looper) GetServers() (servers models.AllServers) { - return l.state.GetServers() -} -func (l *looper) SetServers(servers models.AllServers) { - l.state.SetServers(servers) -} -func (l *looper) GetPortForwarded() (port uint16) { - return l.state.GetPortForwarded() -} diff --git a/internal/openvpn/portforwarded.go b/internal/openvpn/portforwarded.go new file mode 100644 index 00000000..0ce6ab80 --- /dev/null +++ b/internal/openvpn/portforwarded.go @@ -0,0 +1,5 @@ +package openvpn + +func (l *looper) GetPortForwarded() (port uint16) { + return l.state.GetPortForwarded() +} diff --git a/internal/openvpn/servers.go b/internal/openvpn/servers.go new file mode 100644 index 00000000..00c28fe0 --- /dev/null +++ b/internal/openvpn/servers.go @@ -0,0 +1,11 @@ +package openvpn + +import "github.com/qdm12/gluetun/internal/models" + +func (l *looper) GetServers() (servers models.AllServers) { + return l.state.GetServers() +} + +func (l *looper) SetServers(servers models.AllServers) { + l.state.SetServers(servers) +} diff --git a/internal/openvpn/settings.go b/internal/openvpn/settings.go new file mode 100644 index 00000000..890a3f62 --- /dev/null +++ b/internal/openvpn/settings.go @@ -0,0 +1,16 @@ +package openvpn + +import ( + "context" + + "github.com/qdm12/gluetun/internal/configuration" +) + +func (l *looper) GetSettings() (settings configuration.OpenVPN) { + return l.state.GetSettings() +} + +func (l *looper) SetSettings(ctx context.Context, settings configuration.OpenVPN) ( + outcome string) { + return l.state.SetSettings(ctx, settings) +} diff --git a/internal/openvpn/state.go b/internal/openvpn/state.go deleted file mode 100644 index 47e7dfc9..00000000 --- a/internal/openvpn/state.go +++ /dev/null @@ -1,179 +0,0 @@ -package openvpn - -import ( - "context" - "errors" - "fmt" - "reflect" - "sync" - - "github.com/qdm12/gluetun/internal/configuration" - "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/gluetun/internal/models" -) - -func newState(status models.LoopStatus, - settings configuration.OpenVPN, allServers models.AllServers, - start chan<- struct{}, running <-chan models.LoopStatus, - stop chan<- struct{}, stopped <-chan struct{}) *state { - return &state{ - status: status, - settings: settings, - allServers: allServers, - start: start, - running: running, - stop: stop, - stopped: stopped, - } -} - -type state struct { - loopMu sync.RWMutex - - status models.LoopStatus - statusMu sync.RWMutex - - settings configuration.OpenVPN - settingsMu sync.RWMutex - - allServers models.AllServers - allServersMu sync.RWMutex - - portForwarded uint16 - portForwardedMu sync.RWMutex - - start chan<- struct{} - running <-chan models.LoopStatus - stop chan<- struct{} - stopped <-chan struct{} -} - -func (s *state) Lock() { s.loopMu.Lock() } -func (s *state) Unlock() { s.loopMu.Unlock() } - -// SetStatus sets the status thread safely. -// It should only be called by the loop internal code since -// it does not interact with the loop code directly. -func (s *state) SetStatus(status models.LoopStatus) { - s.statusMu.Lock() - defer s.statusMu.Unlock() - s.status = status -} - -// GetStatus gets the status thread safely. -func (s *state) GetStatus() (status models.LoopStatus) { - s.statusMu.RLock() - defer s.statusMu.RUnlock() - return s.status -} - -func (s *state) GetSettingsAndServers() (settings configuration.OpenVPN, - allServers models.AllServers) { - s.settingsMu.RLock() - s.allServersMu.RLock() - settings = s.settings - allServers = s.allServers - s.settingsMu.RUnlock() - s.allServersMu.RUnlock() - return settings, allServers -} - -var ErrInvalidStatus = errors.New("invalid status") - -// ApplyStatus sends signals to the running loop depending on the -// current status and status requested, such that its next status -// matches the requested one. It is thread safe and a synchronous call -// since it waits to the loop to fully change its status. -func (s *state) ApplyStatus(ctx context.Context, status models.LoopStatus) ( - outcome string, err error) { - // prevent simultaneous loop changes by restricting - // multiple SetStatus calls to run sequentially. - s.loopMu.Lock() - defer s.loopMu.Unlock() - - // not a read lock as we want to modify it eventually in - // the code below before any other call. - s.statusMu.Lock() - existingStatus := s.status - - switch status { - case constants.Running: - if existingStatus != constants.Stopped { - return "already " + existingStatus.String(), nil - } - - s.status = constants.Starting - s.statusMu.Unlock() - s.start <- struct{}{} - - // Wait for the loop to react to the start signal - newStatus := constants.Starting // for canceled context - select { - case <-ctx.Done(): - case newStatus = <-s.running: - } - s.SetStatus(newStatus) - - return newStatus.String(), nil - case constants.Stopped: - if existingStatus != constants.Running { - return "already " + existingStatus.String(), nil - } - - s.status = constants.Stopping - s.statusMu.Unlock() - s.stop <- struct{}{} - - // Wait for the loop to react to the stop signal - newStatus := constants.Stopping // for canceled context - select { - case <-ctx.Done(): - case <-s.stopped: - newStatus = constants.Stopped - } - s.SetStatus(newStatus) - - return newStatus.String(), nil - default: - return "", fmt.Errorf("%w: %s: it can only be one of: %s, %s", - ErrInvalidStatus, status, constants.Running, constants.Stopped) - } -} - -func (s *state) GetSettings() (settings configuration.OpenVPN) { - s.settingsMu.RLock() - defer s.settingsMu.RUnlock() - return s.settings -} - -func (s *state) SetSettings(ctx context.Context, settings configuration.OpenVPN) ( - outcome string) { - s.settingsMu.Lock() - defer s.settingsMu.Unlock() - settingsUnchanged := reflect.DeepEqual(s.settings, settings) - if settingsUnchanged { - return "settings left unchanged" - } - s.settings = settings - _, _ = s.ApplyStatus(ctx, constants.Stopped) - outcome, _ = s.ApplyStatus(ctx, constants.Running) - return outcome -} - -func (s *state) GetServers() (servers models.AllServers) { - s.allServersMu.RLock() - defer s.allServersMu.RUnlock() - return s.allServers -} - -func (s *state) SetServers(servers models.AllServers) { - s.allServersMu.Lock() - defer s.allServersMu.Unlock() - s.allServers = servers -} - -func (s *state) GetPortForwarded() (port uint16) { - s.portForwardedMu.RLock() - defer s.portForwardedMu.RUnlock() - return s.portForwarded -} diff --git a/internal/openvpn/state/portforwarded.go b/internal/openvpn/state/portforwarded.go new file mode 100644 index 00000000..9e7e6314 --- /dev/null +++ b/internal/openvpn/state/portforwarded.go @@ -0,0 +1,22 @@ +package state + +type PortForwardedGetterSetter interface { + GetPortForwarded() (port uint16) + SetPortForwarded(port uint16) +} + +// GetPortForwarded is used by the control HTTP server +// to obtain the port currently forwarded. +func (s *State) GetPortForwarded() (port uint16) { + s.portForwardedMu.RLock() + defer s.portForwardedMu.RUnlock() + return s.portForwarded +} + +// SetPortForwarded is only used from within the OpenVPN loop +// to set the port forwarded. +func (s *State) SetPortForwarded(port uint16) { + s.portForwardedMu.Lock() + defer s.portForwardedMu.Unlock() + s.portForwarded = port +} diff --git a/internal/openvpn/state/servers.go b/internal/openvpn/state/servers.go new file mode 100644 index 00000000..547312fe --- /dev/null +++ b/internal/openvpn/state/servers.go @@ -0,0 +1,20 @@ +package state + +import "github.com/qdm12/gluetun/internal/models" + +type ServersGetterSetter interface { + GetServers() (servers models.AllServers) + SetServers(servers models.AllServers) +} + +func (s *State) GetServers() (servers models.AllServers) { + s.allServersMu.RLock() + defer s.allServersMu.RUnlock() + return s.allServers +} + +func (s *State) SetServers(servers models.AllServers) { + s.allServersMu.Lock() + defer s.allServersMu.Unlock() + s.allServers = servers +} diff --git a/internal/openvpn/state/settings.go b/internal/openvpn/state/settings.go new file mode 100644 index 00000000..1f9df456 --- /dev/null +++ b/internal/openvpn/state/settings.go @@ -0,0 +1,35 @@ +package state + +import ( + "context" + "reflect" + + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/constants" +) + +type SettingsGetterSetter interface { + GetSettings() (settings configuration.OpenVPN) + SetSettings(ctx context.Context, settings configuration.OpenVPN) ( + outcome string) +} + +func (s *State) GetSettings() (settings configuration.OpenVPN) { + s.settingsMu.RLock() + defer s.settingsMu.RUnlock() + return s.settings +} + +func (s *State) SetSettings(ctx context.Context, settings configuration.OpenVPN) ( + outcome string) { + s.settingsMu.Lock() + defer s.settingsMu.Unlock() + settingsUnchanged := reflect.DeepEqual(s.settings, settings) + if settingsUnchanged { + return "settings left unchanged" + } + s.settings = settings + _, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped) + outcome, _ = s.statusApplier.ApplyStatus(ctx, constants.Running) + return outcome +} diff --git a/internal/openvpn/state/state.go b/internal/openvpn/state/state.go new file mode 100644 index 00000000..1b2fb257 --- /dev/null +++ b/internal/openvpn/state/state.go @@ -0,0 +1,53 @@ +package state + +import ( + "sync" + + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/loopstate" + "github.com/qdm12/gluetun/internal/models" +) + +var _ Manager = (*State)(nil) + +type Manager interface { + SettingsGetterSetter + ServersGetterSetter + PortForwardedGetterSetter + GetSettingsAndServers() (settings configuration.OpenVPN, + allServers models.AllServers) +} + +func New(statusApplier loopstate.Applier, + settings configuration.OpenVPN, + allServers models.AllServers) *State { + return &State{ + statusApplier: statusApplier, + settings: settings, + allServers: allServers, + } +} + +type State struct { + statusApplier loopstate.Applier + + settings configuration.OpenVPN + settingsMu sync.RWMutex + + allServers models.AllServers + allServersMu sync.RWMutex + + portForwarded uint16 + portForwardedMu sync.RWMutex +} + +func (s *State) GetSettingsAndServers() (settings configuration.OpenVPN, + allServers models.AllServers) { + s.settingsMu.RLock() + s.allServersMu.RLock() + settings = s.settings + allServers = s.allServers + s.settingsMu.RUnlock() + s.allServersMu.RUnlock() + return settings, allServers +} diff --git a/internal/openvpn/status.go b/internal/openvpn/status.go new file mode 100644 index 00000000..1ba530dc --- /dev/null +++ b/internal/openvpn/status.go @@ -0,0 +1,16 @@ +package openvpn + +import ( + "context" + + "github.com/qdm12/gluetun/internal/models" +) + +func (l *looper) GetStatus() (status models.LoopStatus) { + return l.statusManager.GetStatus() +} + +func (l *looper) ApplyStatus(ctx context.Context, status models.LoopStatus) ( + outcome string, err error) { + return l.statusManager.ApplyStatus(ctx, status) +}