Maint: port forwarding refactoring (#543)

- portforward package
- portforward run loop
- Less functional arguments and cycles
This commit is contained in:
Quentin McGaw
2021-07-28 08:35:44 -07:00
committed by GitHub
parent c777f8d97d
commit 2998cf5e48
25 changed files with 639 additions and 255 deletions

View File

@@ -15,48 +15,51 @@ import (
"strings"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/golibs/format"
"github.com/qdm12/golibs/logging"
)
var (
ErrBindPort = errors.New("cannot bind port")
ErrGatewayIPIsNil = errors.New("gateway IP address is nil")
ErrServerNameEmpty = errors.New("server name is empty")
ErrCreateHTTPClient = errors.New("cannot create custom HTTP client")
ErrReadSavedPortForwardData = errors.New("cannot read saved port forwarded data")
ErrRefreshPortForwardData = errors.New("cannot refresh port forward data")
ErrBindPort = errors.New("cannot bind port")
)
// PortForward obtains a VPN server side port forwarded from PIA.
//nolint:gocognit
func (p *PIA) PortForward(ctx context.Context, client *http.Client,
logger logging.Logger, gateway net.IP, portAllower firewall.PortAllower,
syncState func(port uint16) (pfFilepath string)) {
commonName := p.activeServer.ServerName
if !p.activeServer.PortForward {
logger.Error("The server " + commonName +
" (region " + p.activeServer.Region + ") does not support port forwarding")
return
}
logger logging.Logger, gateway net.IP, serverName string) (
port uint16, err error) {
// commonName := p.activeServer.ServerName
// if !p.activeServer.PortForward {
// logger.Error("The server " + commonName +
// " (region " + p.activeServer.Region + ") does not support port forwarding")
// return
// }
if gateway == nil {
logger.Error("aborting because: VPN gateway IP address was not found")
return
return 0, ErrGatewayIPIsNil
} else if serverName == "" {
return 0, ErrServerNameEmpty
}
privateIPClient, err := newHTTPClient(commonName)
privateIPClient, err := newHTTPClient(serverName)
if err != nil {
logger.Error("aborting because: " + err.Error())
return
return 0, fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
}
data, err := readPIAPortForwardData(p.portForwardPath)
if err != nil {
logger.Error(err.Error())
return 0, fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
}
dataFound := data.Port > 0
durationToExpiration := data.Expiration.Sub(p.timeNow())
expired := durationToExpiration <= 0
if dataFound {
logger.Info("Found persistent forwarded port data for port " + strconv.Itoa(int(data.Port)))
logger.Info("Found saved forwarded port data for port " + strconv.Itoa(int(data.Port)))
if expired {
logger.Warn("Forwarded port data expired on " +
data.Expiration.Format(time.RFC1123) + ", getting another one")
@@ -66,99 +69,65 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
}
if !dataFound || expired {
tryUntilSuccessful(ctx, logger, func() error {
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
p.portForwardPath, p.authFilePath)
return err
})
if ctx.Err() != nil {
return
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
p.portForwardPath, p.authFilePath)
if err != nil {
return 0, fmt.Errorf("%w: %s", ErrRefreshPortForwardData, err)
}
durationToExpiration = data.Expiration.Sub(p.timeNow())
}
logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) +
" expiring in " + format.FriendlyDuration(durationToExpiration))
logger.Info("Port forwarded data expires in " + format.FriendlyDuration(durationToExpiration))
// First time binding
tryUntilSuccessful(ctx, logger, func() error {
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
return fmt.Errorf("%w: %s", ErrBindPort, err)
}
return nil
})
if ctx.Err() != nil {
return
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
return 0, fmt.Errorf("%w: %s", ErrBindPort, err)
}
filepath := syncState(data.Port)
logger.Info("Writing port to " + filepath)
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
logger.Error(err.Error())
return data.Port, nil
}
var (
ErrPortForwardedExpired = errors.New("port forwarded data expired")
)
func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
logger logging.Logger, port uint16, gateway net.IP, serverName string) (
err error) {
privateIPClient, err := newHTTPClient(serverName)
if err != nil {
return fmt.Errorf("%w: %s", ErrCreateHTTPClient, err)
}
if err := portAllower.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
logger.Error(err.Error())
data, err := readPIAPortForwardData(p.portForwardPath)
if err != nil {
return fmt.Errorf("%w: %s", ErrReadSavedPortForwardData, err)
}
durationToExpiration := data.Expiration.Sub(p.timeNow())
expiryTimer := time.NewTimer(durationToExpiration)
const keepAlivePeriod = 15 * time.Minute
// Timer behaving as a ticker
keepAliveTimer := time.NewTimer(keepAlivePeriod)
for {
select {
case <-ctx.Done():
removeCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := portAllower.RemoveAllowedPort(removeCtx, data.Port); err != nil {
logger.Error(err.Error())
}
if !keepAliveTimer.Stop() {
<-keepAliveTimer.C
}
if !expiryTimer.Stop() {
<-expiryTimer.C
}
return
return ctx.Err()
case <-keepAliveTimer.C:
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
logger.Error("cannot bind port: " + err.Error())
err := bindPort(ctx, privateIPClient, gateway, data)
if err != nil {
return fmt.Errorf("%w: %s", ErrBindPort, err)
}
keepAliveTimer.Reset(keepAlivePeriod)
case <-expiryTimer.C:
logger.Warn("Forward port has expired on " +
data.Expiration.Format(time.RFC1123) + ", getting another one")
oldPort := data.Port
for {
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
p.portForwardPath, p.authFilePath)
if err != nil {
logger.Error(err.Error())
continue
}
break
}
durationToExpiration := data.Expiration.Sub(p.timeNow())
logger.Info("Port forwarded is " + strconv.Itoa(int(data.Port)) +
" expiring in " + format.FriendlyDuration(durationToExpiration))
if err := portAllower.RemoveAllowedPort(ctx, oldPort); err != nil {
logger.Error(err.Error())
}
if err := portAllower.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
logger.Error(err.Error())
}
filepath := syncState(data.Port)
logger.Info("Writing port to " + filepath)
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
logger.Error("Cannot write port forward data to file: " + err.Error())
}
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
logger.Error("Cannot bind port: " + err.Error())
}
if !keepAliveTimer.Stop() {
<-keepAliveTimer.C
}
keepAliveTimer.Reset(keepAlivePeriod)
expiryTimer.Reset(durationToExpiration)
return fmt.Errorf("%w: on %s", ErrPortForwardedExpired,
data.Expiration.Format(time.RFC1123))
}
}
}
@@ -463,21 +432,6 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia
return nil
}
func writePortForwardedToFile(filepath string, port uint16) (err error) {
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
_, err = file.Write([]byte(fmt.Sprintf("%d", port)))
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}
// replaceInErr is used to remove sensitive information from errors.
func replaceInErr(err error, substitutions map[string]string) error {
s := replaceInString(err.Error(), substitutions)