diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 35a87e3d..ef4192f7 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "net/http" "os" "os/signal" @@ -24,6 +23,7 @@ import ( "github.com/qdm12/gluetun/internal/httpproxy" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/openvpn" + "github.com/qdm12/gluetun/internal/portforward" "github.com/qdm12/gluetun/internal/publicip" "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/gluetun/internal/server" @@ -321,8 +321,16 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, tickersGroupHandler := goshutdown.NewGroupHandler("tickers", defaultGroupSettings) otherGroupHandler := goshutdown.NewGroupHandler("other", defaultGroupSettings) + portForwardLogger := logger.NewChild(logging.Settings{Prefix: "port forwarding: "}) + portForwardLooper := portforward.NewLoop(allSettings.OpenVPN.Provider.PortForwarding, + httpClient, firewallConf, portForwardLogger) + portForwardHandler, portForwardCtx, portForwardDone := goshutdown.NewGoRoutineHandler( + "port forwarding", goshutdown.GoRoutineSettings{Timeout: time.Second}) + go portForwardLooper.Run(portForwardCtx, portForwardDone) + + openvpnLogger := logger.NewChild(logging.Settings{Prefix: "openvpn: "}) openvpnLooper := openvpn.NewLoop(allSettings.OpenVPN, nonRootUsername, puid, pgid, allServers, - ovpnConf, firewallConf, logger, httpClient, tunnelReadyCh) + ovpnConf, firewallConf, routingConf, portForwardLooper, openvpnLogger, httpClient, tunnelReadyCh) openvpnHandler, openvpnCtx, openvpnDone := goshutdown.NewGoRoutineHandler( "openvpn", goshutdown.GoRoutineSettings{Timeout: time.Second}) // wait for restartOpenvpn @@ -378,8 +386,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, "events routing", defaultGoRoutineSettings) go routeReadyEvents(eventsRoutingCtx, eventsRoutingDone, buildInfo, tunnelReadyCh, unboundLooper, updaterLooper, publicIPLooper, routingConf, logger, httpClient, - allSettings.VersionInformation, allSettings.OpenVPN.Provider.PortForwarding.Enabled, openvpnLooper.PortForward, - ) + allSettings.VersionInformation) controlGroupHandler.Add(eventsRoutingHandler) controlServerAddress := ":" + strconv.Itoa(int(allSettings.ControlServer.Port)) @@ -406,7 +413,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, } orderHandler := goshutdown.NewOrder("gluetun", orderSettings) orderHandler.Append(controlGroupHandler, tickersGroupHandler, healthServerHandler, - openvpnHandler, otherGroupHandler) + openvpnHandler, portForwardHandler, otherGroupHandler) // Start openvpn for the first time in a blocking call // until openvpn is launched @@ -414,13 +421,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, <-ctx.Done() - if allSettings.OpenVPN.Provider.PortForwarding.Enabled { - logger.Info("Clearing forwarded port status file " + allSettings.OpenVPN.Provider.PortForwarding.Filepath) - if err := os.Remove(allSettings.OpenVPN.Provider.PortForwarding.Filepath); err != nil { - logger.Error(err.Error()) - } - } - return orderHandler.Shutdown(context.Background()) } @@ -450,7 +450,7 @@ func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo model tunnelReadyCh <-chan struct{}, unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper, routing routing.VPNGetter, logger logging.Logger, httpClient *http.Client, - versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) { + versionInformation bool) { defer close(done) // for linters only @@ -503,15 +503,6 @@ func routeReadyEvents(ctx context.Context, done chan<- struct{}, buildInfo model updaterTickerDone = make(chan struct{}) go unboundLooper.RunRestartTicker(restartTickerContext, unboundTickerDone) go updaterLooper.RunRestartTicker(restartTickerContext, updaterTickerDone) - if portForwardingEnabled { - // vpnGateway required only for PIA - vpnGateway, err := routing.VPNLocalGatewayIP() - if err != nil { - logger.Error("cannot get VPN local gateway IP: " + err.Error()) - } - logger.Info("VPN gateway IP address: " + vpnGateway.String()) - startPortForward(vpnGateway) - } } } } diff --git a/internal/openvpn/logs.go b/internal/openvpn/logs.go index 31f01e1b..f7ab7181 100644 --- a/internal/openvpn/logs.go +++ b/internal/openvpn/logs.go @@ -42,6 +42,7 @@ func (l *Loop) collectLines(stdout, stderr <-chan string, done chan<- struct{}) } if strings.Contains(line, "Initialization Sequence Completed") { l.tunnelReady <- struct{}{} + l.startPFCh <- struct{}{} } } } diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 4bbcdc6e..68107e9f 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -1,7 +1,6 @@ package openvpn import ( - "net" "net/http" "time" @@ -11,6 +10,8 @@ import ( "github.com/qdm12/gluetun/internal/loopstate" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/openvpn/state" + "github.com/qdm12/gluetun/internal/portforward" + "github.com/qdm12/gluetun/internal/routing" "github.com/qdm12/golibs/logging" ) @@ -22,8 +23,6 @@ type Looper interface { loopstate.Applier SettingsGetSetter ServersGetterSetter - PortForwadedGetter - PortForwader } type Loop struct { @@ -35,19 +34,21 @@ type Loop struct { pgid int targetConfPath string // Configurators - conf StarterAuthWriter - fw firewallConfigurer + conf StarterAuthWriter + fw firewallConfigurer + routing routing.VPNLocalGatewayIPGetter + portForward portforward.StartStopper // Other objects - logger, pfLogger logging.Logger - client *http.Client - tunnelReady chan<- struct{} + logger logging.Logger + client *http.Client + tunnelReady chan<- struct{} // Internal channels and values - stop <-chan struct{} - stopped chan<- struct{} - start <-chan struct{} - running chan<- models.LoopStatus - portForwardSignals chan net.IP - userTrigger bool + stop <-chan struct{} + stopped chan<- struct{} + start <-chan struct{} + running chan<- models.LoopStatus + userTrigger bool + startPFCh chan struct{} // Internal constant values backoffTime time.Duration } @@ -63,7 +64,8 @@ const ( func NewLoop(settings configuration.OpenVPN, username string, puid, pgid int, allServers models.AllServers, conf Configurator, - fw firewallConfigurer, logger logging.ParentLogger, + fw firewallConfigurer, routing routing.VPNLocalGatewayIPGetter, + portForward portforward.StartStopper, logger logging.Logger, client *http.Client, tunnelReady chan<- struct{}) *Loop { start := make(chan struct{}) running := make(chan models.LoopStatus) @@ -74,24 +76,25 @@ func NewLoop(settings configuration.OpenVPN, username string, state := state.New(statusManager, settings, allServers) return &Loop{ - statusManager: statusManager, - state: state, - username: username, - puid: puid, - pgid: pgid, - targetConfPath: constants.OpenVPNConf, - conf: conf, - fw: fw, - logger: logger.NewChild(logging.Settings{Prefix: "openvpn: "}), - pfLogger: logger.NewChild(logging.Settings{Prefix: "port forwarding: "}), - client: client, - tunnelReady: tunnelReady, - start: start, - running: running, - stop: stop, - stopped: stopped, - portForwardSignals: make(chan net.IP), - userTrigger: true, - backoffTime: defaultBackoffTime, + statusManager: statusManager, + state: state, + username: username, + puid: puid, + pgid: pgid, + targetConfPath: constants.OpenVPNConf, + conf: conf, + fw: fw, + routing: routing, + portForward: portForward, + logger: logger, + client: client, + tunnelReady: tunnelReady, + start: start, + running: running, + stop: stop, + stopped: stopped, + userTrigger: true, + startPFCh: make(chan struct{}), + backoffTime: defaultBackoffTime, } } diff --git a/internal/openvpn/portforward.go b/internal/openvpn/portforward.go new file mode 100644 index 00000000..d026d55b --- /dev/null +++ b/internal/openvpn/portforward.go @@ -0,0 +1,47 @@ +package openvpn + +import ( + "context" + "time" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/portforward" + "github.com/qdm12/gluetun/internal/provider" +) + +func (l *Loop) startPortForwarding(ctx context.Context, + portForwarder provider.PortForwarder, serverName string) { + if !l.GetSettings().Provider.PortForwarding.Enabled { + return + } + // only used for PIA for now + gateway, err := l.routing.VPNLocalGatewayIP() + if err != nil { + l.logger.Error("cannot obtain VPN local gateway IP: " + err.Error()) + return + } + l.logger.Info("VPN gateway IP address: " + gateway.String()) + pfData := portforward.StartData{ + PortForwarder: portForwarder, + Gateway: gateway, + ServerName: serverName, + Interface: constants.TUN, + } + _, err = l.portForward.Start(ctx, pfData) + if err != nil { + l.logger.Error("cannot start port forwarding: " + err.Error()) + } +} + +func (l *Loop) stopPortForwarding(ctx context.Context, timeout time.Duration) { + if timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + + _, err := l.portForward.Stop(ctx) + if err != nil { + l.logger.Error("cannot stop port forwarding: " + err.Error()) + } +} diff --git a/internal/openvpn/portforwarded.go b/internal/openvpn/portforwarded.go deleted file mode 100644 index 82e59d73..00000000 --- a/internal/openvpn/portforwarded.go +++ /dev/null @@ -1,39 +0,0 @@ -package openvpn - -import ( - "context" - "net" - "net/http" - - "github.com/qdm12/gluetun/internal/openvpn/state" - "github.com/qdm12/gluetun/internal/provider" -) - -type PortForwadedGetter = state.PortForwardedGetter - -func (l *Loop) GetPortForwarded() (port uint16) { - return l.state.GetPortForwarded() -} - -type PortForwader interface { - PortForward(vpnGatewayIP net.IP) -} - -func (l *Loop) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway } - -// portForward is a blocking operation which may or may not be infinite. -// You should therefore always call it in a goroutine. -func (l *Loop) portForward(ctx context.Context, - providerConf provider.Provider, client *http.Client, gateway net.IP) { - settings := l.state.GetSettings() - if !settings.Provider.PortForwarding.Enabled { - return - } - syncState := func(port uint16) (pfFilepath string) { - l.state.SetPortForwarded(port) - settings := l.state.GetSettings() - return settings.Provider.PortForwarding.Filepath - } - providerConf.PortForward(ctx, client, l.pfLogger, - gateway, l.fw, syncState) -} diff --git a/internal/openvpn/run.go b/internal/openvpn/run.go index 51fbe38a..2c7382eb 100644 --- a/internal/openvpn/run.go +++ b/internal/openvpn/run.go @@ -88,41 +88,31 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { <-lineCollectionDone } - // Needs the stream line from main.go to know when the tunnel is up - portForwardDone := make(chan struct{}) - go func(ctx context.Context) { - defer close(portForwardDone) - select { - // TODO have a way to disable pf with a context - case <-ctx.Done(): - return - case gateway := <-l.portForwardSignals: - l.portForward(ctx, providerConf, l.client, gateway) - } - }(openvpnCtx) - l.backoffTime = defaultBackoffTime l.signalOrSetStatus(constants.Running) stayHere := true for stayHere { select { + case <-l.startPFCh: + l.startPortForwarding(ctx, providerConf, connection.Hostname) case <-ctx.Done(): + const pfTimeout = 100 * time.Millisecond + l.stopPortForwarding(context.Background(), pfTimeout) openvpnCancel() <-waitError close(waitError) closeStreams() - <-portForwardDone return case <-l.stop: l.userTrigger = true l.logger.Info("stopping") + l.stopPortForwarding(ctx, 0) openvpnCancel() <-waitError // do not close waitError or the waitError // select case will trigger closeStreams() - <-portForwardDone l.stopped <- struct{}{} case <-l.start: l.userTrigger = true @@ -134,9 +124,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { l.statusManager.Lock() // prevent SetStatus from running in parallel + l.stopPortForwarding(ctx, 0) openvpnCancel() l.statusManager.SetStatus(constants.Crashed) - <-portForwardDone l.logAndWait(ctx, err) stayHere = false diff --git a/internal/openvpn/state/state.go b/internal/openvpn/state/state.go index a9333901..32beb6e8 100644 --- a/internal/openvpn/state/state.go +++ b/internal/openvpn/state/state.go @@ -13,7 +13,6 @@ var _ Manager = (*State)(nil) type Manager interface { SettingsGetSetter ServersGetterSetter - PortForwardedGetterSetter GetSettingsAndServers() (settings configuration.OpenVPN, allServers models.AllServers) } @@ -36,9 +35,6 @@ type State struct { allServers models.AllServers allServersMu sync.RWMutex - - portForwarded uint16 - portForwardedMu sync.RWMutex } func (s *State) GetSettingsAndServers() (settings configuration.OpenVPN, diff --git a/internal/portforward/firewall.go b/internal/portforward/firewall.go new file mode 100644 index 00000000..bf187927 --- /dev/null +++ b/internal/portforward/firewall.go @@ -0,0 +1,32 @@ +package portforward + +import "context" + +// firewallBlockPort obtains the state port thread safely and blocks +// it in the firewall if it is not the zero value (0). +func (l *Loop) firewallBlockPort(ctx context.Context) { + port := l.state.GetPortForwarded() + if port == 0 { + return + } + + err := l.portAllower.RemoveAllowedPort(ctx, port) + if err != nil { + l.logger.Error("cannot block previous port in firewall: " + err.Error()) + } +} + +// firewallAllowPort obtains the state port thread safely and allows +// it in the firewall if it is not the zero value (0). +func (l *Loop) firewallAllowPort(ctx context.Context) { + port := l.state.GetPortForwarded() + if port == 0 { + return + } + + startData := l.state.GetStartData() + err := l.portAllower.SetAllowedPort(ctx, port, startData.Interface) + if err != nil { + l.logger.Error("cannot allow port through firewall: " + err.Error()) + } +} diff --git a/internal/portforward/fs.go b/internal/portforward/fs.go new file mode 100644 index 00000000..9cd529f5 --- /dev/null +++ b/internal/portforward/fs.go @@ -0,0 +1,37 @@ +package portforward + +import ( + "fmt" + "os" +) + +func (l *Loop) removePortForwardedFile() { + filepath := l.state.GetSettings().Filepath + l.logger.Info("removing port file " + filepath) + if err := os.Remove(filepath); err != nil { + l.logger.Error(err.Error()) + } +} + +func (l *Loop) writePortForwardedFile(port uint16) { + filepath := l.state.GetSettings().Filepath + l.logger.Info("writing port file " + filepath) + if err := writePortForwardedToFile(filepath, port); err != nil { + l.logger.Error(err.Error()) + } +} + +func writePortForwardedToFile(filepath string, port uint16) (err error) { + file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + if err != nil { + return err + } + + _, err = file.Write([]byte(fmt.Sprint(port))) + if err != nil { + _ = file.Close() + return err + } + + return file.Close() +} diff --git a/internal/portforward/get.go b/internal/portforward/get.go new file mode 100644 index 00000000..90722df6 --- /dev/null +++ b/internal/portforward/get.go @@ -0,0 +1,9 @@ +package portforward + +import "github.com/qdm12/gluetun/internal/portforward/state" + +type Getter = state.PortForwardedGetter + +func (l *Loop) GetPortForwarded() (port uint16) { + return l.state.GetPortForwarded() +} diff --git a/internal/portforward/helpers.go b/internal/portforward/helpers.go new file mode 100644 index 00000000..5dd28654 --- /dev/null +++ b/internal/portforward/helpers.go @@ -0,0 +1,22 @@ +package portforward + +import ( + "context" + "time" +) + +func (l *Loop) logAndWait(ctx context.Context, err error) { + if err != nil { + l.logger.Error(err.Error()) + } + l.logger.Info("retrying in " + l.backoffTime.String()) + timer := time.NewTimer(l.backoffTime) + l.backoffTime *= 2 + select { + case <-timer.C: + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + } +} diff --git a/internal/portforward/loop.go b/internal/portforward/loop.go new file mode 100644 index 00000000..e4843ef2 --- /dev/null +++ b/internal/portforward/loop.go @@ -0,0 +1,71 @@ +package portforward + +import ( + "net/http" + "sync" + "time" + + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/firewall" + "github.com/qdm12/gluetun/internal/loopstate" + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/portforward/state" + "github.com/qdm12/golibs/logging" +) + +var _ Looper = (*Loop)(nil) + +type Looper interface { + Runner + loopstate.Getter + StartStopper + SettingsGetSetter + Getter +} + +type Loop struct { + statusManager loopstate.Manager + state state.Manager + // Objects + client *http.Client + portAllower firewall.PortAllower + logger logging.Logger + // Internal channels and locks + start chan struct{} + running chan models.LoopStatus + stop chan struct{} + stopped chan struct{} + startMu sync.Mutex + backoffTime time.Duration + userTrigger bool +} + +const defaultBackoffTime = 5 * time.Second + +func NewLoop(settings configuration.PortForwarding, + client *http.Client, portAllower firewall.PortAllower, + logger logging.Logger) *Loop { + start := make(chan struct{}) + running := make(chan models.LoopStatus) + stop := make(chan struct{}) + stopped := make(chan struct{}) + + statusManager := loopstate.New(constants.Stopped, start, running, stop, stopped) + state := state.New(statusManager, settings) + + return &Loop{ + statusManager: statusManager, + state: state, + // Objects + client: client, + portAllower: portAllower, + logger: logger, + start: start, + running: running, + stop: stop, + stopped: stopped, + userTrigger: true, + backoffTime: defaultBackoffTime, + } +} diff --git a/internal/portforward/run.go b/internal/portforward/run.go new file mode 100644 index 00000000..ada5d07f --- /dev/null +++ b/internal/portforward/run.go @@ -0,0 +1,97 @@ +package portforward + +import ( + "context" + "strconv" + + "github.com/qdm12/gluetun/internal/constants" +) + +type Runner interface { + Run(ctx context.Context, done chan<- struct{}) +} + +func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { + defer close(done) + + select { + case <-l.start: // l.state.SetStartData called beforehand + case <-ctx.Done(): + return + } + + for ctx.Err() == nil { + pfCtx, pfCancel := context.WithCancel(ctx) + + portCh := make(chan uint16) + errorCh := make(chan error) + + startData := l.state.GetStartData() + + go func(ctx context.Context, startData StartData) { + port, err := startData.PortForwarder.PortForward(ctx, l.client, l.logger, + startData.Gateway, startData.ServerName) + if err != nil { + errorCh <- err + return + } + portCh <- port + + // Infinite loop + err = startData.PortForwarder.KeepPortForward(ctx, l.client, l.logger, + port, startData.Gateway, startData.ServerName) + errorCh <- err + }(pfCtx, startData) + + if l.userTrigger { + l.userTrigger = false + l.running <- constants.Running + } else { // crash + l.backoffTime = defaultBackoffTime + l.statusManager.SetStatus(constants.Running) + } + + stayHere := true + for stayHere { + select { + case <-ctx.Done(): + pfCancel() + <-errorCh + close(errorCh) + close(portCh) + l.removePortForwardedFile() + l.firewallBlockPort(ctx) + l.state.SetPortForwarded(0) + return + case <-l.start: + l.userTrigger = true + l.logger.Info("starting") + pfCancel() + stayHere = false + case <-l.stop: + l.userTrigger = true + l.logger.Info("stopping") + pfCancel() + <-errorCh + l.removePortForwardedFile() + l.firewallBlockPort(ctx) + l.state.SetPortForwarded(0) + l.stopped <- struct{}{} + case port := <-portCh: + l.logger.Info("port forwarded is " + strconv.Itoa(int(port))) + l.firewallBlockPort(ctx) + l.state.SetPortForwarded(port) + l.firewallAllowPort(ctx) + l.writePortForwardedFile(port) + case err := <-errorCh: + pfCancel() + close(errorCh) + close(portCh) + l.statusManager.SetStatus(constants.Crashed) + l.logAndWait(ctx, err) + stayHere = false + } + } + pfCancel() // for linting + } +} diff --git a/internal/portforward/settings.go b/internal/portforward/settings.go new file mode 100644 index 00000000..9736e4ab --- /dev/null +++ b/internal/portforward/settings.go @@ -0,0 +1,19 @@ +package portforward + +import ( + "context" + + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/portforward/state" +) + +type SettingsGetSetter = state.SettingsGetSetter + +func (l *Loop) GetSettings() (settings configuration.PortForwarding) { + return l.state.GetSettings() +} + +func (l *Loop) SetSettings(ctx context.Context, settings configuration.PortForwarding) ( + outcome string) { + return l.state.SetSettings(ctx, settings) +} diff --git a/internal/openvpn/state/portforwarded.go b/internal/portforward/state/portforwarded.go similarity index 100% rename from internal/openvpn/state/portforwarded.go rename to internal/portforward/state/portforwarded.go diff --git a/internal/portforward/state/settings.go b/internal/portforward/state/settings.go new file mode 100644 index 00000000..d85a0c75 --- /dev/null +++ b/internal/portforward/state/settings.go @@ -0,0 +1,55 @@ +package state + +import ( + "context" + "os" + "reflect" + + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/constants" +) + +type SettingsGetSetter interface { + GetSettings() (settings configuration.PortForwarding) + SetSettings(ctx context.Context, + settings configuration.PortForwarding) (outcome string) +} + +func (s *State) GetSettings() (settings configuration.PortForwarding) { + s.settingsMu.RLock() + defer s.settingsMu.RUnlock() + return s.settings +} + +func (s *State) SetSettings(ctx context.Context, settings configuration.PortForwarding) ( + outcome string) { + s.settingsMu.Lock() + + settingsUnchanged := reflect.DeepEqual(s.settings, settings) + if settingsUnchanged { + s.settingsMu.Unlock() + return "settings left unchanged" + } + + if s.settings.Filepath != settings.Filepath { + _ = os.Rename(s.settings.Filepath, settings.Filepath) + } + + newEnabled := settings.Enabled + previousEnabled := s.settings.Enabled + + s.settings = settings + s.settingsMu.Unlock() + + switch { + case !newEnabled && !previousEnabled: + case newEnabled && previousEnabled: + // no need to restart for now since we os.Rename the file here. + case newEnabled && !previousEnabled: + _, _ = s.statusApplier.ApplyStatus(ctx, constants.Running) + case !newEnabled && previousEnabled: + _, _ = s.statusApplier.ApplyStatus(ctx, constants.Stopped) + } + + return "settings updated" +} diff --git a/internal/portforward/state/startdata.go b/internal/portforward/state/startdata.go new file mode 100644 index 00000000..292ecc0d --- /dev/null +++ b/internal/portforward/state/startdata.go @@ -0,0 +1,39 @@ +package state + +import ( + "net" + + "github.com/qdm12/gluetun/internal/provider" +) + +type StartData struct { + PortForwarder provider.PortForwarder + Gateway net.IP // needed for PIA + ServerName string // needed for PIA + Interface string // tun0 or wg0 for example +} + +type StartDataGetterSetter interface { + StartDataGetter + StartDataSetter +} + +type StartDataGetter interface { + GetStartData() (startData StartData) +} + +func (s *State) GetStartData() (startData StartData) { + s.startDataMu.RLock() + defer s.startDataMu.RUnlock() + return s.startData +} + +type StartDataSetter interface { + SetStartData(startData StartData) +} + +func (s *State) SetStartData(startData StartData) { + s.startDataMu.Lock() + defer s.startDataMu.Unlock() + s.startData = startData +} diff --git a/internal/portforward/state/state.go b/internal/portforward/state/state.go new file mode 100644 index 00000000..79d1d5a5 --- /dev/null +++ b/internal/portforward/state/state.go @@ -0,0 +1,37 @@ +package state + +import ( + "sync" + + "github.com/qdm12/gluetun/internal/configuration" + "github.com/qdm12/gluetun/internal/loopstate" +) + +var _ Manager = (*State)(nil) + +type Manager interface { + SettingsGetSetter + PortForwardedGetterSetter + StartDataGetterSetter +} + +func New(statusApplier loopstate.Applier, + settings configuration.PortForwarding) *State { + return &State{ + statusApplier: statusApplier, + settings: settings, + } +} + +type State struct { + statusApplier loopstate.Applier + + settings configuration.PortForwarding + settingsMu sync.RWMutex + + portForwarded uint16 + portForwardedMu sync.RWMutex + + startData StartData + startDataMu sync.RWMutex +} diff --git a/internal/portforward/status.go b/internal/portforward/status.go new file mode 100644 index 00000000..60017843 --- /dev/null +++ b/internal/portforward/status.go @@ -0,0 +1,33 @@ +package portforward + +import ( + "context" + + "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/portforward/state" +) + +func (l *Loop) GetStatus() (status models.LoopStatus) { + return l.statusManager.GetStatus() +} + +type StartData = state.StartData + +type StartStopper interface { + Start(ctx context.Context, data StartData) ( + outcome string, err error) + Stop(ctx context.Context) (outcome string, err error) +} + +func (l *Loop) Start(ctx context.Context, data StartData) ( + outcome string, err error) { + l.startMu.Lock() + defer l.startMu.Unlock() + l.state.SetStartData(data) + return l.statusManager.ApplyStatus(ctx, constants.Running) +} + +func (l *Loop) Stop(ctx context.Context) (outcome string, err error) { + return l.statusManager.ApplyStatus(ctx, constants.Stopped) +} diff --git a/internal/provider/privateinternetaccess/connection.go b/internal/provider/privateinternetaccess/connection.go index 1c770b6e..7539e828 100644 --- a/internal/provider/privateinternetaccess/connection.go +++ b/internal/provider/privateinternetaccess/connection.go @@ -31,6 +31,7 @@ func (p *PIA) GetOpenVPNConnection(selection configuration.ServerSelection) ( IP: IP, Port: port, Protocol: protocol, + Hostname: server.ServerName, // used for port forwarding TLS } connections = append(connections, connection) } diff --git a/internal/provider/privateinternetaccess/portforward.go b/internal/provider/privateinternetaccess/portforward.go index 30b9407e..175323bc 100644 --- a/internal/provider/privateinternetaccess/portforward.go +++ b/internal/provider/privateinternetaccess/portforward.go @@ -15,48 +15,51 @@ import ( "strings" "time" - "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/format" "github.com/qdm12/golibs/logging" ) var ( - ErrBindPort = errors.New("cannot bind port") + ErrGatewayIPIsNil = errors.New("gateway IP address is nil") + ErrServerNameEmpty = errors.New("server name is empty") + ErrCreateHTTPClient = errors.New("cannot create custom HTTP client") + ErrReadSavedPortForwardData = errors.New("cannot read saved port forwarded data") + ErrRefreshPortForwardData = errors.New("cannot refresh port forward data") + ErrBindPort = errors.New("cannot bind port") ) // PortForward obtains a VPN server side port forwarded from PIA. -//nolint:gocognit func (p *PIA) PortForward(ctx context.Context, client *http.Client, - logger logging.Logger, gateway net.IP, portAllower firewall.PortAllower, - syncState func(port uint16) (pfFilepath string)) { - commonName := p.activeServer.ServerName - if !p.activeServer.PortForward { - logger.Error("The server " + commonName + - " (region " + p.activeServer.Region + ") does not support port forwarding") - return - } + logger logging.Logger, gateway net.IP, serverName string) ( + port uint16, err error) { + // commonName := p.activeServer.ServerName + // if !p.activeServer.PortForward { + // logger.Error("The server " + commonName + + // " (region " + p.activeServer.Region + ") does not support port forwarding") + // return + // } if gateway == nil { - logger.Error("aborting because: VPN gateway IP address was not found") - return + return 0, ErrGatewayIPIsNil + } else if serverName == "" { + return 0, ErrServerNameEmpty } - privateIPClient, err := newHTTPClient(commonName) + privateIPClient, err := newHTTPClient(serverName) if err != nil { - logger.Error("aborting because: " + err.Error()) - return + return 0, fmt.Errorf("%w: %s", ErrCreateHTTPClient, err) } data, err := readPIAPortForwardData(p.portForwardPath) if err != nil { - logger.Error(err.Error()) + return 0, fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err) } + dataFound := data.Port > 0 durationToExpiration := data.Expiration.Sub(p.timeNow()) expired := durationToExpiration <= 0 if dataFound { - logger.Info("Found persistent forwarded port data for port " + strconv.Itoa(int(data.Port))) + logger.Info("Found saved forwarded port data for port " + strconv.Itoa(int(data.Port))) if expired { logger.Warn("Forwarded port data expired on " + data.Expiration.Format(time.RFC1123) + ", getting another one") @@ -66,99 +69,65 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client, } if !dataFound || expired { - tryUntilSuccessful(ctx, logger, func() error { - data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, - p.portForwardPath, p.authFilePath) - return err - }) - if ctx.Err() != nil { - return + data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, + p.portForwardPath, p.authFilePath) + if err != nil { + return 0, fmt.Errorf("%w: %s", ErrRefreshPortForwardData, err) } durationToExpiration = data.Expiration.Sub(p.timeNow()) } - logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) + - " expiring in " + format.FriendlyDuration(durationToExpiration)) + logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration)) // First time binding - tryUntilSuccessful(ctx, logger, func() error { - if err := bindPort(ctx, privateIPClient, gateway, data); err != nil { - return fmt.Errorf("%w: %s", ErrBindPort, err) - } - return nil - }) - if ctx.Err() != nil { - return + if err := bindPort(ctx, privateIPClient, gateway, data); err != nil { + return 0, fmt.Errorf("%w: %s", ErrBindPort, err) } - filepath := syncState(data.Port) - logger.Info("Writing port to " + filepath) - if err := writePortForwardedToFile(filepath, data.Port); err != nil { - logger.Error(err.Error()) + return data.Port, nil +} + +var ( + ErrPortForwardedExpired = errors.New("port forwarded data expired") +) + +func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client, + logger logging.Logger, port uint16, gateway net.IP, serverName string) ( + err error) { + privateIPClient, err := newHTTPClient(serverName) + if err != nil { + return fmt.Errorf("%w: %s", ErrCreateHTTPClient, err) } - if err := portAllower.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil { - logger.Error(err.Error()) + data, err := readPIAPortForwardData(p.portForwardPath) + if err != nil { + return fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err) } + durationToExpiration := data.Expiration.Sub(p.timeNow()) expiryTimer := time.NewTimer(durationToExpiration) const keepAlivePeriod = 15 * time.Minute // Timer behaving as a ticker keepAliveTimer := time.NewTimer(keepAlivePeriod) + for { select { case <-ctx.Done(): - removeCtx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if err := portAllower.RemoveAllowedPort(removeCtx, data.Port); err != nil { - logger.Error(err.Error()) - } if !keepAliveTimer.Stop() { <-keepAliveTimer.C } if !expiryTimer.Stop() { <-expiryTimer.C } - return + return ctx.Err() case <-keepAliveTimer.C: - if err := bindPort(ctx, privateIPClient, gateway, data); err != nil { - logger.Error("cannot bind port: " + err.Error()) + err := bindPort(ctx, privateIPClient, gateway, data) + if err != nil { + return fmt.Errorf("%w: %s", ErrBindPort, err) } keepAliveTimer.Reset(keepAlivePeriod) case <-expiryTimer.C: - logger.Warn("Forward port has expired on " + - data.Expiration.Format(time.RFC1123) + ", getting another one") - oldPort := data.Port - for { - data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, - p.portForwardPath, p.authFilePath) - if err != nil { - logger.Error(err.Error()) - continue - } - break - } - durationToExpiration := data.Expiration.Sub(p.timeNow()) - logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) + - " expiring in " + format.FriendlyDuration(durationToExpiration)) - if err := portAllower.RemoveAllowedPort(ctx, oldPort); err != nil { - logger.Error(err.Error()) - } - if err := portAllower.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil { - logger.Error(err.Error()) - } - filepath := syncState(data.Port) - logger.Info("Writing port to " + filepath) - if err := writePortForwardedToFile(filepath, data.Port); err != nil { - logger.Error("Cannot write port forward data to file: " + err.Error()) - } - if err := bindPort(ctx, privateIPClient, gateway, data); err != nil { - logger.Error("Cannot bind port: " + err.Error()) - } - if !keepAliveTimer.Stop() { - <-keepAliveTimer.C - } - keepAliveTimer.Reset(keepAlivePeriod) - expiryTimer.Reset(durationToExpiration) + return fmt.Errorf("%w: on %s", ErrPortForwardedExpired, + data.Expiration.Format(time.RFC1123)) } } } @@ -463,21 +432,6 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia return nil } -func writePortForwardedToFile(filepath string, port uint16) (err error) { - file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) - if err != nil { - return err - } - - _, err = file.Write([]byte(fmt.Sprintf("%d", port))) - if err != nil { - _ = file.Close() - return err - } - - return file.Close() -} - // replaceInErr is used to remove sensitive information from errors. func replaceInErr(err error, substitutions map[string]string) error { s := replaceInString(err.Error(), substitutions) diff --git a/internal/provider/privateinternetaccess/try.go b/internal/provider/privateinternetaccess/try.go deleted file mode 100644 index effdcb1e..00000000 --- a/internal/provider/privateinternetaccess/try.go +++ /dev/null @@ -1,31 +0,0 @@ -package privateinternetaccess - -import ( - "context" - "time" - - "github.com/qdm12/golibs/logging" -) - -func tryUntilSuccessful(ctx context.Context, logger logging.Logger, fn func() error) { - const initialRetryPeriod = 5 * time.Second - retryPeriod := initialRetryPeriod - for { - err := fn() - if err == nil { - break - } - logger.Error(err.Error()) - logger.Info("Trying again in " + retryPeriod.String()) - timer := time.NewTimer(retryPeriod) - select { - case <-timer.C: - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C - } - return - } - retryPeriod *= 2 - } -} diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 317119fa..1ff11ff2 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -10,7 +10,6 @@ import ( "github.com/qdm12/gluetun/internal/configuration" "github.com/qdm12/gluetun/internal/constants" - "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider/cyberghost" "github.com/qdm12/gluetun/internal/provider/fastestvpn" @@ -36,9 +35,16 @@ import ( type Provider interface { GetOpenVPNConnection(selection configuration.ServerSelection) (connection models.OpenVPNConnection, err error) BuildConf(connection models.OpenVPNConnection, username string, settings configuration.OpenVPN) (lines []string) + PortForwarder +} + +type PortForwarder interface { PortForward(ctx context.Context, client *http.Client, - pfLogger logging.Logger, gateway net.IP, portAllower firewall.PortAllower, - syncState func(port uint16) (pfFilepath string)) + logger logging.Logger, gateway net.IP, serverName string) ( + port uint16, err error) + KeepPortForward(ctx context.Context, client *http.Client, + logger logging.Logger, port uint16, gateway net.IP, serverName string) ( + err error) } func New(provider string, allServers models.AllServers, timeNow func() time.Time) Provider { diff --git a/internal/provider/utils/noportforward.go b/internal/provider/utils/noportforward.go index 844e3b64..7de73afc 100644 --- a/internal/provider/utils/noportforward.go +++ b/internal/provider/utils/noportforward.go @@ -2,17 +2,21 @@ package utils import ( "context" + "errors" + "fmt" "net" "net/http" - "github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/golibs/logging" ) type NoPortForwarder interface { PortForward(ctx context.Context, client *http.Client, - pfLogger logging.Logger, gateway net.IP, portAllower firewall.PortAllower, - syncState func(port uint16) (pfFilepath string)) + logger logging.Logger, gateway net.IP, serverName string) ( + port uint16, err error) + KeepPortForward(ctx context.Context, client *http.Client, + logger logging.Logger, port uint16, gateway net.IP, serverName string) ( + err error) } type NoPortForwarding struct { @@ -25,8 +29,16 @@ func NewNoPortForwarding(providerName string) *NoPortForwarding { } } +var ErrPortForwardingNotSupported = errors.New("custom port forwarding obtention is not supported") + func (n *NoPortForwarding) PortForward(ctx context.Context, client *http.Client, - pfLogger logging.Logger, gateway net.IP, portAllower firewall.PortAllower, - syncState func(port uint16) (pfFilepath string)) { - panic("custom port forwarding obtention is not supported for " + n.providerName) + logger logging.Logger, gateway net.IP, serverName string) ( + port uint16, err error) { + return 0, fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName) +} + +func (n *NoPortForwarding) KeepPortForward(ctx context.Context, client *http.Client, + logger logging.Logger, port uint16, gateway net.IP, serverName string) ( + err error) { + return fmt.Errorf("%w: for %s", ErrPortForwardingNotSupported, n.providerName) } diff --git a/internal/server/openvpn.go b/internal/server/openvpn.go index cc3074ee..a46b171d 100644 --- a/internal/server/openvpn.go +++ b/internal/server/openvpn.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/qdm12/gluetun/internal/openvpn" + "github.com/qdm12/gluetun/internal/portforward" "github.com/qdm12/golibs/logging" ) @@ -22,6 +23,7 @@ func newOpenvpnHandler(ctx context.Context, looper openvpn.Looper, type openvpnHandler struct { ctx context.Context looper openvpn.Looper + pf portforward.Getter logger logging.Logger } @@ -105,7 +107,7 @@ func (h *openvpnHandler) getSettings(w http.ResponseWriter) { } func (h *openvpnHandler) getPortForwarded(w http.ResponseWriter) { - port := h.looper.GetPortForwarded() + port := h.pf.GetPortForwarded() encoder := json.NewEncoder(w) data := portWrapper{Port: port} if err := encoder.Encode(data); err != nil {