Maint: port forwarding refactoring (#543)

- portforward package
- portforward run loop
- Less functional arguments and cycles
This commit is contained in:
Quentin McGaw
2021-07-28 08:35:44 -07:00
committed by GitHub
parent c777f8d97d
commit 2998cf5e48
25 changed files with 639 additions and 255 deletions

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
@@ -24,6 +23,7 @@ import (
"github.com/qdm12/gluetun/internal/httpproxy" "github.com/qdm12/gluetun/internal/httpproxy"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn" "github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/publicip" "github.com/qdm12/gluetun/internal/publicip"
"github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/gluetun/internal/server" "github.com/qdm12/gluetun/internal/server"
@@ -321,8 +321,16 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
tickersGroupHandler := goshutdown.NewGroupHandler("tickers", defaultGroupSettings) tickersGroupHandler := goshutdown.NewGroupHandler("tickers", defaultGroupSettings)
otherGroupHandler := goshutdown.NewGroupHandler("other", defaultGroupSettings) otherGroupHandler := goshutdown.NewGroupHandler("other", defaultGroupSettings)
portForwardLogger := logger.NewChild(logging.Settings{Prefix: "port forwarding: "})
portForwardLooper := portforward.NewLoop(allSettings.OpenVPN.Provider.PortForwarding,
httpClient, firewallConf, portForwardLogger)
portForwardHandler, portForwardCtx, portForwardDone := goshutdown.NewGoRoutineHandler(
"port forwarding", goshutdown.GoRoutineSettings{Timeout: time.Second})
go portForwardLooper.Run(portForwardCtx, portForwardDone)
openvpnLogger := logger.NewChild(logging.Settings{Prefix: "openvpn: "})
openvpnLooper := openvpn.NewLoop(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers, openvpnLooper := openvpn.NewLoop(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers,
ovpnConf, firewallConf, logger, httpClient, tunnelReadyCh) ovpnConf, firewallConf, routingConf, portForwardLooper, openvpnLogger, httpClient, tunnelReadyCh)
openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler( openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler(
"openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second}) "openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second})
// wait for restartOpenvpn // wait for restartOpenvpn
@@ -378,8 +386,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
"events routing", defaultGoRoutineSettings) "events routing", defaultGoRoutineSettings)
go routeReadyEvents(eventsRoutingCtx, eventsRoutingDone, buildInfo, tunnelReadyCh, go routeReadyEvents(eventsRoutingCtx, eventsRoutingDone, buildInfo, tunnelReadyCh,
unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient, unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient,
allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward, allSettings.VersionInformation)
)
controlGroupHandler.Add(eventsRoutingHandler) controlGroupHandler.Add(eventsRoutingHandler)
controlServerAddress := ":" + strconv.Itoa(int(allSettings.ControlServer.Port)) controlServerAddress := ":" + strconv.Itoa(int(allSettings.ControlServer.Port))
@@ -406,7 +413,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
} }
orderHandler := goshutdown.NewOrder("gluetun", orderSettings) orderHandler := goshutdown.NewOrder("gluetun", orderSettings)
orderHandler.Append(controlGroupHandler, tickersGroupHandler, healthServerHandler, orderHandler.Append(controlGroupHandler, tickersGroupHandler, healthServerHandler,
openvpnHandler, otherGroupHandler) openvpnHandler, portForwardHandler, otherGroupHandler)
// Start openvpn for the first time in a blocking call // Start openvpn for the first time in a blocking call
// until openvpn is launched // until openvpn is launched
@@ -414,13 +421,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
<-ctx.Done() <-ctx.Done()
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
logger.Info("Clearing forwarded port status file " + allSettings.OpenVPN.Provider.PortForwarding.Filepath)
if err := os.Remove(allSettings.OpenVPN.Provider.PortForwarding.Filepath); err != nil {
logger.Error(err.Error())
}
}
return orderHandler.Shutdown(context.Background()) return orderHandler.Shutdown(context.Background())
} }
@@ -450,7 +450,7 @@ func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo model
tunnelReadyCh <-chan struct{}, tunnelReadyCh <-chan struct{},
unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper, unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
routing routing.VPNGetter, logger logging.Logger, httpClient *http.Client, routing routing.VPNGetter, logger logging.Logger, httpClient *http.Client,
versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) { versionInformation bool) {
defer close(done) defer close(done)
// for linters only // for linters only
@@ -503,15 +503,6 @@ func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo model
updaterTickerDone = make(chan struct{}) updaterTickerDone = make(chan struct{})
go unboundLooper.RunRestartTicker(restartTickerContext, unboundTickerDone) go unboundLooper.RunRestartTicker(restartTickerContext, unboundTickerDone)
go updaterLooper.RunRestartTicker(restartTickerContext, updaterTickerDone) go updaterLooper.RunRestartTicker(restartTickerContext, updaterTickerDone)
if portForwardingEnabled {
// vpnGateway required only for PIA
vpnGateway, err := routing.VPNLocalGatewayIP()
if err != nil {
logger.Error("cannot get VPN local gateway IP: " + err.Error())
}
logger.Info("VPN gateway IP address: " + vpnGateway.String())
startPortForward(vpnGateway)
}
} }
} }
} }

