PIA nextgen portforward (#242)

* Split provider/pia.go in piav3.go and piav4.go
* Change port forwarding signature
* Enable port forwarding parameter for PIA v4
* Fix VPN gateway IP obtention
* Setup HTTP client for TLS with custom cert
* Error message for regions not supporting pf
This commit is contained in:
Quentin McGaw
2020-10-12 10:55:08 -04:00
committed by GitHub
parent fbecbc1c82
commit ec157f102b
25 changed files with 763 additions and 202 deletions

View File

@@ -115,8 +115,8 @@ docker run --rm --network=container:gluetun alpine:3.12 wget -qO- https://ipinfo
| 🏁 `PASSWORD` | | | Your password |
| `REGION` | | One of the [PIA regions](https://www.privateinternetaccess.com/pages/network/) | VPN server region |
| `PIA_ENCRYPTION` | `strong` | `normal`, `strong` | Encryption preset |
| `PORT_FORWARDING` | `off` | `on`, `off` | Enable port forwarding on the VPN server **for old only** |
| `PORT_FORWARDING_STATUS_FILE` | `/tmp/gluetun/forwarded_port` | Any filepath | Filepath to store the forwarded port number **for old only** |
| `PORT_FORWARDING` | `off` | `on`, `off` | Enable port forwarding on the VPN server |
| `PORT_FORWARDING_STATUS_FILE` | `/tmp/gluetun/forwarded_port` | Any filepath | Filepath to store the forwarded port number |
- Mullvad

View File

@@ -3,6 +3,7 @@ package main
import (
"context"
"fmt"
"net"
"net/http"
"os"
"os/signal"
@@ -188,7 +189,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady)
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers,
ovpnConf, firewallConf, logger, client, fileManager, streamMerger, cancel)
ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel)
wg.Add(1)
// wait for restartOpenvpn
go openvpnLooper.Run(ctx, wg)
@@ -341,10 +342,11 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger,
})
}
//nolint:gocognit
func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dnsReadyCh <-chan struct{},
unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
routing routing.Routing, logger logging.Logger, httpClient *http.Client,
versionInformation, portForwardingEnabled bool, startPortForward func()) {
versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) {
defer wg.Done()
tickerWg := &sync.WaitGroup{}
// for linters only
@@ -364,18 +366,35 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn
tickerWg.Add(2)
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
if portForwardingEnabled {
time.AfterFunc(5*time.Second, startPortForward)
}
defaultInterface, _, err := routing.DefaultRoute()
if err != nil {
logger.Warn(err)
} else {
vpnGatewayIP, err := routing.VPNGatewayIP(defaultInterface)
vpnDestination, err := routing.VPNDestinationIP(defaultInterface)
if err != nil {
logger.Warn(err)
} else {
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
logger.Info("VPN routing IP address: %s", vpnDestination)
}
}
if portForwardingEnabled {
// TODO make instantaneous once v3 go out of service
const waitDuration = 5 * time.Second
timer := time.NewTimer(waitDuration)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
continue
case <-timer.C:
// vpnGateway required only for PIA v4
vpnGateway, err := routing.VPNLocalGatewayIP()
if err != nil {
logger.Error(err)
}
logger.Info("VPN gateway IP address: %s", vpnGateway)
startPortForward(vpnGateway)
}
}
case <-dnsReadyCh:

View File

