Using context for HTTP requests
This commit is contained in:
@@ -52,7 +52,8 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
|||||||
var err error
|
var err error
|
||||||
switch args[1] {
|
switch args[1] {
|
||||||
case "healthcheck":
|
case "healthcheck":
|
||||||
err = cli.HealthCheck()
|
client := &http.Client{Timeout: time.Second}
|
||||||
|
err = cli.HealthCheck(background, client)
|
||||||
case "clientkey":
|
case "clientkey":
|
||||||
err = cli.ClientKey(args[2:])
|
err = cli.ClientKey(args[2:])
|
||||||
case "openvpnconfig":
|
case "openvpnconfig":
|
||||||
@@ -403,7 +404,7 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn
|
|||||||
if !versionInformation {
|
if !versionInformation {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
message, err := versionpkg.GetMessage(version, commit, httpClient)
|
message, err := versionpkg.GetMessage(ctx, version, commit, httpClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
break
|
break
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -9,5 +9,6 @@ require (
|
|||||||
github.com/qdm12/golibs v0.0.0-20200712151944-a0325873bf5a
|
github.com/qdm12/golibs v0.0.0-20200712151944-a0325873bf5a
|
||||||
github.com/qdm12/ss-server v0.0.0-20200819005413-6b516c299307
|
github.com/qdm12/ss-server v0.0.0-20200819005413-6b516c299307
|
||||||
github.com/stretchr/testify v1.6.1
|
github.com/stretchr/testify v1.6.1
|
||||||
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859
|
||||||
golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed
|
golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/updater"
|
"github.com/qdm12/gluetun/internal/updater"
|
||||||
"github.com/qdm12/golibs/files"
|
"github.com/qdm12/golibs/files"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
|
"golang.org/x/net/context/ctxhttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ClientKey(args []string) error {
|
func ClientKey(args []string) error {
|
||||||
@@ -39,9 +40,9 @@ func ClientKey(args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HealthCheck() error {
|
func HealthCheck(ctx context.Context, client *http.Client) error {
|
||||||
client := &http.Client{Timeout: time.Second}
|
const url = "http://localhost:8000/health"
|
||||||
response, err := client.Get("http://localhost:8000/health")
|
response, err := ctxhttp.Get(ctx, client, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
"github.com/qdm12/golibs/files"
|
"github.com/qdm12/golibs/files"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
|
"golang.org/x/net/context/ctxhttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type piaV3 struct {
|
type piaV3 struct {
|
||||||
@@ -89,7 +90,7 @@ func (p *piaV3) PortForward(ctx context.Context, client *http.Client,
|
|||||||
}
|
}
|
||||||
clientID := hex.EncodeToString(b)
|
clientID := hex.EncodeToString(b)
|
||||||
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
|
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
|
||||||
response, err := client.Get(url) // TODO add ctx
|
response, err := ctxhttp.Get(ctx, client, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pfLogger.Error(err)
|
pfLogger.Error(err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/qdm12/gluetun/internal/models"
|
"github.com/qdm12/gluetun/internal/models"
|
||||||
"github.com/qdm12/golibs/files"
|
"github.com/qdm12/golibs/files"
|
||||||
"github.com/qdm12/golibs/logging"
|
"github.com/qdm12/golibs/logging"
|
||||||
|
"golang.org/x/net/context/ctxhttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type piaV4 struct {
|
type piaV4 struct {
|
||||||
@@ -151,7 +152,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
|
|||||||
|
|
||||||
if !dataFound || expired {
|
if !dataFound || expired {
|
||||||
tryUntilSuccessful(ctx, pfLogger, func() error {
|
tryUntilSuccessful(ctx, pfLogger, func() error {
|
||||||
data, err = refreshPIAPortForwardData(client, gateway, fileManager)
|
data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
@@ -163,7 +164,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
|
|||||||
|
|
||||||
// First time binding
|
// First time binding
|
||||||
tryUntilSuccessful(ctx, pfLogger, func() error {
|
tryUntilSuccessful(ctx, pfLogger, func() error {
|
||||||
return bindPIAPort(client, gateway, data)
|
return bindPIAPort(ctx, client, gateway, data)
|
||||||
})
|
})
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
@@ -202,7 +203,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
case <-keepAliveTimer.C:
|
case <-keepAliveTimer.C:
|
||||||
if err := bindPIAPort(client, gateway, data); err != nil {
|
if err := bindPIAPort(ctx, client, gateway, data); err != nil {
|
||||||
pfLogger.Error(err)
|
pfLogger.Error(err)
|
||||||
}
|
}
|
||||||
keepAliveTimer.Reset(keepAlivePeriod)
|
keepAliveTimer.Reset(keepAlivePeriod)
|
||||||
@@ -210,7 +211,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
|
|||||||
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
|
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
|
||||||
oldPort := data.Port
|
oldPort := data.Port
|
||||||
for {
|
for {
|
||||||
data, err = refreshPIAPortForwardData(client, gateway, fileManager)
|
data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pfLogger.Error(err)
|
pfLogger.Error(err)
|
||||||
continue
|
continue
|
||||||
@@ -233,7 +234,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
|
|||||||
); err != nil {
|
); err != nil {
|
||||||
pfLogger.Error(err)
|
pfLogger.Error(err)
|
||||||
}
|
}
|
||||||
if err := bindPIAPort(client, gateway, data); err != nil {
|
if err := bindPIAPort(ctx, client, gateway, data); err != nil {
|
||||||
pfLogger.Error(err)
|
pfLogger.Error(err)
|
||||||
}
|
}
|
||||||
if !keepAliveTimer.Stop() {
|
if !keepAliveTimer.Stop() {
|
||||||
@@ -292,12 +293,12 @@ func newPIAv4HTTPClient(serverName string) (client *http.Client, err error) {
|
|||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func refreshPIAPortForwardData(client *http.Client, gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) {
|
func refreshPIAPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) {
|
||||||
data.Token, err = fetchPIAToken(fileManager, client)
|
data.Token, err = fetchPIAToken(fileManager, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return data, fmt.Errorf("cannot obtain token: %w", err)
|
return data, fmt.Errorf("cannot obtain token: %w", err)
|
||||||
}
|
}
|
||||||
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(client, gateway, data.Token)
|
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(ctx, client, gateway, data.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
|
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
|
||||||
}
|
}
|
||||||
@@ -429,7 +430,7 @@ func getOpenvpnCredentials(fileManager files.FileManager) (username, password st
|
|||||||
return username, password, nil
|
return username, password, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string) (port uint16, signature string, expiration time.Time, err error) {
|
func fetchPIAPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, token string) (port uint16, signature string, expiration time.Time, err error) {
|
||||||
queryParams := url.Values{}
|
queryParams := url.Values{}
|
||||||
queryParams.Add("token", token)
|
queryParams.Add("token", token)
|
||||||
url := url.URL{
|
url := url.URL{
|
||||||
@@ -438,7 +439,7 @@ func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string)
|
|||||||
Path: "/getSignature",
|
Path: "/getSignature",
|
||||||
RawQuery: queryParams.Encode(),
|
RawQuery: queryParams.Encode(),
|
||||||
}
|
}
|
||||||
response, err := client.Get(url.String())
|
response, err := ctxhttp.Get(ctx, client, url.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
|
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
|
||||||
}
|
}
|
||||||
@@ -465,7 +466,7 @@ func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string)
|
|||||||
return port, data.Signature, expiration, err
|
return port, data.Signature, expiration, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func bindPIAPort(client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
|
func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
|
||||||
payload, err := packPIAPayload(data.Port, data.Token, data.Expiration)
|
payload, err := packPIAPayload(data.Port, data.Token, data.Expiration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -480,7 +481,7 @@ func bindPIAPort(client *http.Client, gateway net.IP, data piaPortForwardData) (
|
|||||||
RawQuery: queryParams.Encode(),
|
RawQuery: queryParams.Encode(),
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.Get(url.String())
|
response, err := ctxhttp.Get(ctx, client, url.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot bind port: %w", err)
|
return fmt.Errorf("cannot bind port: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
package version
|
package version
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context/ctxhttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type githubRelease struct {
|
type githubRelease struct {
|
||||||
@@ -23,9 +26,9 @@ type githubCommit struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGithubReleases(client *http.Client) (releases []githubRelease, err error) {
|
func getGithubReleases(ctx context.Context, client *http.Client) (releases []githubRelease, err error) {
|
||||||
const url = "https://api.github.com/repos/qdm12/gluetun/releases"
|
const url = "https://api.github.com/repos/qdm12/gluetun/releases"
|
||||||
response, err := client.Get(url)
|
response, err := ctxhttp.Get(ctx, client, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -40,9 +43,9 @@ func getGithubReleases(client *http.Client) (releases []githubRelease, err error
|
|||||||
return releases, nil
|
return releases, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGithubCommits(client *http.Client) (commits []githubCommit, err error) {
|
func getGithubCommits(ctx context.Context, client *http.Client) (commits []githubCommit, err error) {
|
||||||
const url = "https://api.github.com/repos/qdm12/gluetun/commits"
|
const url = "https://api.github.com/repos/qdm12/gluetun/commits"
|
||||||
response, err := client.Get(url)
|
response, err := ctxhttp.Get(ctx, client, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package version
|
package version
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
@@ -10,10 +11,10 @@ import (
|
|||||||
|
|
||||||
// GetMessage returns a message for the user describing if there is a newer version
|
// GetMessage returns a message for the user describing if there is a newer version
|
||||||
// available. It should only be called once the tunnel is established.
|
// available. It should only be called once the tunnel is established.
|
||||||
func GetMessage(version, commitShort string, client *http.Client) (message string, err error) {
|
func GetMessage(ctx context.Context, version, commitShort string, client *http.Client) (message string, err error) {
|
||||||
if version == "latest" {
|
if version == "latest" {
|
||||||
// Find # of commits between current commit and latest commit
|
// Find # of commits between current commit and latest commit
|
||||||
commitsSince, err := getCommitsSince(client, commitShort)
|
commitsSince, err := getCommitsSince(ctx, client, commitShort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("cannot get version information: %w", err)
|
return "", fmt.Errorf("cannot get version information: %w", err)
|
||||||
} else if commitsSince == 0 {
|
} else if commitsSince == 0 {
|
||||||
@@ -25,7 +26,7 @@ func GetMessage(version, commitShort string, client *http.Client) (message strin
|
|||||||
}
|
}
|
||||||
return fmt.Sprintf("You are running %d %s behind the most recent %s", commitsSince, commits, version), nil
|
return fmt.Sprintf("You are running %d %s behind the most recent %s", commitsSince, commits, version), nil
|
||||||
}
|
}
|
||||||
tagName, name, releaseTime, err := getLatestRelease(client)
|
tagName, name, releaseTime, err := getLatestRelease(ctx, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("cannot get version information: %w", err)
|
return "", fmt.Errorf("cannot get version information: %w", err)
|
||||||
}
|
}
|
||||||
@@ -38,8 +39,8 @@ func GetMessage(version, commitShort string, client *http.Client) (message strin
|
|||||||
nil
|
nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getLatestRelease(client *http.Client) (tagName, name string, time time.Time, err error) {
|
func getLatestRelease(ctx context.Context, client *http.Client) (tagName, name string, time time.Time, err error) {
|
||||||
releases, err := getGithubReleases(client)
|
releases, err := getGithubReleases(ctx, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", time, err
|
return "", "", time, err
|
||||||
}
|
}
|
||||||
@@ -52,8 +53,8 @@ func getLatestRelease(client *http.Client) (tagName, name string, time time.Time
|
|||||||
return "", "", time, fmt.Errorf("no releases found")
|
return "", "", time, fmt.Errorf("no releases found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCommitsSince(client *http.Client, commitShort string) (n int, err error) {
|
func getCommitsSince(ctx context.Context, client *http.Client, commitShort string) (n int, err error) {
|
||||||
commits, err := getGithubCommits(client)
|
commits, err := getGithubCommits(ctx, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user