diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 5c801bb9..000451c6 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -52,7 +52,8 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go var err error switch args[1] { case "healthcheck": - err = cli.HealthCheck() + client := &http.Client{Timeout: time.Second} + err = cli.HealthCheck(background, client) case "clientkey": err = cli.ClientKey(args[2:]) case "openvpnconfig": @@ -403,7 +404,7 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn if !versionInformation { break } - message, err := versionpkg.GetMessage(version, commit, httpClient) + message, err := versionpkg.GetMessage(ctx, version, commit, httpClient) if err != nil { logger.Error(err) break diff --git a/go.mod b/go.mod index f2f6e8ce..213b08af 100644 --- a/go.mod +++ b/go.mod @@ -9,5 +9,6 @@ require ( github.com/qdm12/golibs v0.0.0-20200712151944-a0325873bf5a github.com/qdm12/ss-server v0.0.0-20200819005413-6b516c299307 github.com/stretchr/testify v1.6.1 + golang.org/x/net v0.0.0-20190620200207-3b0461eec859 golang.org/x/sys v0.0.0-20200814200057-3d37ad5750ed ) diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 8e3bd3ef..825ec5aa 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -17,6 +17,7 @@ import ( "github.com/qdm12/gluetun/internal/updater" "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" + "golang.org/x/net/context/ctxhttp" ) func ClientKey(args []string) error { @@ -39,9 +40,9 @@ func ClientKey(args []string) error { return nil } -func HealthCheck() error { - client := &http.Client{Timeout: time.Second} - response, err := client.Get("http://localhost:8000/health") +func HealthCheck(ctx context.Context, client *http.Client) error { + const url = "http://localhost:8000/health" + response, err := ctxhttp.Get(ctx, client, url) if err != nil { return err } diff --git a/internal/provider/piav3.go b/internal/provider/piav3.go index 84cf851c..8d286b48 100644 --- a/internal/provider/piav3.go +++ b/internal/provider/piav3.go @@ -16,6 +16,7 @@ import ( "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" + "golang.org/x/net/context/ctxhttp" ) type piaV3 struct { @@ -89,7 +90,7 @@ func (p *piaV3) PortForward(ctx context.Context, client *http.Client, } clientID := hex.EncodeToString(b) url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID) - response, err := client.Get(url) // TODO add ctx + response, err := ctxhttp.Get(ctx, client, url) if err != nil { pfLogger.Error(err) return diff --git a/internal/provider/piav4.go b/internal/provider/piav4.go index f364ccc3..60662c84 100644 --- a/internal/provider/piav4.go +++ b/internal/provider/piav4.go @@ -21,6 +21,7 @@ import ( "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" + "golang.org/x/net/context/ctxhttp" ) type piaV4 struct { @@ -151,7 +152,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client, if !dataFound || expired { tryUntilSuccessful(ctx, pfLogger, func() error { - data, err = refreshPIAPortForwardData(client, gateway, fileManager) + data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager) return err }) if ctx.Err() != nil { @@ -163,7 +164,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client, // First time binding tryUntilSuccessful(ctx, pfLogger, func() error { - return bindPIAPort(client, gateway, data) + return bindPIAPort(ctx, client, gateway, data) }) if ctx.Err() != nil { return @@ -202,7 +203,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client, } return case <-keepAliveTimer.C: - if err := bindPIAPort(client, gateway, data); err != nil { + if err := bindPIAPort(ctx, client, gateway, data); err != nil { pfLogger.Error(err) } keepAliveTimer.Reset(keepAlivePeriod) @@ -210,7 +211,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client, pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123)) oldPort := data.Port for { - data, err = refreshPIAPortForwardData(client, gateway, fileManager) + data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager) if err != nil { pfLogger.Error(err) continue @@ -233,7 +234,7 @@ func (p *piaV4) PortForward(ctx context.Context, client *http.Client, ); err != nil { pfLogger.Error(err) } - if err := bindPIAPort(client, gateway, data); err != nil { + if err := bindPIAPort(ctx, client, gateway, data); err != nil { pfLogger.Error(err) } if !keepAliveTimer.Stop() { @@ -292,12 +293,12 @@ func newPIAv4HTTPClient(serverName string) (client *http.Client, err error) { return client, nil } -func refreshPIAPortForwardData(client *http.Client, gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) { +func refreshPIAPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) { data.Token, err = fetchPIAToken(fileManager, client) if err != nil { return data, fmt.Errorf("cannot obtain token: %w", err) } - data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(client, gateway, data.Token) + data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(ctx, client, gateway, data.Token) if err != nil { return data, fmt.Errorf("cannot obtain port forwarding data: %w", err) } @@ -429,7 +430,7 @@ func getOpenvpnCredentials(fileManager files.FileManager) (username, password st return username, password, nil } -func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string) (port uint16, signature string, expiration time.Time, err error) { +func fetchPIAPortForwardData(ctx context.Context, client *http.Client, gateway net.IP, token string) (port uint16, signature string, expiration time.Time, err error) { queryParams := url.Values{} queryParams.Add("token", token) url := url.URL{ @@ -438,7 +439,7 @@ func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string) Path: "/getSignature", RawQuery: queryParams.Encode(), } - response, err := client.Get(url.String()) + response, err := ctxhttp.Get(ctx, client, url.String()) if err != nil { return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err) } @@ -465,7 +466,7 @@ func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string) return port, data.Signature, expiration, err } -func bindPIAPort(client *http.Client, gateway net.IP, data piaPortForwardData) (err error) { +func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data piaPortForwardData) (err error) { payload, err := packPIAPayload(data.Port, data.Token, data.Expiration) if err != nil { return err @@ -480,7 +481,7 @@ func bindPIAPort(client *http.Client, gateway net.IP, data piaPortForwardData) ( RawQuery: queryParams.Encode(), } - response, err := client.Get(url.String()) + response, err := ctxhttp.Get(ctx, client, url.String()) if err != nil { return fmt.Errorf("cannot bind port: %w", err) } diff --git a/internal/version/github.go b/internal/version/github.go index 46fc490a..1c363513 100644 --- a/internal/version/github.go +++ b/internal/version/github.go @@ -1,10 +1,13 @@ package version import ( + "context" "encoding/json" "io/ioutil" "net/http" "time" + + "golang.org/x/net/context/ctxhttp" ) type githubRelease struct { @@ -23,9 +26,9 @@ type githubCommit struct { } } -func getGithubReleases(client *http.Client) (releases []githubRelease, err error) { +func getGithubReleases(ctx context.Context, client *http.Client) (releases []githubRelease, err error) { const url = "https://api.github.com/repos/qdm12/gluetun/releases" - response, err := client.Get(url) + response, err := ctxhttp.Get(ctx, client, url) if err != nil { return nil, err } @@ -40,9 +43,9 @@ func getGithubReleases(client *http.Client) (releases []githubRelease, err error return releases, nil } -func getGithubCommits(client *http.Client) (commits []githubCommit, err error) { +func getGithubCommits(ctx context.Context, client *http.Client) (commits []githubCommit, err error) { const url = "https://api.github.com/repos/qdm12/gluetun/commits" - response, err := client.Get(url) + response, err := ctxhttp.Get(ctx, client, url) if err != nil { return nil, err } diff --git a/internal/version/version.go b/internal/version/version.go index 3d6e3c55..7ca79f6f 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -1,6 +1,7 @@ package version import ( + "context" "fmt" "net/http" "time" @@ -10,10 +11,10 @@ import ( // GetMessage returns a message for the user describing if there is a newer version // available. It should only be called once the tunnel is established. -func GetMessage(version, commitShort string, client *http.Client) (message string, err error) { +func GetMessage(ctx context.Context, version, commitShort string, client *http.Client) (message string, err error) { if version == "latest" { // Find # of commits between current commit and latest commit - commitsSince, err := getCommitsSince(client, commitShort) + commitsSince, err := getCommitsSince(ctx, client, commitShort) if err != nil { return "", fmt.Errorf("cannot get version information: %w", err) } else if commitsSince == 0 { @@ -25,7 +26,7 @@ func GetMessage(version, commitShort string, client *http.Client) (message strin } return fmt.Sprintf("You are running %d %s behind the most recent %s", commitsSince, commits, version), nil } - tagName, name, releaseTime, err := getLatestRelease(client) + tagName, name, releaseTime, err := getLatestRelease(ctx, client) if err != nil { return "", fmt.Errorf("cannot get version information: %w", err) } @@ -38,8 +39,8 @@ func GetMessage(version, commitShort string, client *http.Client) (message strin nil } -func getLatestRelease(client *http.Client) (tagName, name string, time time.Time, err error) { - releases, err := getGithubReleases(client) +func getLatestRelease(ctx context.Context, client *http.Client) (tagName, name string, time time.Time, err error) { + releases, err := getGithubReleases(ctx, client) if err != nil { return "", "", time, err } @@ -52,8 +53,8 @@ func getLatestRelease(client *http.Client) (tagName, name string, time time.Time return "", "", time, fmt.Errorf("no releases found") } -func getCommitsSince(client *http.Client, commitShort string) (n int, err error) { - commits, err := getGithubCommits(client) +func getCommitsSince(ctx context.Context, client *http.Client, commitShort string) (n int, err error) { + commits, err := getGithubCommits(ctx, client) if err != nil { return 0, err }