Maint: port forwarding refactoring (#543)
- portforward package - portforward run loop - Less functional arguments and cycles
This commit is contained in:
@@ -15,48 +15,51 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/golibs/format"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBindPort = errors.New("cannot bind port")
|
||||
ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
|
||||
ErrServerNameEmpty = errors.New("server name is empty")
|
||||
ErrCreateHTTPClient = errors.New("cannot create custom HTTP client")
|
||||
ErrReadSavedPortForwardData = errors.New("cannot read saved port forwarded data")
|
||||
ErrRefreshPortForwardData = errors.New("cannot refresh port forward data")
|
||||
ErrBindPort = errors.New("cannot bind port")
|
||||
)
|
||||
|
||||
// PortForward obtains a VPN server side port forwarded from PIA.
|
||||
//nolint:gocognit
|
||||
func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
logger logging.Logger, gateway net.IP, portAllower firewall.PortAllower,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
commonName := p.activeServer.ServerName
|
||||
if !p.activeServer.PortForward {
|
||||
logger.Error("The server " + commonName +
|
||||
" (region " + p.activeServer.Region + ") does not support port forwarding")
|
||||
return
|
||||
}
|
||||
logger logging.Logger, gateway net.IP, serverName string) (
|
||||
port uint16, err error) {
|
||||
// commonName := p.activeServer.ServerName
|
||||
// if !p.activeServer.PortForward {
|
||||
// logger.Error("The server " + commonName +
|
||||
// " (region " + p.activeServer.Region + ") does not support port forwarding")
|
||||
// return
|
||||
// }
|
||||
if gateway == nil {
|
||||
logger.Error("aborting because: VPN gateway IP address was not found")
|
||||
return
|
||||
return 0, ErrGatewayIPIsNil
|
||||
} else if serverName == "" {
|
||||
return 0, ErrServerNameEmpty
|
||||
}
|
||||
|
||||
privateIPClient, err := newHTTPClient(commonName)
|
||||
privateIPClient, err := newHTTPClient(serverName)
|
||||
if err != nil {
|
||||
logger.Error("aborting because: " + err.Error())
|
||||
return
|
||||
return 0, fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
|
||||
}
|
||||
|
||||
data, err := readPIAPortForwardData(p.portForwardPath)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
return 0, fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
|
||||
}
|
||||
|
||||
dataFound := data.Port > 0
|
||||
durationToExpiration := data.Expiration.Sub(p.timeNow())
|
||||
expired := durationToExpiration <= 0
|
||||
|
||||
if dataFound {
|
||||
logger.Info("Found persistent forwarded port data for port " + strconv.Itoa(int(data.Port)))
|
||||
logger.Info("Found saved forwarded port data for port " + strconv.Itoa(int(data.Port)))
|
||||
if expired {
|
||||
logger.Warn("Forwarded port data expired on " +
|
||||
data.Expiration.Format(time.RFC1123) + ", getting another one")
|
||||
@@ -66,99 +69,65 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
}
|
||||
|
||||
if !dataFound || expired {
|
||||
tryUntilSuccessful(ctx, logger, func() error {
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
||||
p.portForwardPath, p.authFilePath)
|
||||
return err
|
||||
})
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
||||
p.portForwardPath, p.authFilePath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%w: %s", ErrRefreshPortForwardData, err)
|
||||
}
|
||||
durationToExpiration = data.Expiration.Sub(p.timeNow())
|
||||
}
|
||||
logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) +
|
||||
" expiring in " + format.FriendlyDuration(durationToExpiration))
|
||||
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
|
||||
|
||||
// First time binding
|
||||
tryUntilSuccessful(ctx, logger, func() error {
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrBindPort, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
return 0, fmt.Errorf("%w: %s", ErrBindPort, err)
|
||||
}
|
||||
|
||||
filepath := syncState(data.Port)
|
||||
logger.Info("Writing port to " + filepath)
|
||||
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
|
||||
logger.Error(err.Error())
|
||||
return data.Port, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrPortForwardedExpired = errors.New("port forwarded data expired")
|
||||
)
|
||||
|
||||
func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
|
||||
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
|
||||
err error) {
|
||||
privateIPClient, err := newHTTPClient(serverName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
|
||||
}
|
||||
|
||||
if err := portAllower.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
|
||||
logger.Error(err.Error())
|
||||
data, err := readPIAPortForwardData(p.portForwardPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
|
||||
}
|
||||
|
||||
durationToExpiration := data.Expiration.Sub(p.timeNow())
|
||||
expiryTimer := time.NewTimer(durationToExpiration)
|
||||
const keepAlivePeriod = 15 * time.Minute
|
||||
// Timer behaving as a ticker
|
||||
keepAliveTimer := time.NewTimer(keepAlivePeriod)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
removeCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := portAllower.RemoveAllowedPort(removeCtx, data.Port); err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
if !keepAliveTimer.Stop() {
|
||||
<-keepAliveTimer.C
|
||||
}
|
||||
if !expiryTimer.Stop() {
|
||||
<-expiryTimer.C
|
||||
}
|
||||
return
|
||||
return ctx.Err()
|
||||
case <-keepAliveTimer.C:
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
logger.Error("cannot bind port: " + err.Error())
|
||||
err := bindPort(ctx, privateIPClient, gateway, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrBindPort, err)
|
||||
}
|
||||
keepAliveTimer.Reset(keepAlivePeriod)
|
||||
case <-expiryTimer.C:
|
||||
logger.Warn("Forward port has expired on " +
|
||||
data.Expiration.Format(time.RFC1123) + ", getting another one")
|
||||
oldPort := data.Port
|
||||
for {
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
||||
p.portForwardPath, p.authFilePath)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
durationToExpiration := data.Expiration.Sub(p.timeNow())
|
||||
logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) +
|
||||
" expiring in " + format.FriendlyDuration(durationToExpiration))
|
||||
if err := portAllower.RemoveAllowedPort(ctx, oldPort); err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
if err := portAllower.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
filepath := syncState(data.Port)
|
||||
logger.Info("Writing port to " + filepath)
|
||||
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
|
||||
logger.Error("Cannot write port forward data to file: " + err.Error())
|
||||
}
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
logger.Error("Cannot bind port: " + err.Error())
|
||||
}
|
||||
if !keepAliveTimer.Stop() {
|
||||
<-keepAliveTimer.C
|
||||
}
|
||||
keepAliveTimer.Reset(keepAlivePeriod)
|
||||
expiryTimer.Reset(durationToExpiration)
|
||||
return fmt.Errorf("%w: on %s", ErrPortForwardedExpired,
|
||||
data.Expiration.Format(time.RFC1123))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -463,21 +432,6 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia
|
||||
return nil
|
||||
}
|
||||
|
||||
func writePortForwardedToFile(filepath string, port uint16) (err error) {
|
||||
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = file.Write([]byte(fmt.Sprintf("%d", port)))
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
// replaceInErr is used to remove sensitive information from errors.
|
||||
func replaceInErr(err error, substitutions map[string]string) error {
|
||||
s := replaceInString(err.Error(), substitutions)
|
||||
|
||||
Reference in New Issue
Block a user