Maint: port forwarding refactoring (#543)
- portforward package - portforward run loop - Less functional arguments and cycles
This commit is contained in:
32
internal/portforward/firewall.go
Normal file
32
internal/portforward/firewall.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package portforward
|
||||
|
||||
import "context"
|
||||
|
||||
// firewallBlockPort obtains the state port thread safely and blocks
|
||||
// it in the firewall if it is not the zero value (0).
|
||||
func (l *Loop) firewallBlockPort(ctx context.Context) {
|
||||
port := l.state.GetPortForwarded()
|
||||
if port == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := l.portAllower.RemoveAllowedPort(ctx, port)
|
||||
if err != nil {
|
||||
l.logger.Error("cannot block previous port in firewall: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// firewallAllowPort obtains the state port thread safely and allows
|
||||
// it in the firewall if it is not the zero value (0).
|
||||
func (l *Loop) firewallAllowPort(ctx context.Context) {
|
||||
port := l.state.GetPortForwarded()
|
||||
if port == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
startData := l.state.GetStartData()
|
||||
err := l.portAllower.SetAllowedPort(ctx, port, startData.Interface)
|
||||
if err != nil {
|
||||
l.logger.Error("cannot allow port through firewall: " + err.Error())
|
||||
}
|
||||
}
|
||||
37
internal/portforward/fs.go
Normal file
37
internal/portforward/fs.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func (l *Loop) removePortForwardedFile() {
|
||||
filepath := l.state.GetSettings().Filepath
|
||||
l.logger.Info("removing port file " + filepath)
|
||||
if err := os.Remove(filepath); err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Loop) writePortForwardedFile(port uint16) {
|
||||
filepath := l.state.GetSettings().Filepath
|
||||
l.logger.Info("writing port file " + filepath)
|
||||
if err := writePortForwardedToFile(filepath, port); err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func writePortForwardedToFile(filepath string, port uint16) (err error) {
|
||||
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = file.Write([]byte(fmt.Sprint(port)))
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return file.Close()
|
||||
}
|
||||
9
internal/portforward/get.go
Normal file
9
internal/portforward/get.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package portforward
|
||||
|
||||
import "github.com/qdm12/gluetun/internal/portforward/state"
|
||||
|
||||
type Getter = state.PortForwardedGetter
|
||||
|
||||
func (l *Loop) GetPortForwarded() (port uint16) {
|
||||
return l.state.GetPortForwarded()
|
||||
}
|
||||
22
internal/portforward/helpers.go
Normal file
22
internal/portforward/helpers.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (l *Loop) logAndWait(ctx context.Context, err error) {
|
||||
if err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
l.logger.Info("retrying in " + l.backoffTime.String())
|
||||
timer := time.NewTimer(l.backoffTime)
|
||||
l.backoffTime *= 2
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
}
|
||||
}
|
||||
71
internal/portforward/loop.go
Normal file
71
internal/portforward/loop.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/loopstate"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/portforward/state"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
var _ Looper = (*Loop)(nil)
|
||||
|
||||
type Looper interface {
|
||||
Runner
|
||||
loopstate.Getter
|
||||
StartStopper
|
||||
SettingsGetSetter
|
||||
Getter
|
||||
}
|
||||
|
||||
type Loop struct {
|
||||
statusManager loopstate.Manager
|
||||
state state.Manager
|
||||
// Objects
|
||||
client *http.Client
|
||||
portAllower firewall.PortAllower
|
||||
logger logging.Logger
|
||||
// Internal channels and locks
|
||||
start chan struct{}
|
||||
running chan models.LoopStatus
|
||||
stop chan struct{}
|
||||
stopped chan struct{}
|
||||
startMu sync.Mutex
|
||||
backoffTime time.Duration
|
||||
userTrigger bool
|
||||
}
|
||||
|
||||
const defaultBackoffTime = 5 * time.Second
|
||||
|
||||
func NewLoop(settings configuration.PortForwarding,
|
||||
client *http.Client, portAllower firewall.PortAllower,
|
||||
logger logging.Logger) *Loop {
|
||||
start := make(chan struct{})
|
||||
running := make(chan models.LoopStatus)
|
||||
stop := make(chan struct{})
|
||||
stopped := make(chan struct{})
|
||||
|
||||
statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped)
|
||||
state := state.New(statusManager, settings)
|
||||
|
||||
return &Loop{
|
||||
statusManager: statusManager,
|
||||
state: state,
|
||||
// Objects
|
||||
client: client,
|
||||
portAllower: portAllower,
|
||||
logger: logger,
|
||||
start: start,
|
||||
running: running,
|
||||
stop: stop,
|
||||
stopped: stopped,
|
||||
userTrigger: true,
|
||||
backoffTime: defaultBackoffTime,
|
||||
}
|
||||
}
|
||||
97
internal/portforward/run.go
Normal file
97
internal/portforward/run.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
)
|
||||
|
||||
type Runner interface {
|
||||
Run(ctx context.Context, done chan<- struct{})
|
||||
}
|
||||
|
||||
func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
defer close(done)
|
||||
|
||||
select {
|
||||
case <-l.start: // l.state.SetStartData called beforehand
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
for ctx.Err() == nil {
|
||||
pfCtx, pfCancel := context.WithCancel(ctx)
|
||||
|
||||
portCh := make(chan uint16)
|
||||
errorCh := make(chan error)
|
||||
|
||||
startData := l.state.GetStartData()
|
||||
|
||||
go func(ctx context.Context, startData StartData) {
|
||||
port, err := startData.PortForwarder.PortForward(ctx, l.client, l.logger,
|
||||
startData.Gateway, startData.ServerName)
|
||||
if err != nil {
|
||||
errorCh <- err
|
||||
return
|
||||
}
|
||||
portCh <- port
|
||||
|
||||
// Infinite loop
|
||||
err = startData.PortForwarder.KeepPortForward(ctx, l.client, l.logger,
|
||||
port, startData.Gateway, startData.ServerName)
|
||||
errorCh <- err
|
||||
}(pfCtx, startData)
|
||||
|
||||
if l.userTrigger {
|
||||
l.userTrigger = false
|
||||
l.running <- constants.Running
|
||||
} else { // crash
|
||||
l.backoffTime = defaultBackoffTime
|
||||
l.statusManager.SetStatus(constants.Running)
|
||||
}
|
||||
|
||||
stayHere := true
|
||||
for stayHere {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
pfCancel()
|
||||
<-errorCh
|
||||
close(errorCh)
|
||||
close(portCh)
|
||||
l.removePortForwardedFile()
|
||||
l.firewallBlockPort(ctx)
|
||||
l.state.SetPortForwarded(0)
|
||||
return
|
||||
case <-l.start:
|
||||
l.userTrigger = true
|
||||
l.logger.Info("starting")
|
||||
pfCancel()
|
||||
stayHere = false
|
||||
case <-l.stop:
|
||||
l.userTrigger = true
|
||||
l.logger.Info("stopping")
|
||||
pfCancel()
|
||||
<-errorCh
|
||||
l.removePortForwardedFile()
|
||||
l.firewallBlockPort(ctx)
|
||||
l.state.SetPortForwarded(0)
|
||||
l.stopped <- struct{}{}
|
||||
case port := <-portCh:
|
||||
l.logger.Info("port forwarded is " + strconv.Itoa(int(port)))
|
||||
l.firewallBlockPort(ctx)
|
||||
l.state.SetPortForwarded(port)
|
||||
l.firewallAllowPort(ctx)
|
||||
l.writePortForwardedFile(port)
|
||||
case err := <-errorCh:
|
||||
pfCancel()
|
||||
close(errorCh)
|
||||
close(portCh)
|
||||
l.statusManager.SetStatus(constants.Crashed)
|
||||
l.logAndWait(ctx, err)
|
||||
stayHere = false
|
||||
}
|
||||
}
|
||||
pfCancel() // for linting
|
||||
}
|
||||
}
|
||||
19
internal/portforward/settings.go
Normal file
19
internal/portforward/settings.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/portforward/state"
|
||||
)
|
||||
|
||||
type SettingsGetSetter = state.SettingsGetSetter
|
||||
|
||||
func (l *Loop) GetSettings() (settings configuration.PortForwarding) {
|
||||
return l.state.GetSettings()
|
||||
}
|
||||
|
||||
func (l *Loop) SetSettings(ctx context.Context, settings configuration.PortForwarding) (
|
||||
outcome string) {
|
||||
return l.state.SetSettings(ctx, settings)
|
||||
}
|
||||
26
internal/portforward/state/portforwarded.go
Normal file
26
internal/portforward/state/portforwarded.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package state
|
||||
|
||||
type PortForwardedGetterSetter interface {
|
||||
PortForwardedGetter
|
||||
SetPortForwarded(port uint16)
|
||||
}
|
||||
|
||||
type PortForwardedGetter interface {
|
||||
GetPortForwarded() (port uint16)
|
||||
}
|
||||
|
||||
// GetPortForwarded is used by the control HTTP server
|
||||
// to obtain the port currently forwarded.
|
||||
func (s *State) GetPortForwarded() (port uint16) {
|
||||
s.portForwardedMu.RLock()
|
||||
defer s.portForwardedMu.RUnlock()
|
||||
return s.portForwarded
|
||||
}
|
||||
|
||||
// SetPortForwarded is only used from within the OpenVPN loop
|
||||
// to set the port forwarded.
|
||||
func (s *State) SetPortForwarded(port uint16) {
|
||||
s.portForwardedMu.Lock()
|
||||
defer s.portForwardedMu.Unlock()
|
||||
s.portForwarded = port
|
||||
}
|
||||
55
internal/portforward/state/settings.go
Normal file
55
internal/portforward/state/settings.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
)
|
||||
|
||||
type SettingsGetSetter interface {
|
||||
GetSettings() (settings configuration.PortForwarding)
|
||||
SetSettings(ctx context.Context,
|
||||
settings configuration.PortForwarding) (outcome string)
|
||||
}
|
||||
|
||||
func (s *State) GetSettings() (settings configuration.PortForwarding) {
|
||||
s.settingsMu.RLock()
|
||||
defer s.settingsMu.RUnlock()
|
||||
return s.settings
|
||||
}
|
||||
|
||||
func (s *State) SetSettings(ctx context.Context, settings configuration.PortForwarding) (
|
||||
outcome string) {
|
||||
s.settingsMu.Lock()
|
||||
|
||||
settingsUnchanged := reflect.DeepEqual(s.settings, settings)
|
||||
if settingsUnchanged {
|
||||
s.settingsMu.Unlock()
|
||||
return "settings left unchanged"
|
||||
}
|
||||
|
||||
if s.settings.Filepath != settings.Filepath {
|
||||
_ = os.Rename(s.settings.Filepath, settings.Filepath)
|
||||
}
|
||||
|
||||
newEnabled := settings.Enabled
|
||||
previousEnabled := s.settings.Enabled
|
||||
|
||||
s.settings = settings
|
||||
s.settingsMu.Unlock()
|
||||
|
||||
switch {
|
||||
case !newEnabled && !previousEnabled:
|
||||
case newEnabled && previousEnabled:
|
||||
// no need to restart for now since we os.Rename the file here.
|
||||
case newEnabled && !previousEnabled:
|
||||
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Running)
|
||||
case !newEnabled && previousEnabled:
|
||||
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped)
|
||||
}
|
||||
|
||||
return "settings updated"
|
||||
}
|
||||
39
internal/portforward/state/startdata.go
Normal file
39
internal/portforward/state/startdata.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
)
|
||||
|
||||
type StartData struct {
|
||||
PortForwarder provider.PortForwarder
|
||||
Gateway net.IP // needed for PIA
|
||||
ServerName string // needed for PIA
|
||||
Interface string // tun0 or wg0 for example
|
||||
}
|
||||
|
||||
type StartDataGetterSetter interface {
|
||||
StartDataGetter
|
||||
StartDataSetter
|
||||
}
|
||||
|
||||
type StartDataGetter interface {
|
||||
GetStartData() (startData StartData)
|
||||
}
|
||||
|
||||
func (s *State) GetStartData() (startData StartData) {
|
||||
s.startDataMu.RLock()
|
||||
defer s.startDataMu.RUnlock()
|
||||
return s.startData
|
||||
}
|
||||
|
||||
type StartDataSetter interface {
|
||||
SetStartData(startData StartData)
|
||||
}
|
||||
|
||||
func (s *State) SetStartData(startData StartData) {
|
||||
s.startDataMu.Lock()
|
||||
defer s.startDataMu.Unlock()
|
||||
s.startData = startData
|
||||
}
|
||||
37
internal/portforward/state/state.go
Normal file
37
internal/portforward/state/state.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/loopstate"
|
||||
)
|
||||
|
||||
var _ Manager = (*State)(nil)
|
||||
|
||||
type Manager interface {
|
||||
SettingsGetSetter
|
||||
PortForwardedGetterSetter
|
||||
StartDataGetterSetter
|
||||
}
|
||||
|
||||
func New(statusApplier loopstate.Applier,
|
||||
settings configuration.PortForwarding) *State {
|
||||
return &State{
|
||||
statusApplier: statusApplier,
|
||||
settings: settings,
|
||||
}
|
||||
}
|
||||
|
||||
type State struct {
|
||||
statusApplier loopstate.Applier
|
||||
|
||||
settings configuration.PortForwarding
|
||||
settingsMu sync.RWMutex
|
||||
|
||||
portForwarded uint16
|
||||
portForwardedMu sync.RWMutex
|
||||
|
||||
startData StartData
|
||||
startDataMu sync.RWMutex
|
||||
}
|
||||
33
internal/portforward/status.go
Normal file
33
internal/portforward/status.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/portforward/state"
|
||||
)
|
||||
|
||||
func (l *Loop) GetStatus() (status models.LoopStatus) {
|
||||
return l.statusManager.GetStatus()
|
||||
}
|
||||
|
||||
type StartData = state.StartData
|
||||
|
||||
type StartStopper interface {
|
||||
Start(ctx context.Context, data StartData) (
|
||||
outcome string, err error)
|
||||
Stop(ctx context.Context) (outcome string, err error)
|
||||
}
|
||||
|
||||
func (l *Loop) Start(ctx context.Context, data StartData) (
|
||||
outcome string, err error) {
|
||||
l.startMu.Lock()
|
||||
defer l.startMu.Unlock()
|
||||
l.state.SetStartData(data)
|
||||
return l.statusManager.ApplyStatus(ctx, constants.Running)
|
||||
}
|
||||
|
||||
func (l *Loop) Stop(ctx context.Context) (outcome string, err error) {
|
||||
return l.statusManager.ApplyStatus(ctx, constants.Stopped)
|
||||
}
|
||||
Reference in New Issue
Block a user