View File

@@ -42,6 +42,7 @@ func (l *Loop) collectLines(stdout, stderr <-chan string, done chan<- struct{})
} }
if strings.Contains(line, "Initialization Sequence Completed") { if strings.Contains(line, "Initialization Sequence Completed") {
l.tunnelReady <- struct{}{} l.tunnelReady <- struct{}{}
l.startPFCh <- struct{}{}
} }
} }
} }

View File

@@ -1,7 +1,6 @@
package openvpn package openvpn
import ( import (
"net"
"net/http" "net/http"
"time" "time"
@@ -11,6 +10,8 @@ import (
"github.com/qdm12/gluetun/internal/loopstate" "github.com/qdm12/gluetun/internal/loopstate"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn/state" "github.com/qdm12/gluetun/internal/openvpn/state"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -22,8 +23,6 @@ type Looper interface {
loopstate.Applier loopstate.Applier
SettingsGetSetter SettingsGetSetter
ServersGetterSetter ServersGetterSetter
PortForwadedGetter
PortForwader
} }
type Loop struct { type Loop struct {
@@ -35,19 +34,21 @@ type Loop struct {
pgid int pgid int
targetConfPath string targetConfPath string
// Configurators // Configurators
conf StarterAuthWriter conf StarterAuthWriter
fw firewallConfigurer fw firewallConfigurer
routing routing.VPNLocalGatewayIPGetter
portForward portforward.StartStopper
// Other objects // Other objects
logger, pfLogger logging.Logger logger logging.Logger
client *http.Client client *http.Client
tunnelReady chan<- struct{} tunnelReady chan<- struct{}
// Internal channels and values // Internal channels and values
stop <-chan struct{} stop <-chan struct{}
stopped chan<- struct{} stopped chan<- struct{}
start <-chan struct{} start <-chan struct{}
running chan<- models.LoopStatus running chan<- models.LoopStatus
portForwardSignals chan net.IP userTrigger bool
userTrigger bool startPFCh chan struct{}
// Internal constant values // Internal constant values
backoffTime time.Duration backoffTime time.Duration
} }
@@ -63,7 +64,8 @@ const (
func NewLoop(settings configuration.OpenVPN, username string, func NewLoop(settings configuration.OpenVPN, username string,
puid, pgid int, allServers models.AllServers, conf Configurator, puid, pgid int, allServers models.AllServers, conf Configurator,
fw firewallConfigurer, logger logging.ParentLogger, fw firewallConfigurer, routing routing.VPNLocalGatewayIPGetter,
portForward portforward.StartStopper, logger logging.Logger,
client *http.Client, tunnelReady chan<- struct{}) *Loop { client *http.Client, tunnelReady chan<- struct{}) *Loop {
start := make(chan struct{}) start := make(chan struct{})
running := make(chan models.LoopStatus) running := make(chan models.LoopStatus)
@@ -74,24 +76,25 @@ func NewLoop(settings configuration.OpenVPN, username string,
state := state.New(statusManager, settings, allServers) state := state.New(statusManager, settings, allServers)
return &Loop{ return &Loop{
statusManager: statusManager, statusManager: statusManager,
state: state, state: state,
username: username, username: username,
puid: puid, puid: puid,
pgid: pgid, pgid: pgid,
targetConfPath: constants.OpenVPNConf, targetConfPath: constants.OpenVPNConf,
conf: conf, conf: conf,
fw: fw, fw: fw,
logger: logger.NewChild(logging.Settings{Prefix: "openvpn: "}), routing: routing,
pfLogger: logger.NewChild(logging.Settings{Prefix: "port forwarding: "}), portForward: portForward,
client: client, logger: logger,
tunnelReady: tunnelReady, client: client,
start: start, tunnelReady: tunnelReady,
running: running, start: start,
stop: stop, running: running,
stopped: stopped, stop: stop,
portForwardSignals: make(chan net.IP), stopped: stopped,
userTrigger: true, userTrigger: true,
backoffTime: defaultBackoffTime, startPFCh: make(chan struct{}),
backoffTime: defaultBackoffTime,
} }
} }

View File

@@ -0,0 +1,47 @@
package openvpn
import (
"context"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/provider"
)
func (l *Loop) startPortForwarding(ctx context.Context,
portForwarder provider.PortForwarder, serverName string) {
if !l.GetSettings().Provider.PortForwarding.Enabled {
return
}
// only used for PIA for now
gateway, err := l.routing.VPNLocalGatewayIP()
if err != nil {
l.logger.Error("cannot obtain VPN local gateway IP: " + err.Error())
return
}
l.logger.Info("VPN gateway IP address: " + gateway.String())
pfData := portforward.StartData{
PortForwarder: portForwarder,
Gateway: gateway,
ServerName: serverName,
Interface: constants.TUN,
}
_, err = l.portForward.Start(ctx, pfData)
if err != nil {
l.logger.Error("cannot start port forwarding: " + err.Error())
}
}
func (l *Loop) stopPortForwarding(ctx context.Context, timeout time.Duration) {
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
_, err := l.portForward.Stop(ctx)
if err != nil {
l.logger.Error("cannot stop port forwarding: " + err.Error())
}
}

View File

@@ -1,39 +0,0 @@
package openvpn
import (
"context"
"net"
"net/http"
"github.com/qdm12/gluetun/internal/openvpn/state"
"github.com/qdm12/gluetun/internal/provider"
)
type PortForwadedGetter = state.PortForwardedGetter
func (l *Loop) GetPortForwarded() (port uint16) {
return l.state.GetPortForwarded()
}
type PortForwader interface {
PortForward(vpnGatewayIP net.IP)
}
func (l *Loop) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway }
// portForward is a blocking operation which may or may not be infinite.
// You should therefore always call it in a goroutine.
func (l *Loop) portForward(ctx context.Context,
providerConf provider.Provider, client *http.Client, gateway net.IP) {
settings := l.state.GetSettings()
if !settings.Provider.PortForwarding.Enabled {
return
}
syncState := func(port uint16) (pfFilepath string) {
l.state.SetPortForwarded(port)
settings := l.state.GetSettings()
return settings.Provider.PortForwarding.Filepath
}
providerConf.PortForward(ctx, client, l.pfLogger,
gateway, l.fw, syncState)
}

View File

@@ -88,41 +88,31 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
<-lineCollectionDone <-lineCollectionDone
} }
// Needs the stream line from main.go to know when the tunnel is up
portForwardDone := make(chan struct{})
go func(ctx context.Context) {
defer close(portForwardDone)
select {
// TODO have a way to disable pf with a context
case <-ctx.Done():
return
case gateway := <-l.portForwardSignals:
l.portForward(ctx, providerConf, l.client, gateway)
}
}(openvpnCtx)
l.backoffTime = defaultBackoffTime l.backoffTime = defaultBackoffTime
l.signalOrSetStatus(constants.Running) l.signalOrSetStatus(constants.Running)
stayHere := true stayHere := true
for stayHere { for stayHere {
select { select {
case <-l.startPFCh:
l.startPortForwarding(ctx, providerConf, connection.Hostname)
case <-ctx.Done(): case <-ctx.Done():
const pfTimeout = 100 * time.Millisecond
l.stopPortForwarding(context.Background(), pfTimeout)
openvpnCancel() openvpnCancel()
<-waitError <-waitError
close(waitError) close(waitError)
closeStreams() closeStreams()
<-portForwardDone
return return
case <-l.stop: case <-l.stop:
l.userTrigger = true l.userTrigger = true
l.logger.Info("stopping") l.logger.Info("stopping")
l.stopPortForwarding(ctx, 0)
openvpnCancel() openvpnCancel()
<-waitError <-waitError
// do not close waitError or the waitError // do not close waitError or the waitError
// select case will trigger // select case will trigger
closeStreams() closeStreams()
<-portForwardDone
l.stopped <- struct{}{} l.stopped <- struct{}{}
case <-l.start: case <-l.start:
l.userTrigger = true l.userTrigger = true
@@ -134,9 +124,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
l.statusManager.Lock() // prevent SetStatus from running in parallel l.statusManager.Lock() // prevent SetStatus from running in parallel
l.stopPortForwarding(ctx, 0)
openvpnCancel() openvpnCancel()
l.statusManager.SetStatus(constants.Crashed) l.statusManager.SetStatus(constants.Crashed)
<-portForwardDone
l.logAndWait(ctx, err) l.logAndWait(ctx, err)
stayHere = false stayHere = false

View File

@@ -13,7 +13,6 @@ var _ Manager = (*State)(nil)
type Manager interface { type Manager interface {
SettingsGetSetter SettingsGetSetter
ServersGetterSetter ServersGetterSetter
PortForwardedGetterSetter
GetSettingsAndServers() (settings configuration.OpenVPN, GetSettingsAndServers() (settings configuration.OpenVPN,
allServers models.AllServers) allServers models.AllServers)
} }
@@ -36,9 +35,6 @@ type State struct {
allServers models.AllServers allServers models.AllServers
allServersMu sync.RWMutex allServersMu sync.RWMutex
portForwarded uint16
portForwardedMu sync.RWMutex
} }
func (s *State) GetSettingsAndServers() (settings configuration.OpenVPN, func (s *State) GetSettingsAndServers() (settings configuration.OpenVPN,

View File

@@ -0,0 +1,32 @@
package portforward
import "context"
// firewallBlockPort obtains the state port thread safely and blocks
// it in the firewall if it is not the zero value (0).
func (l *Loop) firewallBlockPort(ctx context.Context) {
port := l.state.GetPortForwarded()
if port == 0 {
return
}
err := l.portAllower.RemoveAllowedPort(ctx, port)
if err != nil {
l.logger.Error("cannot block previous port in firewall: " + err.Error())
}
}
// firewallAllowPort obtains the state port thread safely and allows
// it in the firewall if it is not the zero value (0).
func (l *Loop) firewallAllowPort(ctx context.Context) {
port := l.state.GetPortForwarded()
if port == 0 {
return
}
startData := l.state.GetStartData()
err := l.portAllower.SetAllowedPort(ctx, port, startData.Interface)
if err != nil {
l.logger.Error("cannot allow port through firewall: " + err.Error())
}
}

View File

@@ -0,0 +1,37 @@
package portforward
import (
"fmt"
"os"
)
func (l *Loop) removePortForwardedFile() {
filepath := l.state.GetSettings().Filepath
l.logger.Info("removing port file " + filepath)
if err := os.Remove(filepath); err != nil {
l.logger.Error(err.Error())
}
}
func (l *Loop) writePortForwardedFile(port uint16) {
filepath := l.state.GetSettings().Filepath
l.logger.Info("writing port file " + filepath)
if err := writePortForwardedToFile(filepath, port); err != nil {
l.logger.Error(err.Error())
}
}
func writePortForwardedToFile(filepath string, port uint16) (err error) {
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
_, err = file.Write([]byte(fmt.Sprint(port)))
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -0,0 +1,9 @@
package portforward
import "github.com/qdm12/gluetun/internal/portforward/state"
type Getter = state.PortForwardedGetter
func (l *Loop) GetPortForwarded() (port uint16) {
return l.state.GetPortForwarded()
}

View File

@@ -0,0 +1,22 @@
package portforward
import (
"context"
"time"
)
func (l *Loop) logAndWait(ctx context.Context, err error) {
if err != nil {
l.logger.Error(err.Error())
}
l.logger.Info("retrying in " + l.backoffTime.String())
timer := time.NewTimer(l.backoffTime)
l.backoffTime *= 2
select {
case <-timer.C:
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
}
}

View File

@@ -0,0 +1,71 @@
package portforward
import (
"net/http"
"sync"
"time"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/loopstate"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/portforward/state"
"github.com/qdm12/golibs/logging"
)
var _ Looper = (*Loop)(nil)
type Looper interface {
Runner
loopstate.Getter
StartStopper
SettingsGetSetter
Getter
}
type Loop struct {
statusManager loopstate.Manager
state state.Manager
// Objects
client *http.Client
portAllower firewall.PortAllower
logger logging.Logger
// Internal channels and locks
start chan struct{}
running chan models.LoopStatus
stop chan struct{}
stopped chan struct{}
startMu sync.Mutex
backoffTime time.Duration
userTrigger bool
}
const defaultBackoffTime = 5 * time.Second
func NewLoop(settings configuration.PortForwarding,
client *http.Client, portAllower firewall.PortAllower,
logger logging.Logger) *Loop {
start := make(chan struct{})
running := make(chan models.LoopStatus)
stop := make(chan struct{})
stopped := make(chan struct{})
statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped)
state := state.New(statusManager, settings)
return &Loop{
statusManager: statusManager,
state: state,
// Objects
client: client,
portAllower: portAllower,
logger: logger,
start: start,
running: running,
stop: stop,
stopped: stopped,
userTrigger: true,
backoffTime: defaultBackoffTime,
}
}

View File

@@ -0,0 +1,97 @@
package portforward
import (
"context"
"strconv"
"github.com/qdm12/gluetun/internal/constants"
)
type Runner interface {
Run(ctx context.Context, done chan<- struct{})
}
func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
defer close(done)
select {
case <-l.start: // l.state.SetStartData called beforehand
case <-ctx.Done():
return
}
for ctx.Err() == nil {
pfCtx, pfCancel := context.WithCancel(ctx)
portCh := make(chan uint16)
errorCh := make(chan error)
startData := l.state.GetStartData()
go func(ctx context.Context, startData StartData) {
port, err := startData.PortForwarder.PortForward(ctx, l.client, l.logger,
startData.Gateway, startData.ServerName)
if err != nil {
errorCh <- err
return
}
portCh <- port
// Infinite loop
err = startData.PortForwarder.KeepPortForward(ctx, l.client, l.logger,
port, startData.Gateway, startData.ServerName)
errorCh <- err
}(pfCtx, startData)
if l.userTrigger {
l.userTrigger = false
l.running <- constants.Running
} else { // crash
l.backoffTime = defaultBackoffTime
l.statusManager.SetStatus(constants.Running)
}
stayHere := true
for stayHere {
select {
case <-ctx.Done():
pfCancel()
<-errorCh
close(errorCh)
close(portCh)
l.removePortForwardedFile()
l.firewallBlockPort(ctx)
l.state.SetPortForwarded(0)
return
case <-l.start:
l.userTrigger = true
l.logger.Info("starting")
pfCancel()
stayHere = false
case <-l.stop:
l.userTrigger = true
l.logger.Info("stopping")
pfCancel()
<-errorCh
l.removePortForwardedFile()
l.firewallBlockPort(ctx)
l.state.SetPortForwarded(0)
l.stopped <- struct{}{}
case port := <-portCh:
l.logger.Info("port forwarded is " + strconv.Itoa(int(port)))
l.firewallBlockPort(ctx)
l.state.SetPortForwarded(port)
l.firewallAllowPort(ctx)
l.writePortForwardedFile(port)
case err := <-errorCh:
pfCancel()
close(errorCh)
close(portCh)
l.statusManager.SetStatus(constants.Crashed)
l.logAndWait(ctx, err)
stayHere = false
}
}
pfCancel() // for linting
}
}

View File

@@ -0,0 +1,19 @@
package portforward
import (
"context"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/portforward/state"
)
type SettingsGetSetter = state.SettingsGetSetter
func (l *Loop) GetSettings() (settings configuration.PortForwarding) {
return l.state.GetSettings()
}
func (l *Loop) SetSettings(ctx context.Context, settings configuration.PortForwarding) (
outcome string) {
return l.state.SetSettings(ctx, settings)
}

View File

@@ -0,0 +1,55 @@
package state
import (
"context"
"os"
"reflect"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants"
)
type SettingsGetSetter interface {
GetSettings() (settings configuration.PortForwarding)
SetSettings(ctx context.Context,
settings configuration.PortForwarding) (outcome string)
}
func (s *State) GetSettings() (settings configuration.PortForwarding) {
s.settingsMu.RLock()
defer s.settingsMu.RUnlock()
return s.settings
}
func (s *State) SetSettings(ctx context.Context, settings configuration.PortForwarding) (
outcome string) {
s.settingsMu.Lock()
settingsUnchanged := reflect.DeepEqual(s.settings, settings)
if settingsUnchanged {
s.settingsMu.Unlock()
return "settings left unchanged"
}
if s.settings.Filepath != settings.Filepath {
_ = os.Rename(s.settings.Filepath, settings.Filepath)
}
newEnabled := settings.Enabled
previousEnabled := s.settings.Enabled
s.settings = settings
s.settingsMu.Unlock()
switch {
case !newEnabled && !previousEnabled:
case newEnabled && previousEnabled:
// no need to restart for now since we os.Rename the file here.
case newEnabled && !previousEnabled:
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Running)
case !newEnabled && previousEnabled:
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped)
}
return "settings updated"
}

View File

@@ -0,0 +1,39 @@
package state
import (
"net"
"github.com/qdm12/gluetun/internal/provider"
)
type StartData struct {
PortForwarder provider.PortForwarder
Gateway net.IP // needed for PIA
ServerName string // needed for PIA
Interface string // tun0 or wg0 for example
}
type StartDataGetterSetter interface {
StartDataGetter
StartDataSetter
}
type StartDataGetter interface {
GetStartData() (startData StartData)
}
func (s *State) GetStartData() (startData StartData) {
s.startDataMu.RLock()
defer s.startDataMu.RUnlock()
return s.startData
}
type StartDataSetter interface {
SetStartData(startData StartData)
}
func (s *State) SetStartData(startData StartData) {
s.startDataMu.Lock()
defer s.startDataMu.Unlock()
s.startData = startData
}

View File

@@ -0,0 +1,37 @@
package state
import (
"sync"
"github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/loopstate"
)
var _ Manager = (*State)(nil)
type Manager interface {
SettingsGetSetter
PortForwardedGetterSetter
StartDataGetterSetter
}
func New(statusApplier loopstate.Applier,
settings configuration.PortForwarding) *State {
return &State{
statusApplier: statusApplier,
settings: settings,
}
}
type State struct {
statusApplier loopstate.Applier
settings configuration.PortForwarding
settingsMu sync.RWMutex
portForwarded uint16
portForwardedMu sync.RWMutex
startData StartData
startDataMu sync.RWMutex
}

View File

@@ -0,0 +1,33 @@
package portforward
import (
"context"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/portforward/state"
)
func (l *Loop) GetStatus() (status models.LoopStatus) {
return l.statusManager.GetStatus()
}
type StartData = state.StartData
type StartStopper interface {
Start(ctx context.Context, data StartData) (
outcome string, err error)
Stop(ctx context.Context) (outcome string, err error)
}
func (l *Loop) Start(ctx context.Context, data StartData) (
outcome string, err error) {
l.startMu.Lock()
defer l.startMu.Unlock()
l.state.SetStartData(data)
return l.statusManager.ApplyStatus(ctx, constants.Running)
}
func (l *Loop) Stop(ctx context.Context) (outcome string, err error) {
return l.statusManager.ApplyStatus(ctx, constants.Stopped)
}

View File

@@ -31,6 +31,7 @@ func (p *PIA) GetOpenVPNConnection(selection configuration.ServerSelection) (
IP: IP, IP: IP,
Port: port, Port: port,
Protocol: protocol, Protocol: protocol,
Hostname: server.ServerName, // used for port forwarding TLS
} }
connections = append(connections, connection) connections = append(connections, connection)
} }

