Maint: port forwarding refactoring (#543)
- portforward package - portforward run loop - Less functional arguments and cycles
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -24,6 +23,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/httpproxy"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/openvpn"
|
||||
"github.com/qdm12/gluetun/internal/portforward"
|
||||
"github.com/qdm12/gluetun/internal/publicip"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
"github.com/qdm12/gluetun/internal/server"
|
||||
@@ -321,8 +321,16 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
tickersGroupHandler := goshutdown.NewGroupHandler("tickers", defaultGroupSettings)
|
||||
otherGroupHandler := goshutdown.NewGroupHandler("other", defaultGroupSettings)
|
||||
|
||||
portForwardLogger := logger.NewChild(logging.Settings{Prefix: "port forwarding: "})
|
||||
portForwardLooper := portforward.NewLoop(allSettings.OpenVPN.Provider.PortForwarding,
|
||||
httpClient, firewallConf, portForwardLogger)
|
||||
portForwardHandler, portForwardCtx, portForwardDone := goshutdown.NewGoRoutineHandler(
|
||||
"port forwarding", goshutdown.GoRoutineSettings{Timeout: time.Second})
|
||||
go portForwardLooper.Run(portForwardCtx, portForwardDone)
|
||||
|
||||
openvpnLogger := logger.NewChild(logging.Settings{Prefix: "openvpn: "})
|
||||
openvpnLooper := openvpn.NewLoop(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers,
|
||||
ovpnConf, firewallConf, logger, httpClient, tunnelReadyCh)
|
||||
ovpnConf, firewallConf, routingConf, portForwardLooper, openvpnLogger, httpClient, tunnelReadyCh)
|
||||
openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler(
|
||||
"openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second})
|
||||
// wait for restartOpenvpn
|
||||
@@ -378,8 +386,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
"events routing", defaultGoRoutineSettings)
|
||||
go routeReadyEvents(eventsRoutingCtx, eventsRoutingDone, buildInfo, tunnelReadyCh,
|
||||
unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient,
|
||||
allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward,
|
||||
)
|
||||
allSettings.VersionInformation)
|
||||
controlGroupHandler.Add(eventsRoutingHandler)
|
||||
|
||||
controlServerAddress := ":" + strconv.Itoa(int(allSettings.ControlServer.Port))
|
||||
@@ -406,7 +413,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
}
|
||||
orderHandler := goshutdown.NewOrder("gluetun", orderSettings)
|
||||
orderHandler.Append(controlGroupHandler, tickersGroupHandler, healthServerHandler,
|
||||
openvpnHandler, otherGroupHandler)
|
||||
openvpnHandler, portForwardHandler, otherGroupHandler)
|
||||
|
||||
// Start openvpn for the first time in a blocking call
|
||||
// until openvpn is launched
|
||||
@@ -414,13 +421,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
||||
logger.Info("Clearing forwarded port status file " + allSettings.OpenVPN.Provider.PortForwarding.Filepath)
|
||||
if err := os.Remove(allSettings.OpenVPN.Provider.PortForwarding.Filepath); err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
return orderHandler.Shutdown(context.Background())
|
||||
}
|
||||
|
||||
@@ -450,7 +450,7 @@ func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo model
|
||||
tunnelReadyCh <-chan struct{},
|
||||
unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
|
||||
routing routing.VPNGetter, logger logging.Logger, httpClient *http.Client,
|
||||
versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) {
|
||||
versionInformation bool) {
|
||||
defer close(done)
|
||||
|
||||
// for linters only
|
||||
@@ -503,15 +503,6 @@ func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo model
|
||||
updaterTickerDone = make(chan struct{})
|
||||
go unboundLooper.RunRestartTicker(restartTickerContext, unboundTickerDone)
|
||||
go updaterLooper.RunRestartTicker(restartTickerContext, updaterTickerDone)
|
||||
if portForwardingEnabled {
|
||||
// vpnGateway required only for PIA
|
||||
vpnGateway, err := routing.VPNLocalGatewayIP()
|
||||
if err != nil {
|
||||
logger.Error("cannot get VPN local gateway IP: " + err.Error())
|
||||
}
|
||||
logger.Info("VPN gateway IP address: " + vpnGateway.String())
|
||||
startPortForward(vpnGateway)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ func (l *Loop) collectLines(stdout, stderr <-chan string, done chan<- struct{})
|
||||
}
|
||||
if strings.Contains(line, "Initialization Sequence Completed") {
|
||||
l.tunnelReady <- struct{}{}
|
||||
l.startPFCh <- struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -11,6 +10,8 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/loopstate"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/openvpn/state"
|
||||
"github.com/qdm12/gluetun/internal/portforward"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -22,8 +23,6 @@ type Looper interface {
|
||||
loopstate.Applier
|
||||
SettingsGetSetter
|
||||
ServersGetterSetter
|
||||
PortForwadedGetter
|
||||
PortForwader
|
||||
}
|
||||
|
||||
type Loop struct {
|
||||
@@ -35,19 +34,21 @@ type Loop struct {
|
||||
pgid int
|
||||
targetConfPath string
|
||||
// Configurators
|
||||
conf StarterAuthWriter
|
||||
fw firewallConfigurer
|
||||
conf StarterAuthWriter
|
||||
fw firewallConfigurer
|
||||
routing routing.VPNLocalGatewayIPGetter
|
||||
portForward portforward.StartStopper
|
||||
// Other objects
|
||||
logger, pfLogger logging.Logger
|
||||
client *http.Client
|
||||
tunnelReady chan<- struct{}
|
||||
logger logging.Logger
|
||||
client *http.Client
|
||||
tunnelReady chan<- struct{}
|
||||
// Internal channels and values
|
||||
stop <-chan struct{}
|
||||
stopped chan<- struct{}
|
||||
start <-chan struct{}
|
||||
running chan<- models.LoopStatus
|
||||
portForwardSignals chan net.IP
|
||||
userTrigger bool
|
||||
stop <-chan struct{}
|
||||
stopped chan<- struct{}
|
||||
start <-chan struct{}
|
||||
running chan<- models.LoopStatus
|
||||
userTrigger bool
|
||||
startPFCh chan struct{}
|
||||
// Internal constant values
|
||||
backoffTime time.Duration
|
||||
}
|
||||
@@ -63,7 +64,8 @@ const (
|
||||
|
||||
func NewLoop(settings configuration.OpenVPN, username string,
|
||||
puid, pgid int, allServers models.AllServers, conf Configurator,
|
||||
fw firewallConfigurer, logger logging.ParentLogger,
|
||||
fw firewallConfigurer, routing routing.VPNLocalGatewayIPGetter,
|
||||
portForward portforward.StartStopper, logger logging.Logger,
|
||||
client *http.Client, tunnelReady chan<- struct{}) *Loop {
|
||||
start := make(chan struct{})
|
||||
running := make(chan models.LoopStatus)
|
||||
@@ -74,24 +76,25 @@ func NewLoop(settings configuration.OpenVPN, username string,
|
||||
state := state.New(statusManager, settings, allServers)
|
||||
|
||||
return &Loop{
|
||||
statusManager: statusManager,
|
||||
state: state,
|
||||
username: username,
|
||||
puid: puid,
|
||||
pgid: pgid,
|
||||
targetConfPath: constants.OpenVPNConf,
|
||||
conf: conf,
|
||||
fw: fw,
|
||||
logger: logger.NewChild(logging.Settings{Prefix: "openvpn: "}),
|
||||
pfLogger: logger.NewChild(logging.Settings{Prefix: "port forwarding: "}),
|
||||
client: client,
|
||||
tunnelReady: tunnelReady,
|
||||
start: start,
|
||||
running: running,
|
||||
stop: stop,
|
||||
stopped: stopped,
|
||||
portForwardSignals: make(chan net.IP),
|
||||
userTrigger: true,
|
||||
backoffTime: defaultBackoffTime,
|
||||
statusManager: statusManager,
|
||||
state: state,
|
||||
username: username,
|
||||
puid: puid,
|
||||
pgid: pgid,
|
||||
targetConfPath: constants.OpenVPNConf,
|
||||
conf: conf,
|
||||
fw: fw,
|
||||
routing: routing,
|
||||
portForward: portForward,
|
||||
logger: logger,
|
||||
client: client,
|
||||
tunnelReady: tunnelReady,
|
||||
start: start,
|
||||
running: running,
|
||||
stop: stop,
|
||||
stopped: stopped,
|
||||
userTrigger: true,
|
||||
startPFCh: make(chan struct{}),
|
||||
backoffTime: defaultBackoffTime,
|
||||
}
|
||||
}
|
||||
|
||||
47
internal/openvpn/portforward.go
Normal file
47
internal/openvpn/portforward.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/portforward"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
)
|
||||
|
||||
func (l *Loop) startPortForwarding(ctx context.Context,
|
||||
portForwarder provider.PortForwarder, serverName string) {
|
||||
if !l.GetSettings().Provider.PortForwarding.Enabled {
|
||||
return
|
||||
}
|
||||
// only used for PIA for now
|
||||
gateway, err := l.routing.VPNLocalGatewayIP()
|
||||
if err != nil {
|
||||
l.logger.Error("cannot obtain VPN local gateway IP: " + err.Error())
|
||||
return
|
||||
}
|
||||
l.logger.Info("VPN gateway IP address: " + gateway.String())
|
||||
pfData := portforward.StartData{
|
||||
PortForwarder: portForwarder,
|
||||
Gateway: gateway,
|
||||
ServerName: serverName,
|
||||
Interface: constants.TUN,
|
||||
}
|
||||
_, err = l.portForward.Start(ctx, pfData)
|
||||
if err != nil {
|
||||
l.logger.Error("cannot start port forwarding: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Loop) stopPortForwarding(ctx context.Context, timeout time.Duration) {
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
_, err := l.portForward.Stop(ctx)
|
||||
if err != nil {
|
||||
l.logger.Error("cannot stop port forwarding: " + err.Error())
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/openvpn/state"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
)
|
||||
|
||||
type PortForwadedGetter = state.PortForwardedGetter
|
||||
|
||||
func (l *Loop) GetPortForwarded() (port uint16) {
|
||||
return l.state.GetPortForwarded()
|
||||
}
|
||||
|
||||
type PortForwader interface {
|
||||
PortForward(vpnGatewayIP net.IP)
|
||||
}
|
||||
|
||||
func (l *Loop) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway }
|
||||
|
||||
// portForward is a blocking operation which may or may not be infinite.
|
||||
// You should therefore always call it in a goroutine.
|
||||
func (l *Loop) portForward(ctx context.Context,
|
||||
providerConf provider.Provider, client *http.Client, gateway net.IP) {
|
||||
settings := l.state.GetSettings()
|
||||
if !settings.Provider.PortForwarding.Enabled {
|
||||
return
|
||||
}
|
||||
syncState := func(port uint16) (pfFilepath string) {
|
||||
l.state.SetPortForwarded(port)
|
||||
settings := l.state.GetSettings()
|
||||
return settings.Provider.PortForwarding.Filepath
|
||||
}
|
||||
providerConf.PortForward(ctx, client, l.pfLogger,
|
||||
gateway, l.fw, syncState)
|
||||
}
|
||||
@@ -88,41 +88,31 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
<-lineCollectionDone
|
||||
}
|
||||
|
||||
// Needs the stream line from main.go to know when the tunnel is up
|
||||
portForwardDone := make(chan struct{})
|
||||
go func(ctx context.Context) {
|
||||
defer close(portForwardDone)
|
||||
select {
|
||||
// TODO have a way to disable pf with a context
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case gateway := <-l.portForwardSignals:
|
||||
l.portForward(ctx, providerConf, l.client, gateway)
|
||||
}
|
||||
}(openvpnCtx)
|
||||
|
||||
l.backoffTime = defaultBackoffTime
|
||||
l.signalOrSetStatus(constants.Running)
|
||||
|
||||
stayHere := true
|
||||
for stayHere {
|
||||
select {
|
||||
case <-l.startPFCh:
|
||||
l.startPortForwarding(ctx, providerConf, connection.Hostname)
|
||||
case <-ctx.Done():
|
||||
const pfTimeout = 100 * time.Millisecond
|
||||
l.stopPortForwarding(context.Background(), pfTimeout)
|
||||
openvpnCancel()
|
||||
<-waitError
|
||||
close(waitError)
|
||||
closeStreams()
|
||||
<-portForwardDone
|
||||
return
|
||||
case <-l.stop:
|
||||
l.userTrigger = true
|
||||
l.logger.Info("stopping")
|
||||
l.stopPortForwarding(ctx, 0)
|
||||
openvpnCancel()
|
||||
<-waitError
|
||||
// do not close waitError or the waitError
|
||||
// select case will trigger
|
||||
closeStreams()
|
||||
<-portForwardDone
|
||||
l.stopped <- struct{}{}
|
||||
case <-l.start:
|
||||
l.userTrigger = true
|
||||
@@ -134,9 +124,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
|
||||
l.statusManager.Lock() // prevent SetStatus from running in parallel
|
||||
|
||||
l.stopPortForwarding(ctx, 0)
|
||||
openvpnCancel()
|
||||
l.statusManager.SetStatus(constants.Crashed)
|
||||
<-portForwardDone
|
||||
l.logAndWait(ctx, err)
|
||||
stayHere = false
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ var _ Manager = (*State)(nil)
|
||||
type Manager interface {
|
||||
SettingsGetSetter
|
||||
ServersGetterSetter
|
||||
PortForwardedGetterSetter
|
||||
GetSettingsAndServers() (settings configuration.OpenVPN,
|
||||
allServers models.AllServers)
|
||||
}
|
||||
@@ -36,9 +35,6 @@ type State struct {
|
||||
|
||||
allServers models.AllServers
|
||||
allServersMu sync.RWMutex
|
||||
|
||||
portForwarded uint16
|
||||
portForwardedMu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *State) GetSettingsAndServers() (settings configuration.OpenVPN,
|
||||
|
||||
32
internal/portforward/firewall.go
Normal file
32
internal/portforward/firewall.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package portforward
|
||||
|
||||
import "context"
|
||||
|
||||
// firewallBlockPort obtains the state port thread safely and blocks
|
||||
// it in the firewall if it is not the zero value (0).
|
||||
func (l *Loop) firewallBlockPort(ctx context.Context) {
|
||||
port := l.state.GetPortForwarded()
|
||||
if port == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := l.portAllower.RemoveAllowedPort(ctx, port)
|
||||
if err != nil {
|
||||
l.logger.Error("cannot block previous port in firewall: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// firewallAllowPort obtains the state port thread safely and allows
|
||||
// it in the firewall if it is not the zero value (0).
|
||||
func (l *Loop) firewallAllowPort(ctx context.Context) {
|
||||
port := l.state.GetPortForwarded()
|
||||
if port == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
startData := l.state.GetStartData()
|
||||
err := l.portAllower.SetAllowedPort(ctx, port, startData.Interface)
|
||||
if err != nil {
|
||||
l.logger.Error("cannot allow port through firewall: " + err.Error())
|
||||
}
|
||||
}
|
||||
37
internal/portforward/fs.go
Normal file
37
internal/portforward/fs.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func (l *Loop) removePortForwardedFile() {
|
||||
filepath := l.state.GetSettings().Filepath
|
||||
l.logger.Info("removing port file " + filepath)
|
||||
if err := os.Remove(filepath); err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Loop) writePortForwardedFile(port uint16) {
|
||||
filepath := l.state.GetSettings().Filepath
|
||||
l.logger.Info("writing port file " + filepath)
|
||||
if err := writePortForwardedToFile(filepath, port); err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func writePortForwardedToFile(filepath string, port uint16) (err error) {
|
||||
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = file.Write([]byte(fmt.Sprint(port)))
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return file.Close()
|
||||
}
|
||||
9
internal/portforward/get.go
Normal file
9
internal/portforward/get.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package portforward
|
||||
|
||||
import "github.com/qdm12/gluetun/internal/portforward/state"
|
||||
|
||||
type Getter = state.PortForwardedGetter
|
||||
|
||||
func (l *Loop) GetPortForwarded() (port uint16) {
|
||||
return l.state.GetPortForwarded()
|
||||
}
|
||||
22
internal/portforward/helpers.go
Normal file
22
internal/portforward/helpers.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (l *Loop) logAndWait(ctx context.Context, err error) {
|
||||
if err != nil {
|
||||
l.logger.Error(err.Error())
|
||||
}
|
||||
l.logger.Info("retrying in " + l.backoffTime.String())
|
||||
timer := time.NewTimer(l.backoffTime)
|
||||
l.backoffTime *= 2
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
}
|
||||
}
|
||||
71
internal/portforward/loop.go
Normal file
71
internal/portforward/loop.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/loopstate"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/portforward/state"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
var _ Looper = (*Loop)(nil)
|
||||
|
||||
type Looper interface {
|
||||
Runner
|
||||
loopstate.Getter
|
||||
StartStopper
|
||||
SettingsGetSetter
|
||||
Getter
|
||||
}
|
||||
|
||||
type Loop struct {
|
||||
statusManager loopstate.Manager
|
||||
state state.Manager
|
||||
// Objects
|
||||
client *http.Client
|
||||
portAllower firewall.PortAllower
|
||||
logger logging.Logger
|
||||
// Internal channels and locks
|
||||
start chan struct{}
|
||||
running chan models.LoopStatus
|
||||
stop chan struct{}
|
||||
stopped chan struct{}
|
||||
startMu sync.Mutex
|
||||
backoffTime time.Duration
|
||||
userTrigger bool
|
||||
}
|
||||
|
||||
const defaultBackoffTime = 5 * time.Second
|
||||
|
||||
func NewLoop(settings configuration.PortForwarding,
|
||||
client *http.Client, portAllower firewall.PortAllower,
|
||||
logger logging.Logger) *Loop {
|
||||
start := make(chan struct{})
|
||||
running := make(chan models.LoopStatus)
|
||||
stop := make(chan struct{})
|
||||
stopped := make(chan struct{})
|
||||
|
||||
statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped)
|
||||
state := state.New(statusManager, settings)
|
||||
|
||||
return &Loop{
|
||||
statusManager: statusManager,
|
||||
state: state,
|
||||
// Objects
|
||||
client: client,
|
||||
portAllower: portAllower,
|
||||
logger: logger,
|
||||
start: start,
|
||||
running: running,
|
||||
stop: stop,
|
||||
stopped: stopped,
|
||||
userTrigger: true,
|
||||
backoffTime: defaultBackoffTime,
|
||||
}
|
||||
}
|
||||
97
internal/portforward/run.go
Normal file
97
internal/portforward/run.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
)
|
||||
|
||||
type Runner interface {
|
||||
Run(ctx context.Context, done chan<- struct{})
|
||||
}
|
||||
|
||||
func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
|
||||
defer close(done)
|
||||
|
||||
select {
|
||||
case <-l.start: // l.state.SetStartData called beforehand
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
for ctx.Err() == nil {
|
||||
pfCtx, pfCancel := context.WithCancel(ctx)
|
||||
|
||||
portCh := make(chan uint16)
|
||||
errorCh := make(chan error)
|
||||
|
||||
startData := l.state.GetStartData()
|
||||
|
||||
go func(ctx context.Context, startData StartData) {
|
||||
port, err := startData.PortForwarder.PortForward(ctx, l.client, l.logger,
|
||||
startData.Gateway, startData.ServerName)
|
||||
if err != nil {
|
||||
errorCh <- err
|
||||
return
|
||||
}
|
||||
portCh <- port
|
||||
|
||||
// Infinite loop
|
||||
err = startData.PortForwarder.KeepPortForward(ctx, l.client, l.logger,
|
||||
port, startData.Gateway, startData.ServerName)
|
||||
errorCh <- err
|
||||
}(pfCtx, startData)
|
||||
|
||||
if l.userTrigger {
|
||||
l.userTrigger = false
|
||||
l.running <- constants.Running
|
||||
} else { // crash
|
||||
l.backoffTime = defaultBackoffTime
|
||||
l.statusManager.SetStatus(constants.Running)
|
||||
}
|
||||
|
||||
stayHere := true
|
||||
for stayHere {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
pfCancel()
|
||||
<-errorCh
|
||||
close(errorCh)
|
||||
close(portCh)
|
||||
l.removePortForwardedFile()
|
||||
l.firewallBlockPort(ctx)
|
||||
l.state.SetPortForwarded(0)
|
||||
return
|
||||
case <-l.start:
|
||||
l.userTrigger = true
|
||||
l.logger.Info("starting")
|
||||
pfCancel()
|
||||
stayHere = false
|
||||
case <-l.stop:
|
||||
l.userTrigger = true
|
||||
l.logger.Info("stopping")
|
||||
pfCancel()
|
||||
<-errorCh
|
||||
l.removePortForwardedFile()
|
||||
l.firewallBlockPort(ctx)
|
||||
l.state.SetPortForwarded(0)
|
||||
l.stopped <- struct{}{}
|
||||
case port := <-portCh:
|
||||
l.logger.Info("port forwarded is " + strconv.Itoa(int(port)))
|
||||
l.firewallBlockPort(ctx)
|
||||
l.state.SetPortForwarded(port)
|
||||
l.firewallAllowPort(ctx)
|
||||
l.writePortForwardedFile(port)
|
||||
case err := <-errorCh:
|
||||
pfCancel()
|
||||
close(errorCh)
|
||||
close(portCh)
|
||||
l.statusManager.SetStatus(constants.Crashed)
|
||||
l.logAndWait(ctx, err)
|
||||
stayHere = false
|
||||
}
|
||||
}
|
||||
pfCancel() // for linting
|
||||
}
|
||||
}
|
||||
19
internal/portforward/settings.go
Normal file
19
internal/portforward/settings.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/portforward/state"
|
||||
)
|
||||
|
||||
type SettingsGetSetter = state.SettingsGetSetter
|
||||
|
||||
func (l *Loop) GetSettings() (settings configuration.PortForwarding) {
|
||||
return l.state.GetSettings()
|
||||
}
|
||||
|
||||
func (l *Loop) SetSettings(ctx context.Context, settings configuration.PortForwarding) (
|
||||
outcome string) {
|
||||
return l.state.SetSettings(ctx, settings)
|
||||
}
|
||||
55
internal/portforward/state/settings.go
Normal file
55
internal/portforward/state/settings.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
)
|
||||
|
||||
type SettingsGetSetter interface {
|
||||
GetSettings() (settings configuration.PortForwarding)
|
||||
SetSettings(ctx context.Context,
|
||||
settings configuration.PortForwarding) (outcome string)
|
||||
}
|
||||
|
||||
func (s *State) GetSettings() (settings configuration.PortForwarding) {
|
||||
s.settingsMu.RLock()
|
||||
defer s.settingsMu.RUnlock()
|
||||
return s.settings
|
||||
}
|
||||
|
||||
func (s *State) SetSettings(ctx context.Context, settings configuration.PortForwarding) (
|
||||
outcome string) {
|
||||
s.settingsMu.Lock()
|
||||
|
||||
settingsUnchanged := reflect.DeepEqual(s.settings, settings)
|
||||
if settingsUnchanged {
|
||||
s.settingsMu.Unlock()
|
||||
return "settings left unchanged"
|
||||
}
|
||||
|
||||
if s.settings.Filepath != settings.Filepath {
|
||||
_ = os.Rename(s.settings.Filepath, settings.Filepath)
|
||||
}
|
||||
|
||||
newEnabled := settings.Enabled
|
||||
previousEnabled := s.settings.Enabled
|
||||
|
||||
s.settings = settings
|
||||
s.settingsMu.Unlock()
|
||||
|
||||
switch {
|
||||
case !newEnabled && !previousEnabled:
|
||||
case newEnabled && previousEnabled:
|
||||
// no need to restart for now since we os.Rename the file here.
|
||||
case newEnabled && !previousEnabled:
|
||||
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Running)
|
||||
case !newEnabled && previousEnabled:
|
||||
_, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped)
|
||||
}
|
||||
|
||||
return "settings updated"
|
||||
}
|
||||
39
internal/portforward/state/startdata.go
Normal file
39
internal/portforward/state/startdata.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
)
|
||||
|
||||
type StartData struct {
|
||||
PortForwarder provider.PortForwarder
|
||||
Gateway net.IP // needed for PIA
|
||||
ServerName string // needed for PIA
|
||||
Interface string // tun0 or wg0 for example
|
||||
}
|
||||
|
||||
type StartDataGetterSetter interface {
|
||||
StartDataGetter
|
||||
StartDataSetter
|
||||
}
|
||||
|
||||
type StartDataGetter interface {
|
||||
GetStartData() (startData StartData)
|
||||
}
|
||||
|
||||
func (s *State) GetStartData() (startData StartData) {
|
||||
s.startDataMu.RLock()
|
||||
defer s.startDataMu.RUnlock()
|
||||
return s.startData
|
||||
}
|
||||
|
||||
type StartDataSetter interface {
|
||||
SetStartData(startData StartData)
|
||||
}
|
||||
|
||||
func (s *State) SetStartData(startData StartData) {
|
||||
s.startDataMu.Lock()
|
||||
defer s.startDataMu.Unlock()
|
||||
s.startData = startData
|
||||
}
|
||||
37
internal/portforward/state/state.go
Normal file
37
internal/portforward/state/state.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/loopstate"
|
||||
)
|
||||
|
||||
var _ Manager = (*State)(nil)
|
||||
|
||||
type Manager interface {
|
||||
SettingsGetSetter
|
||||
PortForwardedGetterSetter
|
||||
StartDataGetterSetter
|
||||
}
|
||||
|
||||
func New(statusApplier loopstate.Applier,
|
||||
settings configuration.PortForwarding) *State {
|
||||
return &State{
|
||||
statusApplier: statusApplier,
|
||||
settings: settings,
|
||||
}
|
||||
}
|
||||
|
||||
type State struct {
|
||||
statusApplier loopstate.Applier
|
||||
|
||||
settings configuration.PortForwarding
|
||||
settingsMu sync.RWMutex
|
||||
|
||||
portForwarded uint16
|
||||
portForwardedMu sync.RWMutex
|
||||
|
||||
startData StartData
|
||||
startDataMu sync.RWMutex
|
||||
}
|
||||
33
internal/portforward/status.go
Normal file
33
internal/portforward/status.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/portforward/state"
|
||||
)
|
||||
|
||||
func (l *Loop) GetStatus() (status models.LoopStatus) {
|
||||
return l.statusManager.GetStatus()
|
||||
}
|
||||
|
||||
type StartData = state.StartData
|
||||
|
||||
type StartStopper interface {
|
||||
Start(ctx context.Context, data StartData) (
|
||||
outcome string, err error)
|
||||
Stop(ctx context.Context) (outcome string, err error)
|
||||
}
|
||||
|
||||
func (l *Loop) Start(ctx context.Context, data StartData) (
|
||||
outcome string, err error) {
|
||||
l.startMu.Lock()
|
||||
defer l.startMu.Unlock()
|
||||
l.state.SetStartData(data)
|
||||
return l.statusManager.ApplyStatus(ctx, constants.Running)
|
||||
}
|
||||
|
||||
func (l *Loop) Stop(ctx context.Context) (outcome string, err error) {
|
||||
return l.statusManager.ApplyStatus(ctx, constants.Stopped)
|
||||
}
|
||||
@@ -31,6 +31,7 @@ func (p *PIA) GetOpenVPNConnection(selection configuration.ServerSelection) (
|
||||
IP: IP,
|
||||
Port: port,
|
||||
Protocol: protocol,
|
||||
Hostname: server.ServerName, // used for port forwarding TLS
|
||||
}
|
||||
connections = append(connections, connection)
|
||||
}
|
||||
|
||||
@@ -15,48 +15,51 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/format"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBindPort = errors.New("cannot bind port")
|
||||
ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
|
||||
ErrServerNameEmpty = errors.New("server name is empty")
|
||||
ErrCreateHTTPClient = errors.New("cannot create custom HTTP client")
|
||||
ErrReadSavedPortForwardData = errors.New("cannot read saved port forwarded data")
|
||||
ErrRefreshPortForwardData = errors.New("cannot refresh port forward data")
|
||||
ErrBindPort = errors.New("cannot bind port")
|
||||
)
|
||||
|
||||
// PortForward obtains a VPN server side port forwarded from PIA.
|
||||
//nolint:gocognit
|
||||
func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
logger logging.Logger, gateway net.IP, portAllower firewall.PortAllower,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
commonName := p.activeServer.ServerName
|
||||
if !p.activeServer.PortForward {
|
||||
logger.Error("The server " + commonName +
|
||||
" (region " + p.activeServer.Region + ") does not support port forwarding")
|
||||
return
|
||||
}
|
||||
logger logging.Logger, gateway net.IP, serverName string) (
|
||||
port uint16, err error) {
|
||||
// commonName := p.activeServer.ServerName
|
||||
// if !p.activeServer.PortForward {
|
||||
// logger.Error("The server " + commonName +
|
||||
// " (region " + p.activeServer.Region + ") does not support port forwarding")
|
||||
// return
|
||||
// }
|
||||
if gateway == nil {
|
||||
logger.Error("aborting because: VPN gateway IP address was not found")
|
||||
return
|
||||
return 0, ErrGatewayIPIsNil
|
||||
} else if serverName == "" {
|
||||
return 0, ErrServerNameEmpty
|
||||
}
|
||||
|
||||
privateIPClient, err := newHTTPClient(commonName)
|
||||
privateIPClient, err := newHTTPClient(serverName)
|
||||
if err != nil {
|
||||
logger.Error("aborting because: " + err.Error())
|
||||
return
|
||||
return 0, fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
|
||||
}
|
||||
|
||||
data, err := readPIAPortForwardData(p.portForwardPath)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
return 0, fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
|
||||
}
|
||||
|
||||
dataFound := data.Port > 0
|
||||
durationToExpiration := data.Expiration.Sub(p.timeNow())
|
||||
expired := durationToExpiration <= 0
|
||||
|
||||
if dataFound {
|
||||
logger.Info("Found persistent forwarded port data for port " + strconv.Itoa(int(data.Port)))
|
||||
logger.Info("Found saved forwarded port data for port " + strconv.Itoa(int(data.Port)))
|
||||
if expired {
|
||||
logger.Warn("Forwarded port data expired on " +
|
||||
data.Expiration.Format(time.RFC1123) + ", getting another one")
|
||||
@@ -66,99 +69,65 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
}
|
||||
|
||||
if !dataFound || expired {
|
||||
tryUntilSuccessful(ctx, logger, func() error {
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
||||
p.portForwardPath, p.authFilePath)
|
||||
return err
|
||||
})
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
||||
p.portForwardPath, p.authFilePath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%w: %s", ErrRefreshPortForwardData, err)
|
||||
}
|
||||
durationToExpiration = data.Expiration.Sub(p.timeNow())
|
||||
}
|
||||
logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) +
|
||||
" expiring in " + format.FriendlyDuration(durationToExpiration))
|
||||
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
|
||||
|
||||
// First time binding
|
||||
tryUntilSuccessful(ctx, logger, func() error {
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrBindPort, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
return 0, fmt.Errorf("%w: %s", ErrBindPort, err)
|
||||
}
|
||||
|
||||
filepath := syncState(data.Port)
|
||||
logger.Info("Writing port to " + filepath)
|
||||
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
|
||||
logger.Error(err.Error())
|
||||
return data.Port, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrPortForwardedExpired = errors.New("port forwarded data expired")
|
||||
)
|
||||
|
||||
func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
|
||||
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
|
||||
err error) {
|
||||
privateIPClient, err := newHTTPClient(serverName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
|
||||
}
|
||||
|
||||
if err := portAllower.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
|
||||
logger.Error(err.Error())
|
||||
data, err := readPIAPortForwardData(p.portForwardPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
|
||||
}
|
||||
|
||||
durationToExpiration := data.Expiration.Sub(p.timeNow())
|
||||
expiryTimer := time.NewTimer(durationToExpiration)
|
||||
const keepAlivePeriod = 15 * time.Minute
|
||||
// Timer behaving as a ticker
|
||||
keepAliveTimer := time.NewTimer(keepAlivePeriod)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
removeCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := portAllower.RemoveAllowedPort(removeCtx, data.Port); err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
if !keepAliveTimer.Stop() {
|
||||
<-keepAliveTimer.C
|
||||
}
|
||||
if !expiryTimer.Stop() {
|
||||
<-expiryTimer.C
|
||||
}
|
||||
return
|
||||
return ctx.Err()
|
||||
case <-keepAliveTimer.C:
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
logger.Error("cannot bind port: " + err.Error())
|
||||
err := bindPort(ctx, privateIPClient, gateway, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrBindPort, err)
|
||||
}
|
||||
keepAliveTimer.Reset(keepAlivePeriod)
|
||||
case <-expiryTimer.C:
|
||||
logger.Warn("Forward port has expired on " +
|
||||
data.Expiration.Format(time.RFC1123) + ", getting another one")
|
||||
oldPort := data.Port
|
||||
for {
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
||||
p.portForwardPath, p.authFilePath)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
durationToExpiration := data.Expiration.Sub(p.timeNow())
|
||||
logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) +
|
||||
" expiring in " + format.FriendlyDuration(durationToExpiration))
|
||||
if err := portAllower.RemoveAllowedPort(ctx, oldPort); err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
if err := portAllower.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
filepath := syncState(data.Port)
|
||||
logger.Info("Writing port to " + filepath)
|
||||
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
|
||||
logger.Error("Cannot write port forward data to file: " + err.Error())
|
||||
}
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
logger.Error("Cannot bind port: " + err.Error())
|
||||
}
|
||||
if !keepAliveTimer.Stop() {
|
||||
<-keepAliveTimer.C
|
||||
}
|
||||
keepAliveTimer.Reset(keepAlivePeriod)
|
||||
expiryTimer.Reset(durationToExpiration)
|
||||
return fmt.Errorf("%w: on %s", ErrPortForwardedExpired,
|
||||
data.Expiration.Format(time.RFC1123))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -463,21 +432,6 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia
|
||||
return nil
|
||||
}
|
||||
|
||||
func writePortForwardedToFile(filepath string, port uint16) (err error) {
|
||||
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = file.Write([]byte(fmt.Sprintf("%d", port)))
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
// replaceInErr is used to remove sensitive information from errors.
|
||||
func replaceInErr(err error, substitutions map[string]string) error {
|
||||
s := replaceInString(err.Error(), substitutions)
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
package privateinternetaccess
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
func tryUntilSuccessful(ctx context.Context, logger logging.Logger, fn func() error) {
|
||||
const initialRetryPeriod = 5 * time.Second
|
||||
retryPeriod := initialRetryPeriod
|
||||
for {
|
||||
err := fn()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
logger.Error(err.Error())
|
||||
logger.Info("Trying again in " + retryPeriod.String())
|
||||
timer := time.NewTimer(retryPeriod)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return
|
||||
}
|
||||
retryPeriod *= 2
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/configuration"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider/cyberghost"
|
||||
"github.com/qdm12/gluetun/internal/provider/fastestvpn"
|
||||
@@ -36,9 +35,16 @@ import (
|
||||
type Provider interface {
|
||||
GetOpenVPNConnection(selection configuration.ServerSelection) (connection models.OpenVPNConnection, err error)
|
||||
BuildConf(connection models.OpenVPNConnection, username string, settings configuration.OpenVPN) (lines []string)
|
||||
PortForwarder
|
||||
}
|
||||
|
||||
type PortForwarder interface {
|
||||
PortForward(ctx context.Context, client *http.Client,
|
||||
pfLogger logging.Logger, gateway net.IP, portAllower firewall.PortAllower,
|
||||
syncState func(port uint16) (pfFilepath string))
|
||||
logger logging.Logger, gateway net.IP, serverName string) (
|
||||
port uint16, err error)
|
||||
KeepPortForward(ctx context.Context, client *http.Client,
|
||||
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
|
||||
err error)
|
||||
}
|
||||
|
||||
func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider {
|
||||
|
||||
@@ -2,17 +2,21 @@ package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type NoPortForwarder interface {
|
||||
PortForward(ctx context.Context, client *http.Client,
|
||||
pfLogger logging.Logger, gateway net.IP, portAllower firewall.PortAllower,
|
||||
syncState func(port uint16) (pfFilepath string))
|
||||
logger logging.Logger, gateway net.IP, serverName string) (
|
||||
port uint16, err error)
|
||||
KeepPortForward(ctx context.Context, client *http.Client,
|
||||
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
|
||||
err error)
|
||||
}
|
||||
|
||||
type NoPortForwarding struct {
|
||||
@@ -25,8 +29,16 @@ func NewNoPortForwarding(providerName string) *NoPortForwarding {
|
||||
}
|
||||
}
|
||||
|
||||
var ErrPortForwardingNotSupported = errors.New("custom port forwarding obtention is not supported")
|
||||
|
||||
func (n *NoPortForwarding) PortForward(ctx context.Context, client *http.Client,
|
||||
pfLogger logging.Logger, gateway net.IP, portAllower firewall.PortAllower,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
panic("custom port forwarding obtention is not supported for " + n.providerName)
|
||||
logger logging.Logger, gateway net.IP, serverName string) (
|
||||
port uint16, err error) {
|
||||
return 0, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName)
|
||||
}
|
||||
|
||||
func (n *NoPortForwarding) KeepPortForward(ctx context.Context, client *http.Client,
|
||||
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
|
||||
err error) {
|
||||
return fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/openvpn"
|
||||
"github.com/qdm12/gluetun/internal/portforward"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -22,6 +23,7 @@ func newOpenvpnHandler(ctx context.Context, looper openvpn.Looper,
|
||||
type openvpnHandler struct {
|
||||
ctx context.Context
|
||||
looper openvpn.Looper
|
||||
pf portforward.Getter
|
||||
logger logging.Logger
|
||||
}
|
||||
|
||||
@@ -105,7 +107,7 @@ func (h *openvpnHandler) getSettings(w http.ResponseWriter) {
|
||||
}
|
||||
|
||||
func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) {
|
||||
port := h.looper.GetPortForwarded()
|
||||
port := h.pf.GetPortForwarded()
|
||||
encoder := json.NewEncoder(w)
|
||||
data := portWrapper{Port: port}
|
||||
if err := encoder.Encode(data); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user