Maint: do not mock os functions
- Use filepaths with /tmp for tests instead - Only mock functions where filepath can't be specified such as user.Lookup
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -18,7 +19,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/format"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -28,7 +28,7 @@ var (
|
||||
// PortForward obtains a VPN server side port forwarded from PIA.
|
||||
//nolint:gocognit
|
||||
func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
openFile os.OpenFileFunc, logger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
logger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath string)) {
|
||||
commonName := p.activeServer.ServerName
|
||||
if !p.activeServer.PortForward {
|
||||
@@ -47,7 +47,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
return
|
||||
}
|
||||
|
||||
data, err := readPIAPortForwardData(openFile)
|
||||
data, err := readPIAPortForwardData(p.portForwardPath)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
@@ -67,7 +67,8 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
|
||||
if !dataFound || expired {
|
||||
tryUntilSuccessful(ctx, logger, func() error {
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
||||
p.portForwardPath, p.authFilePath)
|
||||
return err
|
||||
})
|
||||
if ctx.Err() != nil {
|
||||
@@ -91,7 +92,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
|
||||
filepath := syncState(data.Port)
|
||||
logger.Info("Writing port to " + filepath)
|
||||
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
|
||||
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
@@ -128,7 +129,8 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
data.Expiration.Format(time.RFC1123) + ", getting another one")
|
||||
oldPort := data.Port
|
||||
for {
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway, openFile)
|
||||
data, err = refreshPIAPortForwardData(ctx, client, privateIPClient, gateway,
|
||||
p.portForwardPath, p.authFilePath)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
continue
|
||||
@@ -146,7 +148,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
||||
}
|
||||
filepath := syncState(data.Port)
|
||||
logger.Info("Writing port to " + filepath)
|
||||
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
|
||||
if err := writePortForwardedToFile(filepath, data.Port); err != nil {
|
||||
logger.Error("Cannot write port forward data to file: " + err.Error())
|
||||
}
|
||||
if err := bindPort(ctx, privateIPClient, gateway, data); err != nil {
|
||||
@@ -168,8 +170,8 @@ var (
|
||||
)
|
||||
|
||||
func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
|
||||
gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
|
||||
data.Token, err = fetchToken(ctx, openFile, client)
|
||||
gateway net.IP, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
|
||||
data.Token, err = fetchToken(ctx, client, authFilePath)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("%w: %s", ErrFetchToken, err)
|
||||
}
|
||||
@@ -179,7 +181,7 @@ func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *htt
|
||||
return data, fmt.Errorf("%w: %s", ErrFetchPortForwarding, err)
|
||||
}
|
||||
|
||||
if err := writePIAPortForwardData(openFile, data); err != nil {
|
||||
if err := writePIAPortForwardData(portForwardPath, data); err != nil {
|
||||
return data, fmt.Errorf("%w: %s", ErrPersistPortForwarding, err)
|
||||
}
|
||||
|
||||
@@ -199,8 +201,8 @@ type piaPortForwardData struct {
|
||||
Expiration time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
|
||||
file, err := openFile(constants.PIAPortForward, os.O_RDONLY, 0)
|
||||
func readPIAPortForwardData(portForwardPath string) (data piaPortForwardData, err error) {
|
||||
file, err := os.Open(portForwardPath)
|
||||
if os.IsNotExist(err) {
|
||||
return data, nil
|
||||
} else if err != nil {
|
||||
@@ -216,8 +218,8 @@ func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData,
|
||||
return data, file.Close()
|
||||
}
|
||||
|
||||
func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData) (err error) {
|
||||
file, err := openFile(constants.PIAPortForward, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
func writePIAPortForwardData(portForwardPath string, data piaPortForwardData) (err error) {
|
||||
file, err := os.OpenFile(portForwardPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -269,9 +271,9 @@ var (
|
||||
errEmptyToken = errors.New("token received is empty")
|
||||
)
|
||||
|
||||
func fetchToken(ctx context.Context, openFile os.OpenFileFunc,
|
||||
client *http.Client) (token string, err error) {
|
||||
username, password, err := getOpenvpnCredentials(openFile)
|
||||
func fetchToken(ctx context.Context, client *http.Client,
|
||||
authFilePath string) (token string, err error) {
|
||||
username, password, err := getOpenvpnCredentials(authFilePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %s", errGetCredentials, err)
|
||||
}
|
||||
@@ -321,8 +323,9 @@ var (
|
||||
errAuthFileMalformed = errors.New("authentication file is malformed")
|
||||
)
|
||||
|
||||
func getOpenvpnCredentials(openFile os.OpenFileFunc) (username, password string, err error) {
|
||||
file, err := openFile(constants.OpenVPNAuthConf, os.O_RDONLY, 0)
|
||||
func getOpenvpnCredentials(authFilePath string) (
|
||||
username, password string, err error) {
|
||||
file, err := os.Open(authFilePath)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("%w: %s", errAuthFileRead, err)
|
||||
}
|
||||
@@ -460,9 +463,8 @@ func bindPort(ctx context.Context, client *http.Client, gateway net.IP, data pia
|
||||
return nil
|
||||
}
|
||||
|
||||
func writePortForwardedToFile(openFile os.OpenFileFunc,
|
||||
filepath string, port uint16) (err error) {
|
||||
file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
func writePortForwardedToFile(filepath string, port uint16) (err error) {
|
||||
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
)
|
||||
|
||||
@@ -12,12 +13,18 @@ type PIA struct {
|
||||
randSource rand.Source
|
||||
timeNow func() time.Time
|
||||
activeServer models.PIAServer
|
||||
// Port forwarding
|
||||
portForwardPath string
|
||||
authFilePath string
|
||||
}
|
||||
|
||||
func New(servers []models.PIAServer, randSource rand.Source, timeNow func() time.Time) *PIA {
|
||||
func New(servers []models.PIAServer, randSource rand.Source,
|
||||
timeNow func() time.Time) *PIA {
|
||||
return &PIA{
|
||||
servers: servers,
|
||||
timeNow: timeNow,
|
||||
randSource: randSource,
|
||||
servers: servers,
|
||||
timeNow: timeNow,
|
||||
randSource: randSource,
|
||||
portForwardPath: constants.PIAPortForward,
|
||||
authFilePath: constants.OpenVPNAuthConf,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user