Using context for HTTP requests

This commit is contained in:
Quentin McGaw
2020-10-17 21:54:09 +00:00
parent 0d2ca377df
commit 6f4be72785
7 changed files with 37 additions and 28 deletions

View File

@@ -16,6 +16,7 @@ import (
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"golang.org/x/net/context/ctxhttp"
)
type piaV3 struct {
@@ -89,7 +90,7 @@ func (p *piaV3) PortForward(ctx context.Context, client *http.Client,
}
clientID := hex.EncodeToString(b)
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
response, err := client.Get(url) // TODO add ctx
response, err := ctxhttp.Get(ctx, client, url)
if err != nil {
pfLogger.Error(err)
return

View File

@@ -21,6 +21,7 @@ import (
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"golang.org/x/net/context/ctxhttp"
)
type piaV4 struct {
@@ -151,7 +152,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
if !dataFound || expired {
tryUntilSuccessful(ctx, pfLogger, func() error {
data, err = refreshPIAPortForwardData(client, gateway, fileManager)
data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager)
return err
})
if ctx.Err() != nil {
@@ -163,7 +164,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
// First time binding
tryUntilSuccessful(ctx, pfLogger, func() error {
return bindPIAPort(client, gateway, data)
return bindPIAPort(ctx, client, gateway, data)
})
if ctx.Err() != nil {
return
@@ -202,7 +203,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
}
return
case <-keepAliveTimer.C:
if err := bindPIAPort(client, gateway, data); err != nil {
if err := bindPIAPort(ctx, client, gateway, data); err != nil {
pfLogger.Error(err)
}
keepAliveTimer.Reset(keepAlivePeriod)
@@ -210,7 +211,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
oldPort := data.Port
for {
data, err = refreshPIAPortForwardData(client, gateway, fileManager)
data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager)
if err != nil {
pfLogger.Error(err)
continue
@@ -233,7 +234,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
); err != nil {
pfLogger.Error(err)
}
if err := bindPIAPort(client, gateway, data); err != nil {
if err := bindPIAPort(ctx, client, gateway, data); err != nil {
pfLogger.Error(err)
}
if !keepAliveTimer.Stop() {
@@ -292,12 +293,12 @@ func newPIAv4HTTPClient(serverName string) (client *http.Client, err error) {
return client, nil
}
func refreshPIAPortForwardData(client *http.Client, gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) {
func refreshPIAPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) {
data.Token, err = fetchPIAToken(fileManager, client)
if err != nil {
return data, fmt.Errorf("cannot obtain token: %w", err)
}
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(client, gateway, data.Token)
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(ctx, client, gateway, data.Token)
if err != nil {
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
}
@@ -429,7 +430,7 @@ func getOpenvpnCredentials(fileManager files.FileManager) (username, password st
return username, password, nil
}
func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string) (port uint16, signature string, expiration time.Time, err error) {
func fetchPIAPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, token string) (port uint16, signature string, expiration time.Time, err error) {
queryParams := url.Values{}
queryParams.Add("token", token)
url := url.URL{
@@ -438,7 +439,7 @@ func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string)
Path: "/getSignature",
RawQuery: queryParams.Encode(),
}
response, err := client.Get(url.String())
response, err := ctxhttp.Get(ctx, client, url.String())
if err != nil {
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
}
@@ -465,7 +466,7 @@ func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string)
return port, data.Signature, expiration, err
}
func bindPIAPort(client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
payload, err := packPIAPayload(data.Port, data.Token, data.Expiration)
if err != nil {
return err
@@ -480,7 +481,7 @@ func bindPIAPort(client *http.Client, gateway net.IP, data piaPortForwardData) (
RawQuery: queryParams.Encode(),
}
response, err := client.Get(url.String())
response, err := ctxhttp.Get(ctx, client, url.String())
if err != nil {
return fmt.Errorf("cannot bind port: %w", err)
}