diff --git a/internal/provider/privateinternetaccess/httpclient.go b/internal/provider/privateinternetaccess/httpclient.go index 2cbd4213..97a42be0 100644 --- a/internal/provider/privateinternetaccess/httpclient.go +++ b/internal/provider/privateinternetaccess/httpclient.go @@ -2,51 +2,31 @@ package privateinternetaccess import ( "crypto/tls" - "crypto/x509" - "encoding/base64" - "fmt" "net" "net/http" "time" - - "github.com/qdm12/gluetun/internal/constants" ) -func newHTTPClient(serverName string) (client *http.Client, err error) { - certificateBytes, err := base64.StdEncoding.DecodeString(constants.PiaCAStrong) - if err != nil { - return nil, fmt.Errorf("cannot parse X509 certificate: %w", err) - } - certificate, err := x509.ParseCertificate(certificateBytes) - if err != nil { - return nil, fmt.Errorf("cannot parse X509 certificate: %w", err) - } - +func newHTTPClient(serverName string) (client *http.Client) { //nolint:gomnd - transport := &http.Transport{ - // Settings taken from http.DefaultTransport - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - rootCAs := x509.NewCertPool() - rootCAs.AddCert(certificate) - transport.TLSClientConfig = &tls.Config{ - RootCAs: rootCAs, - MinVersion: tls.VersionTLS12, - ServerName: serverName, - } - - const httpTimeout = 30 * time.Second return &http.Client{ - Transport: transport, - Timeout: httpTimeout, - }, nil + Transport: &http.Transport{ + // Settings taken from http.DefaultTransport + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: serverName, + }, + }, + Timeout: 30 * time.Second, + } } diff --git a/internal/provider/privateinternetaccess/httpclient_test.go b/internal/provider/privateinternetaccess/httpclient_test.go new file mode 100644 index 00000000..f764a492 --- /dev/null +++ b/internal/provider/privateinternetaccess/httpclient_test.go @@ -0,0 +1,31 @@ +package privateinternetaccess + +import ( + "crypto/tls" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_newHTTPClient(t *testing.T) { + t.Parallel() + + const serverName = "testserver" + + expectedPIATransportTLSConfig := &tls.Config{ + // Can't directly compare RootCAs because of private fields + RootCAs: nil, + MinVersion: tls.VersionTLS12, + ServerName: serverName, + } + + piaClient := newHTTPClient(serverName) + + // Verify pia transport TLS config is set + piaTransport, ok := piaClient.Transport.(*http.Transport) + require.True(t, ok) + piaTransport.TLSClientConfig.RootCAs = nil + assert.Equal(t, expectedPIATransportTLSConfig, piaTransport.TLSClientConfig) +} diff --git a/internal/provider/privateinternetaccess/portforward.go b/internal/provider/privateinternetaccess/portforward.go index 6b8239fc..f0d07dd4 100644 --- a/internal/provider/privateinternetaccess/portforward.go +++ b/internal/provider/privateinternetaccess/portforward.go @@ -47,10 +47,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client, return 0, ErrServerNameEmpty } - privateIPClient, err := newHTTPClient(serverName) - if err != nil { - return 0, fmt.Errorf("cannot create custom HTTP client: %w", err) - } + privateIPClient := newHTTPClient(serverName) data, err := readPIAPortForwardData(p.portForwardPath) if err != nil { @@ -95,10 +92,7 @@ var ( func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client, port uint16, gateway net.IP, serverName string) (err error) { - privateIPClient, err := newHTTPClient(serverName) - if err != nil { - return fmt.Errorf("cannot create custom HTTP client: %w", err) - } + privateIPClient := newHTTPClient(serverName) data, err := readPIAPortForwardData(p.portForwardPath) if err != nil { diff --git a/internal/provider/privateinternetaccess/portforward_test.go b/internal/provider/privateinternetaccess/portforward_test.go index cfcd99ae..f1498484 100644 --- a/internal/provider/privateinternetaccess/portforward_test.go +++ b/internal/provider/privateinternetaccess/portforward_test.go @@ -1,53 +1,16 @@ package privateinternetaccess import ( - "crypto/tls" - "crypto/x509" "encoding/base64" "encoding/json" "errors" - "net/http" "testing" "time" - "github.com/qdm12/gluetun/internal/constants" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func Test_newHTTPClient(t *testing.T) { - t.Parallel() - - const serverName = "testserver" - - certificateBytes, err := base64.StdEncoding.DecodeString(constants.PiaCAStrong) - require.NoError(t, err) - certificate, err := x509.ParseCertificate(certificateBytes) - require.NoError(t, err) - rootCAs := x509.NewCertPool() - rootCAs.AddCert(certificate) - expectedRootCAsSubjects := rootCAs.Subjects() - - expectedPIATransportTLSConfig := &tls.Config{ - // Can't directly compare RootCAs because of private fields - RootCAs: nil, - MinVersion: tls.VersionTLS12, - ServerName: serverName, - } - - piaClient, err := newHTTPClient(serverName) - - require.NoError(t, err) - - // Verify pia transport TLS config is set - piaTransport, ok := piaClient.Transport.(*http.Transport) - require.True(t, ok) - rootCAsSubjects := piaTransport.TLSClientConfig.RootCAs.Subjects() - assert.Equal(t, expectedRootCAsSubjects, rootCAsSubjects) - piaTransport.TLSClientConfig.RootCAs = nil - assert.Equal(t, expectedPIATransportTLSConfig, piaTransport.TLSClientConfig) -} - func Test_unpackPayload(t *testing.T) { t.Parallel()