fix(pia): port forwarding certificate
- Do not use custom PIA certificate - Only use OS certificates - Update unit test
This commit is contained in:
@@ -2,51 +2,31 @@ package privateinternetaccess
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func newHTTPClient(serverName string) (client *http.Client, err error) {
|
func newHTTPClient(serverName string) (client *http.Client) {
|
||||||
certificateBytes, err := base64.StdEncoding.DecodeString(constants.PiaCAStrong)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot parse X509 certificate: %w", err)
|
|
||||||
}
|
|
||||||
certificate, err := x509.ParseCertificate(certificateBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot parse X509 certificate: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:gomnd
|
//nolint:gomnd
|
||||||
transport := &http.Transport{
|
|
||||||
// Settings taken from http.DefaultTransport
|
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
DialContext: (&net.Dialer{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).DialContext,
|
|
||||||
ForceAttemptHTTP2: true,
|
|
||||||
MaxIdleConns: 100,
|
|
||||||
IdleConnTimeout: 90 * time.Second,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
|
||||||
}
|
|
||||||
rootCAs := x509.NewCertPool()
|
|
||||||
rootCAs.AddCert(certificate)
|
|
||||||
transport.TLSClientConfig = &tls.Config{
|
|
||||||
RootCAs: rootCAs,
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
ServerName: serverName,
|
|
||||||
}
|
|
||||||
|
|
||||||
const httpTimeout = 30 * time.Second
|
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
Transport: transport,
|
Transport: &http.Transport{
|
||||||
Timeout: httpTimeout,
|
// Settings taken from http.DefaultTransport
|
||||||
}, nil
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
ServerName: serverName,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
31
internal/provider/privateinternetaccess/httpclient_test.go
Normal file
31
internal/provider/privateinternetaccess/httpclient_test.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package privateinternetaccess
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_newHTTPClient(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
const serverName = "testserver"
|
||||||
|
|
||||||
|
expectedPIATransportTLSConfig := &tls.Config{
|
||||||
|
// Can't directly compare RootCAs because of private fields
|
||||||
|
RootCAs: nil,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
ServerName: serverName,
|
||||||
|
}
|
||||||
|
|
||||||
|
piaClient := newHTTPClient(serverName)
|
||||||
|
|
||||||
|
// Verify pia transport TLS config is set
|
||||||
|
piaTransport, ok := piaClient.Transport.(*http.Transport)
|
||||||
|
require.True(t, ok)
|
||||||
|
piaTransport.TLSClientConfig.RootCAs = nil
|
||||||
|
assert.Equal(t, expectedPIATransportTLSConfig, piaTransport.TLSClientConfig)
|
||||||
|
}
|
||||||
@@ -47,10 +47,7 @@ func (p *PIA) PortForward(ctx context.Context, client *http.Client,
|
|||||||
return 0, ErrServerNameEmpty
|
return 0, ErrServerNameEmpty
|
||||||
}
|
}
|
||||||
|
|
||||||
privateIPClient, err := newHTTPClient(serverName)
|
privateIPClient := newHTTPClient(serverName)
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("cannot create custom HTTP client: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := readPIAPortForwardData(p.portForwardPath)
|
data, err := readPIAPortForwardData(p.portForwardPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -95,10 +92,7 @@ var (
|
|||||||
|
|
||||||
func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
|
func (p *PIA) KeepPortForward(ctx context.Context, client *http.Client,
|
||||||
port uint16, gateway net.IP, serverName string) (err error) {
|
port uint16, gateway net.IP, serverName string) (err error) {
|
||||||
privateIPClient, err := newHTTPClient(serverName)
|
privateIPClient := newHTTPClient(serverName)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot create custom HTTP client: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := readPIAPortForwardData(p.portForwardPath)
|
data, err := readPIAPortForwardData(p.portForwardPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,53 +1,16 @@
|
|||||||
package privateinternetaccess
|
package privateinternetaccess
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/qdm12/gluetun/internal/constants"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_newHTTPClient(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
const serverName = "testserver"
|
|
||||||
|
|
||||||
certificateBytes, err := base64.StdEncoding.DecodeString(constants.PiaCAStrong)
|
|
||||||
require.NoError(t, err)
|
|
||||||
certificate, err := x509.ParseCertificate(certificateBytes)
|
|
||||||
require.NoError(t, err)
|
|
||||||
rootCAs := x509.NewCertPool()
|
|
||||||
rootCAs.AddCert(certificate)
|
|
||||||
expectedRootCAsSubjects := rootCAs.Subjects()
|
|
||||||
|
|
||||||
expectedPIATransportTLSConfig := &tls.Config{
|
|
||||||
// Can't directly compare RootCAs because of private fields
|
|
||||||
RootCAs: nil,
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
ServerName: serverName,
|
|
||||||
}
|
|
||||||
|
|
||||||
piaClient, err := newHTTPClient(serverName)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Verify pia transport TLS config is set
|
|
||||||
piaTransport, ok := piaClient.Transport.(*http.Transport)
|
|
||||||
require.True(t, ok)
|
|
||||||
rootCAsSubjects := piaTransport.TLSClientConfig.RootCAs.Subjects()
|
|
||||||
assert.Equal(t, expectedRootCAsSubjects, rootCAsSubjects)
|
|
||||||
piaTransport.TLSClientConfig.RootCAs = nil
|
|
||||||
assert.Equal(t, expectedPIATransportTLSConfig, piaTransport.TLSClientConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_unpackPayload(t *testing.T) {
|
func Test_unpackPayload(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user