@@ -15,6 +15,8 @@ const (
OpenVPNAuthConf models.Filepath = "/etc/openvpn/auth.conf"
// OpenVPNConf is the file path to the OpenVPN client configuration file
OpenVPNConf models.Filepath = "/etc/openvpn/target.ovpn"
// PIAPortForward is the file path to the port forwarding JSON information for PIA v4 servers
PIAPortForward models.Filepath = "/gluetun/piaportforward.json"
// TunnelDevice is the file path to tun device
TunnelDevice models.Filepath = "/dev/net/tun"
// NetRoute is the path to the file containing information on the network route

View File

@@ -2,9 +2,9 @@ package constants
const (
// Announcement is a message announcement
Announcement = "Update servers information see https://github.com/qdm12/gluetun/wiki/Update-servers-information"
Announcement = "Port forwarding is working for PIA v4 servers"
// AnnouncementExpiration is the expiration date of the announcement in format yyyy-mm-dd
AnnouncementExpiration = "2020-10-10"
AnnouncementExpiration = "2020-11-15"
)
const (

View File

@@ -0,0 +1,29 @@
package logging
import (
"fmt"
"time"
)
func FormatDuration(duration time.Duration) string {
switch {
case duration < time.Minute:
seconds := int(duration.Round(time.Second).Seconds())
if seconds < 2 {
return fmt.Sprintf("%d second", seconds)
}
return fmt.Sprintf("%d seconds", seconds)
case duration <= time.Hour:
minutes := int(duration.Round(time.Minute).Minutes())
if minutes == 1 {
return "1 minute"
}
return fmt.Sprintf("%d minutes", minutes)
case duration < 48*time.Hour:
hours := int(duration.Truncate(time.Hour).Hours())
return fmt.Sprintf("%d hours", hours)
default:
days := int(duration.Truncate(time.Hour).Hours() / 24)
return fmt.Sprintf("%d days", days)
}
}

View File

@@ -1,4 +1,4 @@
package version
package logging
import (
"testing"
@@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_formatDuration(t *testing.T) {
func Test_FormatDuration(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
duration time.Duration
@@ -57,7 +57,7 @@ func Test_formatDuration(t *testing.T) {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
s := formatDuration(testCase.duration)
s := FormatDuration(testCase.duration)
assert.Equal(t, testCase.s, s)
})
}

View File

@@ -90,6 +90,7 @@ func (p *ProviderSettings) String() string {
settingsList = append(settingsList,
"Region: "+p.ServerSelection.Region,
"Encryption preset: "+p.ExtraConfigOptions.EncryptionPreset,
"Port forwarding: "+p.PortForwarding.String(),
)
case "mullvad":
settingsList = append(settingsList,

View File

@@ -2,7 +2,8 @@ package openvpn
import (
"context"
"fmt"
"net"
"net/http"
"sync"
"time"
@@ -10,17 +11,17 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
)
type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup)
Restart()
PortForward()
PortForward(vpnGatewayIP net.IP)
GetSettings() (settings settings.OpenVPN)
SetSettings(settings settings.OpenVPN)
GetPortForwarded() (portForwarded uint16)
@@ -42,21 +43,22 @@ type looper struct {
// Configurators
conf Configurator
fw firewall.Configurator
routing routing.Routing
// Other objects
logger logging.Logger
client network.Client
logger, pfLogger logging.Logger
client *http.Client
fileManager files.FileManager
streamMerger command.StreamMerger
cancel context.CancelFunc
// Internal channels
restart chan struct{}
portForwardSignals chan struct{}
portForwardSignals chan net.IP
}
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
uid, gid int, allServers models.AllServers,
conf Configurator, fw firewall.Configurator,
logger logging.Logger, client network.Client, fileManager files.FileManager,
conf Configurator, fw firewall.Configurator, routing routing.Routing,
logger logging.Logger, client *http.Client, fileManager files.FileManager,
streamMerger command.StreamMerger, cancel context.CancelFunc) Looper {
return &looper{
provider: provider,
@@ -66,18 +68,20 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
allServers: allServers,
conf: conf,
fw: fw,
routing: routing,
logger: logger.WithPrefix("openvpn: "),
pfLogger: logger.WithPrefix("port forwarding: "),
client: client,
fileManager: fileManager,
streamMerger: streamMerger,
cancel: cancel,
restart: make(chan struct{}),
portForwardSignals: make(chan struct{}),
portForwardSignals: make(chan net.IP),
}
}
func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) PortForward() { l.portForwardSignals <- struct{}{} }
func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway }
func (l *looper) GetSettings() (settings settings.OpenVPN) {
l.settingsMutex.RLock()
@@ -158,10 +162,12 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
go func(ctx context.Context) {
for {
select {
// TODO have a way to disable pf with a context
case <-ctx.Done():
return
case <-l.portForwardSignals:
l.portForward(ctx, providerConf, l.client)
case gateway := <-l.portForwardSignals:
wg.Add(1)
go l.portForward(ctx, wg, providerConf, l.client, gateway)
}
}
}(openvpnCtx)
@@ -200,43 +206,25 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
<-ctx.Done()
}
func (l *looper) portForward(ctx context.Context, providerConf provider.Provider, client network.Client) {
// portForward is a blocking operation which may or may not be infinite.
// You should therefore always call it in a goroutine
func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup,
providerConf provider.Provider, client *http.Client, gateway net.IP) {
defer wg.Done()
settings := l.GetSettings()
if !settings.Provider.PortForwarding.Enabled {
return
}
var port uint16
err := fmt.Errorf("")
for err != nil {
if ctx.Err() != nil {
return
}
port, err = providerConf.GetPortForward(client)
if err != nil {
l.logAndWait(ctx, err)
}
}
l.logger.Info("port forwarded is %d", port)
syncState := func(port uint16) (pfFilepath models.Filepath) {
l.portForwardedMutex.Lock()
if err := l.fw.RemoveAllowedPort(ctx, l.portForwarded); err != nil {
l.logger.Error(err)
}
if err := l.fw.SetAllowedPort(ctx, port, string(constants.TUN)); err != nil {
l.logger.Error(err)
}
l.portForwarded = port
l.portForwardedMutex.Unlock()
filepath := settings.Provider.PortForwarding.Filepath
l.logger.Info("writing forwarded port to %s", filepath)
err = l.fileManager.WriteLinesToFile(
string(filepath), []string{fmt.Sprintf("%d", port)},
files.Ownership(l.uid, l.gid), files.Permissions(0400),
)
if err != nil {
l.logger.Error(err)
settings := l.GetSettings()
return settings.Provider.PortForwarding.Filepath
}
providerConf.PortForward(ctx,
client, l.fileManager, l.pfLogger,
gateway, l.fw, syncState)
}
func (l *looper) GetPortForwarded() (portForwarded uint16) {

View File

@@ -1,12 +1,17 @@
package provider
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
type cyberghost struct {
@@ -135,6 +140,8 @@ func (c *cyberghost) BuildConf(connections []models.OpenVPNConnection, verbosity
return lines
}
func (c *cyberghost) GetPortForward(client network.Client) (port uint16, err error) {
func (c *cyberghost) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for cyberghost")
}

View File

@@ -1,12 +1,17 @@
package provider
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
type mullvad struct {
@@ -134,6 +139,8 @@ func (m *mullvad) BuildConf(connections []models.OpenVPNConnection, verbosity, u
return lines
}
func (m *mullvad) GetPortForward(client network.Client) (port uint16, err error) {
func (m *mullvad) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for mullvad")
}

View File

@@ -1,12 +1,17 @@
package provider
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
type nordvpn struct {
@@ -142,6 +147,8 @@ func (n *nordvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u
return lines
}
func (n *nordvpn) GetPortForward(client network.Client) (port uint16, err error) {
func (n *nordvpn) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for nordvpn")
}

View File

@@ -1,35 +1,18 @@
package provider
import (
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/crypto/random"
"github.com/qdm12/golibs/network"
)
type pia struct {
random random.Random
servers []models.PIAServer
}
func newPrivateInternetAccess(servers []models.PIAServer) *pia {
return &pia{
random: random.NewRandom(),
servers: servers,
}
}
func (p *pia) filterServers(region string) (servers []models.PIAServer) {
func filterPIAServers(servers []models.PIAServer, region string) (filtered []models.PIAServer) {
if len(region) == 0 {
return p.servers
return servers
}
for _, server := range p.servers {
for _, server := range servers {
if strings.EqualFold(server.Region, region) {
return []models.PIAServer{server}
}
@@ -37,8 +20,8 @@ func (p *pia) filterServers(region string) (servers []models.PIAServer) {
return nil
}
func (p *pia) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) {
servers := p.filterServers(selection.Region)
func getPIAOpenVPNConnections(allServers []models.PIAServer, selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) {
servers := filterPIAServers(allServers, selection.Region)
if len(servers) == 0 {
return nil, fmt.Errorf("no server found for region %q", selection.Region)
}
@@ -87,7 +70,7 @@ func (p *pia) GetOpenVPNConnections(selection models.ServerSelection) (connectio
return connections, nil
}
func (p *pia) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
func buildPIAConf(connections []models.OpenVPNConnection, verbosity int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
var X509CRL, certificate string
if extras.EncryptionPreset == constants.PIAEncryptionPresetNormal {
if len(cipher) == 0 {
@@ -161,28 +144,3 @@ func (p *pia) BuildConf(connections []models.OpenVPNConnection, verbosity, uid,
}...)
return lines
}
func (p *pia) GetPortForward(client network.Client) (port uint16, err error) {
b, err := p.random.GenerateRandomBytes(32)
if err != nil {
return 0, err
}
clientID := hex.EncodeToString(b)
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
content, status, err := client.GetContent(url) // TODO add ctx
switch {
case err != nil:
return 0, err
case status != http.StatusOK:
return 0, fmt.Errorf("status is %d for %s; does your PIA server support port forwarding?", status, url)
case len(content) == 0:
return 0, fmt.Errorf("port forwarding is already activated on this connection, has expired, or you are not connected to a PIA region that supports port forwarding")
}
body := struct {
Port uint16 `json:"port"`
}{}
if err := json.Unmarshal(content, &body); err != nil {
return 0, fmt.Errorf("port forwarding response: %w", err)
}
return body.Port, nil
}

View File

@@ -0,0 +1,94 @@
package provider
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/crypto/random"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
type piaV3 struct {
random random.Random
servers []models.PIAServer
}
func newPrivateInternetAccessV3(servers []models.PIAServer) *piaV3 {
return &piaV3{
random: random.NewRandom(),
servers: servers,
}
}
func (p *piaV3) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) {
return getPIAOpenVPNConnections(p.servers, selection)
}
func (p *piaV3) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
return buildPIAConf(connections, verbosity, root, cipher, auth, extras)
}
func (p *piaV3) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
b, err := p.random.GenerateRandomBytes(32)
if err != nil {
pfLogger.Error(err)
return
}
clientID := hex.EncodeToString(b)
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
response, err := client.Get(url) // TODO add ctx
if err != nil {
pfLogger.Error(err)
return
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
pfLogger.Error(fmt.Errorf("%s for %s; does your PIA server support port forwarding?", response.Status, url))
return
}
b, err = ioutil.ReadAll(response.Body)
if err != nil {
pfLogger.Error(err)
return
} else if len(b) == 0 {
pfLogger.Error(fmt.Errorf("port forwarding is already activated on this connection, has expired, or you are not connected to a PIA region that supports port forwarding"))
return
}
body := struct {
Port uint16 `json:"port"`
}{}
if err := json.Unmarshal(b, &body); err != nil {
pfLogger.Error(fmt.Errorf("port forwarding response: %w", err))
return
}
port := body.Port
filepath := syncState(port)
pfLogger.Info("Writing port to %s", filepath)
if err := fileManager.WriteToFile(
string(filepath), []byte(fmt.Sprintf("%d", port)),
files.Permissions(0666),
); err != nil {
pfLogger.Error(err)
}
if err := fw.SetAllowedPort(ctx, port, string(constants.TUN)); err != nil {
pfLogger.Error(err)
}
<-ctx.Done()
if err := fw.RemoveAllowedPort(ctx, port); err != nil {
pfLogger.Error(err)
}
}

412
internal/provider/piav4.go Normal file
View File

@@ -0,0 +1,412 @@
package provider
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
gluetunLog "github.com/qdm12/gluetun/internal/logging"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
type piaV4 struct {
servers []models.PIAServer
timeNow func() time.Time
}
func newPrivateInternetAccessV4(servers []models.PIAServer) *piaV4 {
return &piaV4{
servers: servers,
timeNow: time.Now,
}
}
func (p *piaV4) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) {
return getPIAOpenVPNConnections(p.servers, selection)
}
func (p *piaV4) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
return buildPIAConf(connections, verbosity, root, cipher, auth, extras)
}
//nolint:gocognit
func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
if gateway == nil {
pfLogger.Error("aborting because: VPN gateway IP address was not found")
return
}
client, err := newPIAv4HTTPClient()
if err != nil {
pfLogger.Error("aborting because: %s", err)
return
}
defer pfLogger.Warn("loop exited")
data, err := readPIAPortForwardData(fileManager)
if err != nil {
pfLogger.Error(err)
}
dataFound := data.Port > 0
durationToExpiration := data.Expiration.Sub(p.timeNow())
expired := durationToExpiration <= 0
if dataFound {
pfLogger.Info("Found persistent forwarded port data for port %d", data.Port)
if expired {
pfLogger.Warn("Forwarded port data expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
} else {
pfLogger.Info("Forwarded port data expires in %s", gluetunLog.FormatDuration(durationToExpiration))
}
}
if !dataFound || expired {
tryUntilSuccessful(ctx, pfLogger, func() error {
data, err = refreshPIAPortForwardData(client, gateway, fileManager)
return err
})
if ctx.Err() != nil {
return
}
durationToExpiration = data.Expiration.Sub(p.timeNow())
}
pfLogger.Info("Port forwarded is %d expiring in %s", data.Port, gluetunLog.FormatDuration(durationToExpiration))
// First time binding
tryUntilSuccessful(ctx, pfLogger, func() error {
return bindPIAPort(client, gateway, data)
})
if ctx.Err() != nil {
return
}
filepath := syncState(data.Port)
pfLogger.Info("Writing port to %s", filepath)
if err := fileManager.WriteToFile(
string(filepath), []byte(fmt.Sprintf("%d", data.Port)),
files.Permissions(0666),
); err != nil {
pfLogger.Error(err)
}
if err := fw.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
pfLogger.Error(err)
}
expiryTimer := time.NewTimer(durationToExpiration)
defer expiryTimer.Stop()
const keepAlivePeriod = 15 * time.Minute
keepAliveTicker := time.NewTicker(keepAlivePeriod)
defer keepAliveTicker.Stop()
for {
select {
case <-ctx.Done():
removeCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := fw.RemoveAllowedPort(removeCtx, data.Port); err != nil {
pfLogger.Error(err)
}
return
case <-keepAliveTicker.C:
if err := bindPIAPort(client, gateway, data); err != nil {
pfLogger.Error(err)
}
case <-expiryTimer.C:
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
oldPort := data.Port
for {
data, err = refreshPIAPortForwardData(client, gateway, fileManager)
if err != nil {
pfLogger.Error(err)
continue
}
break
}
durationToExpiration := data.Expiration.Sub(p.timeNow())
pfLogger.Info("Port forwarded is %d expiring in %s", data.Port, gluetunLog.FormatDuration(durationToExpiration))
if err := fw.RemoveAllowedPort(ctx, oldPort); err != nil {
pfLogger.Error(err)
}
if err := fw.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
pfLogger.Error(err)
}
filepath := syncState(data.Port)
pfLogger.Info("Writing port to %s", filepath)
if err := fileManager.WriteToFile(
string(filepath), []byte(fmt.Sprintf("%d", data.Port)),
files.Permissions(0666),
); err != nil {
pfLogger.Error(err)
}
if err := bindPIAPort(client, gateway, data); err != nil {
pfLogger.Error(err)
}
keepAliveTicker.Reset(keepAlivePeriod)
expiryTimer.Reset(durationToExpiration)
}
}
}
func newPIAv4HTTPClient() (client *http.Client, err error) {
certificateBytes, err := base64.StdEncoding.DecodeString(constants.PIACertificateStrong)
if err != nil {
return nil, fmt.Errorf("cannot decode PIA root certificate: %w", err)
}
certificate, err := x509.ParseCertificate(certificateBytes)
if err != nil {
return nil, fmt.Errorf("cannot parse PIA root certificate: %w", err)
}
rootCAs := x509.NewCertPool()
rootCAs.AddCert(certificate)
TLSClientConfig := &tls.Config{
RootCAs: rootCAs,
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: true, //nolint:gosec
} // TODO fix and remove InsecureSkipVerify
transport := http.Transport{
TLSClientConfig: TLSClientConfig,
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
const httpTimeout = 5 * time.Second
client = &http.Client{Transport: &transport, Timeout: httpTimeout}
return client, nil
}
func refreshPIAPortForwardData(client *http.Client, gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) {
data.Token, err = fetchPIAToken(fileManager, client)
if err != nil {
return data, fmt.Errorf("cannot obtain token: %w", err)
}
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(client, gateway, data.Token)
if err != nil {
if strings.HasSuffix(err.Error(), "connection refused") {
return data, fmt.Errorf("cannot obtain port forwarding data: connection was refused, are you sure the region you are using supports port forwarding ;)")
}
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
}
if err := writePIAPortForwardData(fileManager, data); err != nil {
return data, fmt.Errorf("cannot persist port forwarding information to file: %w", err)
}
return data, nil
}
type piaPayload struct {
Token string `json:"token"`
Port uint16 `json:"port"`
Expiration time.Time `json:"expires_at"`
}
type piaPortForwardData struct {
Port uint16 `json:"port"`
Token string `json:"token"`
Signature string `json:"signature"`
Expiration time.Time `json:"expires_at"`
}
func readPIAPortForwardData(fileManager files.FileManager) (data piaPortForwardData, err error) {
const filepath = string(constants.PIAPortForward)
exists, err := fileManager.FileExists(filepath)
if err != nil {
return data, err
} else if !exists {
return data, nil
}
b, err := fileManager.ReadFile(filepath)
if err != nil {
return data, err
}
if err := json.Unmarshal(b, &data); err != nil {
return data, err
}
return data, nil
}
func writePIAPortForwardData(fileManager files.FileManager, data piaPortForwardData) (err error) {
b, err := json.Marshal(&data)
if err != nil {
return fmt.Errorf("cannot encode data: %w", err)
}
err = fileManager.WriteToFile(string(constants.PIAPortForward), b)
if err != nil {
return err
}
return nil
}
func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) {
b, err := base64.RawStdEncoding.DecodeString(payload)
if err != nil {
return 0, "", expiration, fmt.Errorf("cannot decode payload: %w", err)
}
var payloadData piaPayload
if err := json.Unmarshal(b, &payloadData); err != nil {
return 0, "", expiration, fmt.Errorf("cannot parse payload data: %w", err)
}
return payloadData.Port, payloadData.Token, payloadData.Expiration, nil
}
func packPIAPayload(port uint16, token string, expiration time.Time) (payload string, err error) {
payloadData := piaPayload{
Token: token,
Port: port,
Expiration: expiration,
}
b, err := json.Marshal(&payloadData)
if err != nil {
return "", fmt.Errorf("cannot serialize payload data: %w", err)
}
payload = base64.RawStdEncoding.EncodeToString(b)
return payload, nil
}
func fetchPIAToken(fileManager files.FileManager, client *http.Client) (token string, err error) {
username, password, err := getOpenvpnCredentials(fileManager)
if err != nil {
return "", fmt.Errorf("cannot get Openvpn credentials: %w", err)
}
url := url.URL{
Scheme: "https",
User: url.UserPassword(username, password),
Host: "10.0.0.1",
Path: "/authv3/generateToken",
}
request, err := http.NewRequest(http.MethodGet, url.String(), nil)
if err != nil {
return "", err
}
response, err := client.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()
b, err := ioutil.ReadAll(response.Body)
if response.StatusCode != http.StatusOK {
shortenMessage := string(b)
shortenMessage = strings.ReplaceAll(shortenMessage, "\n", "")
shortenMessage = strings.ReplaceAll(shortenMessage, " ", " ")
return "", fmt.Errorf("%s: response received: %q", response.Status, shortenMessage)
} else if err != nil {
return "", err
}
var result struct {
Token string `json:"token"`
}
if err := json.Unmarshal(b, &result); err != nil {
return "", err
} else if len(result.Token) == 0 {
return "", fmt.Errorf("token is empty")
}
return result.Token, nil
}
func getOpenvpnCredentials(fileManager files.FileManager) (username, password string, err error) {
authData, err := fileManager.ReadFile(string(constants.OpenVPNAuthConf))
if err != nil {
return "", "", fmt.Errorf("cannot read openvpn auth file: %w", err)
}
lines := strings.Split(string(authData), "\n")
if len(lines) < 2 {
return "", "", fmt.Errorf("not enough lines (%d) in openvpn auth file", len(lines))
}
username, password = lines[0], lines[1]
return username, password, nil
}
func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string) (port uint16, signature string, expiration time.Time, err error) {
queryParams := url.Values{}
queryParams.Add("token", token)
url := url.URL{
Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"),
Path: "/getSignature",
RawQuery: queryParams.Encode(),
}
response, err := client.Get(url.String())
if err != nil {
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %s", response.Status)
}
b, err := ioutil.ReadAll(response.Body)
if err != nil {
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
}
var data struct {
Status string `json:"status"`
Payload string `json:"payload"`
Signature string `json:"signature"`
}
if err := json.Unmarshal(b, &data); err != nil {
return 0, "", expiration, fmt.Errorf("cannot decode received data: %w", err)
} else if data.Status != "OK" {
return 0, "", expiration, fmt.Errorf("response received from PIA has status %s", data.Status)
}
port, _, expiration, err = unpackPIAPayload(data.Payload)
return port, data.Signature, expiration, err
}
func bindPIAPort(client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
payload, err := packPIAPayload(data.Port, data.Token, data.Expiration)
if err != nil {
return err
}
queryParams := url.Values{}
queryParams.Add("payload", payload)
queryParams.Add("signature", data.Signature)
url := url.URL{
Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"),
Path: "/bindPort",
RawQuery: queryParams.Encode(),
}
response, err := client.Get(url.String())
if err != nil {
return fmt.Errorf("cannot bind port: %w", err)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return fmt.Errorf("cannot bind port: %s", response.Status)
}
b, err := ioutil.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("cannot bind port: %w", err)
}
var responseData struct {
Status string `json:"status"`
Message string `json:"message"`
}
if err := json.Unmarshal(b, &responseData); err != nil {
return fmt.Errorf("cannot bind port: %w", err)
} else if responseData.Status != "OK" {
return fmt.Errorf("response received from PIA: %s (%s)", responseData.Status, responseData.Message)
}
return nil
}

View File

@@ -1,24 +1,32 @@
package provider
import (
"context"
"net"
"net/http"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
// Provider contains methods to read and modify the openvpn configuration to connect as a client
type Provider interface {
GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error)
BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string)
GetPortForward(client network.Client) (port uint16, err error)
PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath))
}
func New(provider models.VPNProvider, allServers models.AllServers) Provider {
switch provider {
case constants.PrivateInternetAccess:
return newPrivateInternetAccess(allServers.Pia.Servers)
return newPrivateInternetAccessV4(allServers.Pia.Servers)
case constants.PrivateInternetAccessOld:
return newPrivateInternetAccess(allServers.PiaOld.Servers)
return newPrivateInternetAccessV3(allServers.PiaOld.Servers)
case constants.Mullvad:
return newMullvad(allServers.Mullvad.Servers)
case constants.Windscribe:

View File

@@ -1,12 +1,17 @@
package provider
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
type purevpn struct {
@@ -157,6 +162,8 @@ func (p *purevpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u
return lines
}
func (p *purevpn) GetPortForward(client network.Client) (port uint16, err error) {
func (p *purevpn) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for purevpn")
}

View File

@@ -1,12 +1,17 @@
package provider
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
type surfshark struct {
@@ -135,6 +140,8 @@ func (s *surfshark) BuildConf(connections []models.OpenVPNConnection, verbosity,
return lines
}
func (s *surfshark) GetPortForward(client network.Client) (port uint16, err error) {
func (s *surfshark) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for surfshark")
}

View File

@@ -0,0 +1,29 @@
package provider
import (
"context"
"time"
"github.com/qdm12/golibs/logging"
)
func tryUntilSuccessful(ctx context.Context, logger logging.Logger, fn func() error) {
const retryPeriod = 10 * time.Second
for {
err := fn()
if err == nil {
break
}
logger.Error(err)
logger.Info("Trying again in %s", retryPeriod)
timer := time.NewTimer(retryPeriod)
select {
case <-timer.C:
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return
}
}
}

View File

@@ -1,12 +1,17 @@
package provider
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
type vyprvpn struct {
@@ -121,6 +126,8 @@ func (v *vyprvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u
return lines
}
func (v *vyprvpn) GetPortForward(client network.Client) (port uint16, err error) {
func (v *vyprvpn) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for vyprvpn")
}

View File

@@ -1,12 +1,17 @@
package provider
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
type windscribe struct {
@@ -133,6 +138,8 @@ func (w *windscribe) BuildConf(connections []models.OpenVPNConnection, verbosity
return lines
}
func (w *windscribe) GetPortForward(client network.Client) (port uint16, err error) {
func (w *windscribe) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for windscribe")
}

View File

@@ -8,6 +8,7 @@ import (
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/files"
)
func parseRoutingTable(data []byte) (entries []routingEntry, err error) {
@@ -23,12 +24,16 @@ func parseRoutingTable(data []byte) (entries []routingEntry, err error) {
return entries, nil
}
func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
func getRoutingEntries(fileManager files.FileManager) (entries []routingEntry, err error) {
data, err := fileManager.ReadFile(string(constants.NetRoute))
if err != nil {
return "", nil, err
return nil, err
}
entries, err := parseRoutingTable(data)
return parseRoutingTable(data)
}
func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
entries, err := getRoutingEntries(r.fileManager)
if err != nil {
return "", nil, err
}
@@ -52,11 +57,7 @@ func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP
}
func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
if err != nil {
return defaultSubnet, err
}
entries, err := parseRoutingTable(data)
entries, err := getRoutingEntries(r.fileManager)
if err != nil {
return defaultSubnet, err
}
@@ -79,11 +80,7 @@ func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
}
func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) {
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
if err != nil {
return false, fmt.Errorf("cannot check route existence: %w", err)
}
entries, err := parseRoutingTable(data)
entries, err := getRoutingEntries(r.fileManager)
if err != nil {
return false, fmt.Errorf("cannot check route existence: %w", err)
}
@@ -96,12 +93,8 @@ func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) {
return false, nil
}
func (r *routing) VPNGatewayIP(defaultInterface string) (ip net.IP, err error) {
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
if err != nil {
return nil, fmt.Errorf("cannot find VPN gateway IP address: %w", err)
}
entries, err := parseRoutingTable(data)
func (r *routing) VPNDestinationIP(defaultInterface string) (ip net.IP, err error) {
entries, err := getRoutingEntries(r.fileManager)
if err != nil {
return nil, fmt.Errorf("cannot find VPN gateway IP address: %w", err)
}
@@ -115,6 +108,20 @@ func (r *routing) VPNGatewayIP(defaultInterface string) (ip net.IP, err error) {
return nil, fmt.Errorf("cannot find VPN gateway IP address from ip routes")
}
func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) {
entries, err := getRoutingEntries(r.fileManager)
if err != nil {
return nil, fmt.Errorf("cannot find VPN local gateway IP address: %w", err)
}
for _, entry := range entries {
if entry.iface == string(constants.TUN) &&
entry.destination.Equal(net.IP{0, 0, 0, 0}) {
return entry.gateway, nil
}
}
return nil, fmt.Errorf("cannot find VPN local gateway IP address from ip routes")
}
func ipIsPrivate(ip net.IP) bool {
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true

View File

@@ -291,7 +291,7 @@ eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF
}
}
func Test_VPNGatewayIP(t *testing.T) {
func Test_VPNDestinationIP(t *testing.T) {
t.Parallel()
tests := map[string]struct {
defaultInterface string
@@ -334,7 +334,7 @@ eth0 x
filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
Return(tc.data, tc.readErr).Times(1)
r := &routing{fileManager: filemanager}
ip, err := r.VPNGatewayIP(tc.defaultInterface)
ip, err := r.VPNDestinationIP(tc.defaultInterface)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())

View File

@@ -14,7 +14,8 @@ type Routing interface {
DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error)
DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error)
LocalSubnet() (defaultSubnet net.IPNet, err error)
VPNGatewayIP(defaultInterface string) (ip net.IP, err error)
VPNDestinationIP(defaultInterface string) (ip net.IP, err error)
VPNLocalGatewayIP() (ip net.IP, err error)
SetDebug()
}

View File

@@ -10,7 +10,16 @@ import (
// GetPIASettings obtains PIA settings from environment variables using the params package.
func GetPIASettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) {
settings.Name = constants.PrivateInternetAccess
return getPIASettings(paramsReader, constants.PrivateInternetAccess)
}
// GetPIAOldSettings obtains PIA settings for the older PIA servers (pre summer 2020) from environment variables using the params package.
func GetPIAOldSettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) {
return getPIASettings(paramsReader, constants.PrivateInternetAccessOld)
}
func getPIASettings(paramsReader params.Reader, name models.VPNProvider) (settings models.ProviderSettings, err error) {
settings.Name = name
settings.ServerSelection.Protocol, err = paramsReader.GetNetworkProtocol()
if err != nil {
return settings, err
@@ -29,30 +38,6 @@ func GetPIASettings(paramsReader params.Reader) (settings models.ProviderSetting
if err != nil {
return settings, err
}
return settings, nil
}
// GetPIAOldSettings obtains PIA settings for the older PIA servers (pre summer 2020) from environment variables using the params package.
func GetPIAOldSettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) {
settings.Name = constants.PrivateInternetAccessOld
settings.ServerSelection.Protocol, err = paramsReader.GetNetworkProtocol()
if err != nil {
return settings, err
}
settings.ServerSelection.TargetIP, err = paramsReader.GetTargetIP()
if err != nil {
return settings, err
}
encryptionPreset, err := paramsReader.GetPIAEncryptionPreset()
if err != nil {
return settings, err
}
settings.ServerSelection.EncryptionPreset = encryptionPreset
settings.ExtraConfigOptions.EncryptionPreset = encryptionPreset
settings.ServerSelection.Region, err = paramsReader.GetPIAOldRegion()
if err != nil {
return settings, err
}
settings.PortForwarding.Enabled, err = paramsReader.GetPortForwarding()
if err != nil {
return settings, err

View File

@@ -4,6 +4,8 @@ import (
"fmt"
"net/http"
"time"
"github.com/qdm12/gluetun/internal/logging"
)
// GetMessage returns a message for the user describing if there is a newer version
@@ -30,35 +32,12 @@ func GetMessage(version, commitShort string, client *http.Client) (message strin
if tagName == version {
return fmt.Sprintf("You are running the latest release %s", version), nil
}
timeSinceRelease := formatDuration(time.Since(releaseTime))
timeSinceRelease := logging.FormatDuration(time.Since(releaseTime))
return fmt.Sprintf("There is a new release %s (%s) created %s ago",
tagName, name, timeSinceRelease),
nil
}
func formatDuration(duration time.Duration) string {
switch {
case duration < time.Minute:
seconds := int(duration.Round(time.Second).Seconds())
if seconds < 2 {
return fmt.Sprintf("%d second", seconds)
}
return fmt.Sprintf("%d seconds", seconds)
case duration <= time.Hour:
minutes := int(duration.Round(time.Minute).Minutes())
if minutes == 1 {
return "1 minute"
}
return fmt.Sprintf("%d minutes", minutes)
case duration < 48*time.Hour:
hours := int(duration.Truncate(time.Hour).Hours())
return fmt.Sprintf("%d hours", hours)
default:
days := int(duration.Truncate(time.Hour).Hours() / 24)
return fmt.Sprintf("%d days", days)
}
}
func getLatestRelease(client *http.Client) (tagName, name string, time time.Time, err error) {
releases, err := getGithubReleases(client)
if err != nil {