View File

@@ -15,48 +15,51 @@ import (
"strings" "strings"
"time" "time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/format" "github.com/qdm12/golibs/format"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
var ( var (
ErrBindPort = errors.New("cannot bind port") ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
ErrServerNameEmpty = errors.New("server name is empty")
ErrCreateHTTPClient = errors.New("cannot create custom HTTP client")
ErrReadSavedPortForwardData = errors.New("cannot read saved port forwarded data")
ErrRefreshPortForwardData = errors.New("cannot refresh port forward data")
ErrBindPort = errors.New("cannot bind port")
) )
// PortForward obtains a VPN server side port forwarded from PIA. // PortForward obtains a VPN server side port forwarded from PIA.
//nolint:gocognit
func (p *PIA) PortForward(ctx context.Context, client *http.Client, func (p *PIA) PortForward(ctx context.Context, client *http.Client,
logger logging.Logger, gateway net.IP, portAllower firewall.PortAllower, logger logging.Logger, gateway net.IP, serverName string) (
syncState func(port uint16) (pfFilepath string)) { port uint16, err error) {
commonName := p.activeServer.ServerName // commonName := p.activeServer.ServerName
if !p.activeServer.PortForward { // if !p.activeServer.PortForward {
logger.Error("The server " + commonName + // logger.Error("The server " + commonName +
" (region " + p.activeServer.Region + ") does not support port forwarding") // " (region " + p.activeServer.Region + ") does not support port forwarding")
return // return
} // }
if gateway == nil { if gateway == nil {
logger.Error("aborting because: VPN gateway IP address was not found") return 0, ErrGatewayIPIsNil
return } else if serverName == "" {
return 0, ErrServerNameEmpty
} }
privateIPClient, err := newHTTPClient(commonName) privateIPClient, err := newHTTPClient(serverName)
if err != nil { if err != nil {
logger.Error("aborting because: " + err.Error()) return 0, fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
return
} }
data, err := readPIAPortForwardData(p.portForwardPath) data, err := readPIAPortForwardData(p.portForwardPath)
if err != nil { if err != nil {
logger.Error(err.Error()) return 0, fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
} }
dataFound := data.Port > 0 dataFound := data.Port > 0
durationToExpiration := data.Expiration.Sub(p.timeNow()) durationToExpiration := data.Expiration.Sub(p.timeNow())
expired := durationToExpiration <= 0 expired := durationToExpiration <= 0
if dataFound { if dataFound {
logger.Info("Found persistent forwarded port data for port " + strconv.Itoa(int(data.Port))) logger.Info("Found saved forwarded port data for port " + strconv.Itoa(int(data.Port)))
if expired { if expired {
logger.Warn("Forwarded port data expired on " + logger.Warn("Forwarded port data expired on " +
data.Expiration.Format(time.RFC1123) + ", getting another one") data.Expiration.Format(time.RFC1123) + ", getting another one")
@@ -66,99 +69,65 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
} }
if !dataFound || expired { if !dataFound || expired {
tryUntilSuccessful(ctx, logger, func() error { data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, p.portForwardPath, p.authFilePath)
p.portForwardPath, p.authFilePath) if err != nil {
return err return 0, fmt.Errorf("%w: %s", ErrRefreshPortForwardData, err)
})
if ctx.Err() != nil {
return
} }
durationToExpiration = data.Expiration.Sub(p.timeNow()) durationToExpiration = data.Expiration.Sub(p.timeNow())
} }
logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) + logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
" expiring in " + format.FriendlyDuration(durationToExpiration))
// First time binding // First time binding
tryUntilSuccessful(ctx, logger, func() error { if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil { return 0, fmt.Errorf("%w: %s", ErrBindPort, err)
return fmt.Errorf("%w: %s", ErrBindPort, err)
}
return nil
})
if ctx.Err() != nil {
return
} }
filepath := syncState(data.Port) return data.Port, nil
logger.Info("Writing port to " + filepath) }
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
logger.Error(err.Error()) var (
ErrPortForwardedExpired = errors.New("port forwarded data expired")
)
func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
err error) {
privateIPClient, err := newHTTPClient(serverName)
if err != nil {
return fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
} }
if err := portAllower.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil { data, err := readPIAPortForwardData(p.portForwardPath)
logger.Error(err.Error()) if err != nil {
return fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
} }
durationToExpiration := data.Expiration.Sub(p.timeNow())
expiryTimer := time.NewTimer(durationToExpiration) expiryTimer := time.NewTimer(durationToExpiration)
const keepAlivePeriod = 15 * time.Minute const keepAlivePeriod = 15 * time.Minute
// Timer behaving as a ticker // Timer behaving as a ticker
keepAliveTimer := time.NewTimer(keepAlivePeriod) keepAliveTimer := time.NewTimer(keepAlivePeriod)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
removeCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := portAllower.RemoveAllowedPort(removeCtx, data.Port); err != nil {
logger.Error(err.Error())
}
if !keepAliveTimer.Stop() { if !keepAliveTimer.Stop() {
<-keepAliveTimer.C <-keepAliveTimer.C
} }
if !expiryTimer.Stop() { if !expiryTimer.Stop() {
<-expiryTimer.C <-expiryTimer.C
} }
return return ctx.Err()
case <-keepAliveTimer.C: case <-keepAliveTimer.C:
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil { err := bindPort(ctx, privateIPClient, gateway, data)
logger.Error("cannot bind port: " + err.Error()) if err != nil {
return fmt.Errorf("%w: %s", ErrBindPort, err)
} }
keepAliveTimer.Reset(keepAlivePeriod) keepAliveTimer.Reset(keepAlivePeriod)
case <-expiryTimer.C: case <-expiryTimer.C:
logger.Warn("Forward port has expired on " + return fmt.Errorf("%w: on %s", ErrPortForwardedExpired,
data.Expiration.Format(time.RFC1123) + ", getting another one") data.Expiration.Format(time.RFC1123))
oldPort := data.Port
for {
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
p.portForwardPath, p.authFilePath)
if err != nil {
logger.Error(err.Error())
continue
}
break
}
durationToExpiration := data.Expiration.Sub(p.timeNow())
logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) +
" expiring in " + format.FriendlyDuration(durationToExpiration))
if err := portAllower.RemoveAllowedPort(ctx, oldPort); err != nil {
logger.Error(err.Error())
}
if err := portAllower.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
logger.Error(err.Error())
}
filepath := syncState(data.Port)
logger.Info("Writing port to " + filepath)
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
logger.Error("Cannot write port forward data to file: " + err.Error())
}
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
logger.Error("Cannot bind port: " + err.Error())
}
if !keepAliveTimer.Stop() {
<-keepAliveTimer.C
}
keepAliveTimer.Reset(keepAlivePeriod)
expiryTimer.Reset(durationToExpiration)
} }
} }
} }
@@ -463,21 +432,6 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia
return nil return nil
} }
func writePortForwardedToFile(filepath string, port uint16) (err error) {
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
_, err = file.Write([]byte(fmt.Sprintf("%d", port)))
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}
// replaceInErr is used to remove sensitive information from errors. // replaceInErr is used to remove sensitive information from errors.
func replaceInErr(err error, substitutions map[string]string) error { func replaceInErr(err error, substitutions map[string]string) error {
s := replaceInString(err.Error(), substitutions) s := replaceInString(err.Error(), substitutions)

View File

@@ -1,31 +0,0 @@
package privateinternetaccess
import (
"context"
"time"
"github.com/qdm12/golibs/logging"
)
func tryUntilSuccessful(ctx context.Context, logger logging.Logger, fn func() error) {
const initialRetryPeriod = 5 * time.Second
retryPeriod := initialRetryPeriod
for {
err := fn()
if err == nil {
break
}
logger.Error(err.Error())
logger.Info("Trying again in " + retryPeriod.String())
timer := time.NewTimer(retryPeriod)
select {
case <-timer.C:
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return
}
retryPeriod *= 2
}
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/qdm12/gluetun/internal/configuration" "github.com/qdm12/gluetun/internal/configuration"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider/cyberghost" "github.com/qdm12/gluetun/internal/provider/cyberghost"
"github.com/qdm12/gluetun/internal/provider/fastestvpn" "github.com/qdm12/gluetun/internal/provider/fastestvpn"
@@ -36,9 +35,16 @@ import (
type Provider interface { type Provider interface {
GetOpenVPNConnection(selection configuration.ServerSelection) (connection models.OpenVPNConnection, err error) GetOpenVPNConnection(selection configuration.ServerSelection) (connection models.OpenVPNConnection, err error)
BuildConf(connection models.OpenVPNConnection, username string, settings configuration.OpenVPN) (lines []string) BuildConf(connection models.OpenVPNConnection, username string, settings configuration.OpenVPN) (lines []string)
PortForwarder
}
type PortForwarder interface {
PortForward(ctx context.Context, client *http.Client, PortForward(ctx context.Context, client *http.Client,
pfLogger logging.Logger, gateway net.IP, portAllower firewall.PortAllower, logger logging.Logger, gateway net.IP, serverName string) (
syncState func(port uint16) (pfFilepath string)) port uint16, err error)
KeepPortForward(ctx context.Context, client *http.Client,
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
err error)
} }
func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider { func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider {

View File

@@ -2,17 +2,21 @@ package utils
import ( import (
"context" "context"
"errors"
"fmt"
"net" "net"
"net/http" "net/http"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
type NoPortForwarder interface { type NoPortForwarder interface {
PortForward(ctx context.Context, client *http.Client, PortForward(ctx context.Context, client *http.Client,
pfLogger logging.Logger, gateway net.IP, portAllower firewall.PortAllower, logger logging.Logger, gateway net.IP, serverName string) (
syncState func(port uint16) (pfFilepath string)) port uint16, err error)
KeepPortForward(ctx context.Context, client *http.Client,
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
err error)
} }
type NoPortForwarding struct { type NoPortForwarding struct {
@@ -25,8 +29,16 @@ func NewNoPortForwarding(providerName string) *NoPortForwarding {
} }
} }
var ErrPortForwardingNotSupported = errors.New("custom port forwarding obtention is not supported")
func (n *NoPortForwarding) PortForward(ctx context.Context, client *http.Client, func (n *NoPortForwarding) PortForward(ctx context.Context, client *http.Client,
pfLogger logging.Logger, gateway net.IP, portAllower firewall.PortAllower, logger logging.Logger, gateway net.IP, serverName string) (
syncState func(port uint16) (pfFilepath string)) { port uint16, err error) {
panic("custom port forwarding obtention is not supported for " + n.providerName) return 0, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName)
}
func (n *NoPortForwarding) KeepPortForward(ctx context.Context, client *http.Client,
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
err error) {
return fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName)
} }

View File

@@ -7,6 +7,7 @@ import (
"strings" "strings"
"github.com/qdm12/gluetun/internal/openvpn" "github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
) )
@@ -22,6 +23,7 @@ func newOpenvpnHandler(ctx context.Context, looper openvpn.Looper,
type openvpnHandler struct { type openvpnHandler struct {
ctx context.Context ctx context.Context
looper openvpn.Looper looper openvpn.Looper
pf portforward.Getter
logger logging.Logger logger logging.Logger
} }
@@ -105,7 +107,7 @@ func (h *openvpnHandler) getSettings(w http.ResponseWriter) {
} }
func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) { func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) {
port := h.looper.GetPortForwarded() port := h.pf.GetPortForwarded()
encoder := json.NewEncoder(w) encoder := json.NewEncoder(w)
data := portWrapper{Port: port} data := portWrapper{Port: port}
if err := encoder.Encode(data); err != nil { if err := encoder.Encode(data); err != nil {