PIA nextgen portforward (#242)

* Split provider/pia.go in piav3.go and piav4.go
* Change port forwarding signature
* Enable port forwarding parameter for PIA v4
* Fix VPN gateway IP obtention
* Setup HTTP client for TLS with custom cert
* Error message for regions not supporting pf
This commit is contained in:
Quentin McGaw
2020-10-12 10:55:08 -04:00
committed by GitHub
parent fbecbc1c82
commit ec157f102b
25 changed files with 763 additions and 202 deletions

View File

@@ -115,8 +115,8 @@ docker run --rm --network=container:gluetun alpine:3.12 wget -qO- https://ipinfo
| 🏁 `PASSWORD` | | | Your password | | 🏁 `PASSWORD` | | | Your password |
| `REGION` | | One of the [PIA regions](https://www.privateinternetaccess.com/pages/network/) | VPN server region | | `REGION` | | One of the [PIA regions](https://www.privateinternetaccess.com/pages/network/) | VPN server region |
| `PIA_ENCRYPTION` | `strong` | `normal`, `strong` | Encryption preset | | `PIA_ENCRYPTION` | `strong` | `normal`, `strong` | Encryption preset |
| `PORT_FORWARDING` | `off` | `on`, `off` | Enable port forwarding on the VPN server **for old only** | | `PORT_FORWARDING` | `off` | `on`, `off` | Enable port forwarding on the VPN server |
| `PORT_FORWARDING_STATUS_FILE` | `/tmp/gluetun/forwarded_port` | Any filepath | Filepath to store the forwarded port number **for old only** | | `PORT_FORWARDING_STATUS_FILE` | `/tmp/gluetun/forwarded_port` | Any filepath | Filepath to store the forwarded port number |
- Mullvad - Mullvad

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
@@ -188,7 +189,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady) go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady)
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers, openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers,
ovpnConf, firewallConf, logger, client, fileManager, streamMerger, cancel) ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel)
wg.Add(1) wg.Add(1)
// wait for restartOpenvpn // wait for restartOpenvpn
go openvpnLooper.Run(ctx, wg) go openvpnLooper.Run(ctx, wg)
@@ -341,10 +342,11 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger,
}) })
} }
//nolint:gocognit
func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dnsReadyCh <-chan struct{}, func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dnsReadyCh <-chan struct{},
unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper, unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
routing routing.Routing, logger logging.Logger, httpClient *http.Client, routing routing.Routing, logger logging.Logger, httpClient *http.Client,
versionInformation, portForwardingEnabled bool, startPortForward func()) { versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) {
defer wg.Done() defer wg.Done()
tickerWg := &sync.WaitGroup{} tickerWg := &sync.WaitGroup{}
// for linters only // for linters only
@@ -364,18 +366,35 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn
tickerWg.Add(2) tickerWg.Add(2)
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg) go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg) go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
if portForwardingEnabled {
time.AfterFunc(5*time.Second, startPortForward)
}
defaultInterface, _, err := routing.DefaultRoute() defaultInterface, _, err := routing.DefaultRoute()
if err != nil { if err != nil {
logger.Warn(err) logger.Warn(err)
} else { } else {
vpnGatewayIP, err := routing.VPNGatewayIP(defaultInterface) vpnDestination, err := routing.VPNDestinationIP(defaultInterface)
if err != nil { if err != nil {
logger.Warn(err) logger.Warn(err)
} else { } else {
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP) logger.Info("VPN routing IP address: %s", vpnDestination)
}
}
if portForwardingEnabled {
// TODO make instantaneous once v3 go out of service
const waitDuration = 5 * time.Second
timer := time.NewTimer(waitDuration)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
continue
case <-timer.C:
// vpnGateway required only for PIA v4
vpnGateway, err := routing.VPNLocalGatewayIP()
if err != nil {
logger.Error(err)
}
logger.Info("VPN gateway IP address: %s", vpnGateway)
startPortForward(vpnGateway)
} }
} }
case <-dnsReadyCh: case <-dnsReadyCh:

View File

@@ -15,6 +15,8 @@ const (
OpenVPNAuthConf models.Filepath = "/etc/openvpn/auth.conf" OpenVPNAuthConf models.Filepath = "/etc/openvpn/auth.conf"
// OpenVPNConf is the file path to the OpenVPN client configuration file // OpenVPNConf is the file path to the OpenVPN client configuration file
OpenVPNConf models.Filepath = "/etc/openvpn/target.ovpn" OpenVPNConf models.Filepath = "/etc/openvpn/target.ovpn"
// PIAPortForward is the file path to the port forwarding JSON information for PIA v4 servers
PIAPortForward models.Filepath = "/gluetun/piaportforward.json"
// TunnelDevice is the file path to tun device // TunnelDevice is the file path to tun device
TunnelDevice models.Filepath = "/dev/net/tun" TunnelDevice models.Filepath = "/dev/net/tun"
// NetRoute is the path to the file containing information on the network route // NetRoute is the path to the file containing information on the network route

View File

@@ -2,9 +2,9 @@ package constants
const ( const (
// Announcement is a message announcement // Announcement is a message announcement
Announcement = "Update servers information see https://github.com/qdm12/gluetun/wiki/Update-servers-information" Announcement = "Port forwarding is working for PIA v4 servers"
// AnnouncementExpiration is the expiration date of the announcement in format yyyy-mm-dd // AnnouncementExpiration is the expiration date of the announcement in format yyyy-mm-dd
AnnouncementExpiration = "2020-10-10" AnnouncementExpiration = "2020-11-15"
) )
const ( const (

View File

@@ -0,0 +1,29 @@
package logging
import (
"fmt"
"time"
)
func FormatDuration(duration time.Duration) string {
switch {
case duration < time.Minute:
seconds := int(duration.Round(time.Second).Seconds())
if seconds < 2 {
return fmt.Sprintf("%d second", seconds)
}
return fmt.Sprintf("%d seconds", seconds)
case duration <= time.Hour:
minutes := int(duration.Round(time.Minute).Minutes())
if minutes == 1 {
return "1 minute"
}
return fmt.Sprintf("%d minutes", minutes)
case duration < 48*time.Hour:
hours := int(duration.Truncate(time.Hour).Hours())
return fmt.Sprintf("%d hours", hours)
default:
days := int(duration.Truncate(time.Hour).Hours() / 24)
return fmt.Sprintf("%d days", days)
}
}

View File

@@ -1,4 +1,4 @@
package version package logging
import ( import (
"testing" "testing"
@@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_formatDuration(t *testing.T) { func Test_FormatDuration(t *testing.T) {
t.Parallel() t.Parallel()
testCases := map[string]struct { testCases := map[string]struct {
duration time.Duration duration time.Duration
@@ -57,7 +57,7 @@ func Test_formatDuration(t *testing.T) {
testCase := testCase testCase := testCase
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
s := formatDuration(testCase.duration) s := FormatDuration(testCase.duration)
assert.Equal(t, testCase.s, s) assert.Equal(t, testCase.s, s)
}) })
} }

View File

@@ -90,6 +90,7 @@ func (p *ProviderSettings) String() string {
settingsList = append(settingsList, settingsList = append(settingsList,
"Region: "+p.ServerSelection.Region, "Region: "+p.ServerSelection.Region,
"Encryption preset: "+p.ExtraConfigOptions.EncryptionPreset, "Encryption preset: "+p.ExtraConfigOptions.EncryptionPreset,
"Port forwarding: "+p.PortForwarding.String(),
) )
case "mullvad": case "mullvad":
settingsList = append(settingsList, settingsList = append(settingsList,

View File

@@ -2,7 +2,8 @@ package openvpn
import ( import (
"context" "context"
"fmt" "net"
"net/http"
"sync" "sync"
"time" "time"
@@ -10,17 +11,17 @@ import (
"github.com/qdm12/gluetun/internal/firewall" "github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/gluetun/internal/settings" "github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/command" "github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files" "github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging" "github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
) )
type Looper interface { type Looper interface {
Run(ctx context.Context, wg *sync.WaitGroup) Run(ctx context.Context, wg *sync.WaitGroup)
Restart() Restart()
PortForward() PortForward(vpnGatewayIP net.IP)
GetSettings() (settings settings.OpenVPN) GetSettings() (settings settings.OpenVPN)
SetSettings(settings settings.OpenVPN) SetSettings(settings settings.OpenVPN)
GetPortForwarded() (portForwarded uint16) GetPortForwarded() (portForwarded uint16)
@@ -40,23 +41,24 @@ type looper struct {
uid int uid int
gid int gid int
// Configurators // Configurators
conf Configurator conf Configurator
fw firewall.Configurator fw firewall.Configurator
routing routing.Routing
// Other objects // Other objects
logger logging.Logger logger, pfLogger logging.Logger
client network.Client client *http.Client
fileManager files.FileManager fileManager files.FileManager
streamMerger command.StreamMerger streamMerger command.StreamMerger
cancel context.CancelFunc cancel context.CancelFunc
// Internal channels // Internal channels
restart chan struct{} restart chan struct{}
portForwardSignals chan struct{} portForwardSignals chan net.IP
} }
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN, func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
uid, gid int, allServers models.AllServers, uid, gid int, allServers models.AllServers,
conf Configurator, fw firewall.Configurator, conf Configurator, fw firewall.Configurator, routing routing.Routing,
logger logging.Logger, client network.Client, fileManager files.FileManager, logger logging.Logger, client *http.Client, fileManager files.FileManager,
streamMerger command.StreamMerger, cancel context.CancelFunc) Looper { streamMerger command.StreamMerger, cancel context.CancelFunc) Looper {
return &looper{ return &looper{
provider: provider, provider: provider,
@@ -66,18 +68,20 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
allServers: allServers, allServers: allServers,
conf: conf, conf: conf,
fw: fw, fw: fw,
routing: routing,
logger: logger.WithPrefix("openvpn: "), logger: logger.WithPrefix("openvpn: "),
pfLogger: logger.WithPrefix("port forwarding: "),
client: client, client: client,
fileManager: fileManager, fileManager: fileManager,
streamMerger: streamMerger, streamMerger: streamMerger,
cancel: cancel, cancel: cancel,
restart: make(chan struct{}), restart: make(chan struct{}),
portForwardSignals: make(chan struct{}), portForwardSignals: make(chan net.IP),
} }
} }
func (l *looper) Restart() { l.restart <- struct{}{} } func (l *looper) Restart() { l.restart <- struct{}{} }
func (l *looper) PortForward() { l.portForwardSignals <- struct{}{} } func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway }
func (l *looper) GetSettings() (settings settings.OpenVPN) { func (l *looper) GetSettings() (settings settings.OpenVPN) {
l.settingsMutex.RLock() l.settingsMutex.RLock()
@@ -158,10 +162,12 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
go func(ctx context.Context) { go func(ctx context.Context) {
for { for {
select { select {
// TODO have a way to disable pf with a context
case <-ctx.Done(): case <-ctx.Done():
return return
case <-l.portForwardSignals: case gateway := <-l.portForwardSignals:
l.portForward(ctx, providerConf, l.client) wg.Add(1)
go l.portForward(ctx, wg, providerConf, l.client, gateway)
} }
} }
}(openvpnCtx) }(openvpnCtx)
@@ -200,43 +206,25 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
<-ctx.Done() <-ctx.Done()
} }
func (l *looper) portForward(ctx context.Context, providerConf provider.Provider, client network.Client) { // portForward is a blocking operation which may or may not be infinite.
// You should therefore always call it in a goroutine
func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup,
providerConf provider.Provider, client *http.Client, gateway net.IP) {
defer wg.Done()
settings := l.GetSettings() settings := l.GetSettings()
if !settings.Provider.PortForwarding.Enabled { if !settings.Provider.PortForwarding.Enabled {
return return
} }
var port uint16 syncState := func(port uint16) (pfFilepath models.Filepath) {
err := fmt.Errorf("") l.portForwardedMutex.Lock()
for err != nil { l.portForwarded = port
if ctx.Err() != nil { l.portForwardedMutex.Unlock()
return settings := l.GetSettings()
} return settings.Provider.PortForwarding.Filepath
port, err = providerConf.GetPortForward(client)
if err != nil {
l.logAndWait(ctx, err)
}
}
l.logger.Info("port forwarded is %d", port)
l.portForwardedMutex.Lock()
if err := l.fw.RemoveAllowedPort(ctx, l.portForwarded); err != nil {
l.logger.Error(err)
}
if err := l.fw.SetAllowedPort(ctx, port, string(constants.TUN)); err != nil {
l.logger.Error(err)
}
l.portForwarded = port
l.portForwardedMutex.Unlock()
filepath := settings.Provider.PortForwarding.Filepath
l.logger.Info("writing forwarded port to %s", filepath)
err = l.fileManager.WriteLinesToFile(
string(filepath), []string{fmt.Sprintf("%d", port)},
files.Ownership(l.uid, l.gid), files.Permissions(0400),
)
if err != nil {
l.logger.Error(err)
} }
providerConf.PortForward(ctx,
client, l.fileManager, l.pfLogger,
gateway, l.fw, syncState)
} }
func (l *looper) GetPortForwarded() (portForwarded uint16) { func (l *looper) GetPortForwarded() (portForwarded uint16) {

View File

@@ -1,12 +1,17 @@
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"net"
"net/http"
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
) )
type cyberghost struct { type cyberghost struct {
@@ -135,6 +140,8 @@ func (c *cyberghost) BuildConf(connections []models.OpenVPNConnection, verbosity
return lines return lines
} }
func (c *cyberghost) GetPortForward(client network.Client) (port uint16, err error) { func (c *cyberghost) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for cyberghost") panic("port forwarding is not supported for cyberghost")
} }

View File

@@ -1,12 +1,17 @@
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"net"
"net/http"
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
) )
type mullvad struct { type mullvad struct {
@@ -134,6 +139,8 @@ func (m *mullvad) BuildConf(connections []models.OpenVPNConnection, verbosity, u
return lines return lines
} }
func (m *mullvad) GetPortForward(client network.Client) (port uint16, err error) { func (m *mullvad) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for mullvad") panic("port forwarding is not supported for mullvad")
} }

View File

@@ -1,12 +1,17 @@
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"net"
"net/http"
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
) )
type nordvpn struct { type nordvpn struct {
@@ -142,6 +147,8 @@ func (n *nordvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u
return lines return lines
} }
func (n *nordvpn) GetPortForward(client network.Client) (port uint16, err error) { func (n *nordvpn) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for nordvpn") panic("port forwarding is not supported for nordvpn")
} }

View File

@@ -1,35 +1,18 @@
package provider package provider
import ( import (
"encoding/hex"
"encoding/json"
"fmt" "fmt"
"net/http"
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/crypto/random"
"github.com/qdm12/golibs/network"
) )
type pia struct { func filterPIAServers(servers []models.PIAServer, region string) (filtered []models.PIAServer) {
random random.Random
servers []models.PIAServer
}
func newPrivateInternetAccess(servers []models.PIAServer) *pia {
return &pia{
random: random.NewRandom(),
servers: servers,
}
}
func (p *pia) filterServers(region string) (servers []models.PIAServer) {
if len(region) == 0 { if len(region) == 0 {
return p.servers return servers
} }
for _, server := range p.servers { for _, server := range servers {
if strings.EqualFold(server.Region, region) { if strings.EqualFold(server.Region, region) {
return []models.PIAServer{server} return []models.PIAServer{server}
} }
@@ -37,8 +20,8 @@ func (p *pia) filterServers(region string) (servers []models.PIAServer) {
return nil return nil
} }
func (p *pia) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) { func getPIAOpenVPNConnections(allServers []models.PIAServer, selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) {
servers := p.filterServers(selection.Region) servers := filterPIAServers(allServers, selection.Region)
if len(servers) == 0 { if len(servers) == 0 {
return nil, fmt.Errorf("no server found for region %q", selection.Region) return nil, fmt.Errorf("no server found for region %q", selection.Region)
} }
@@ -87,7 +70,7 @@ func (p *pia) GetOpenVPNConnections(selection models.ServerSelection) (connectio
return connections, nil return connections, nil
} }
func (p *pia) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) { func buildPIAConf(connections []models.OpenVPNConnection, verbosity int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
var X509CRL, certificate string var X509CRL, certificate string
if extras.EncryptionPreset == constants.PIAEncryptionPresetNormal { if extras.EncryptionPreset == constants.PIAEncryptionPresetNormal {
if len(cipher) == 0 { if len(cipher) == 0 {
@@ -161,28 +144,3 @@ func (p *pia) BuildConf(connections []models.OpenVPNConnection, verbosity, uid,
}...) }...)
return lines return lines
} }
func (p *pia) GetPortForward(client network.Client) (port uint16, err error) {
b, err := p.random.GenerateRandomBytes(32)
if err != nil {
return 0, err
}
clientID := hex.EncodeToString(b)
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
content, status, err := client.GetContent(url) // TODO add ctx
switch {
case err != nil:
return 0, err
case status != http.StatusOK:
return 0, fmt.Errorf("status is %d for %s; does your PIA server support port forwarding?", status, url)
case len(content) == 0:
return 0, fmt.Errorf("port forwarding is already activated on this connection, has expired, or you are not connected to a PIA region that supports port forwarding")
}
body := struct {
Port uint16 `json:"port"`
}{}
if err := json.Unmarshal(content, &body); err != nil {
return 0, fmt.Errorf("port forwarding response: %w", err)
}
return body.Port, nil
}

View File

@@ -0,0 +1,94 @@
package provider
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/crypto/random"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
type piaV3 struct {
random random.Random
servers []models.PIAServer
}
func newPrivateInternetAccessV3(servers []models.PIAServer) *piaV3 {
return &piaV3{
random: random.NewRandom(),
servers: servers,
}
}
func (p *piaV3) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) {
return getPIAOpenVPNConnections(p.servers, selection)
}
func (p *piaV3) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
return buildPIAConf(connections, verbosity, root, cipher, auth, extras)
}
func (p *piaV3) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
b, err := p.random.GenerateRandomBytes(32)
if err != nil {
pfLogger.Error(err)
return
}
clientID := hex.EncodeToString(b)
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
response, err := client.Get(url) // TODO add ctx
if err != nil {
pfLogger.Error(err)
return
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
pfLogger.Error(fmt.Errorf("%s for %s; does your PIA server support port forwarding?", response.Status, url))
return
}
b, err = ioutil.ReadAll(response.Body)
if err != nil {
pfLogger.Error(err)
return
} else if len(b) == 0 {
pfLogger.Error(fmt.Errorf("port forwarding is already activated on this connection, has expired, or you are not connected to a PIA region that supports port forwarding"))
return
}
body := struct {
Port uint16 `json:"port"`
}{}
if err := json.Unmarshal(b, &body); err != nil {
pfLogger.Error(fmt.Errorf("port forwarding response: %w", err))
return
}
port := body.Port
filepath := syncState(port)
pfLogger.Info("Writing port to %s", filepath)
if err := fileManager.WriteToFile(
string(filepath), []byte(fmt.Sprintf("%d", port)),
files.Permissions(0666),
); err != nil {
pfLogger.Error(err)
}
if err := fw.SetAllowedPort(ctx, port, string(constants.TUN)); err != nil {
pfLogger.Error(err)
}
<-ctx.Done()
if err := fw.RemoveAllowedPort(ctx, port); err != nil {
pfLogger.Error(err)
}
}

412
internal/provider/piav4.go Normal file
View File

@@ -0,0 +1,412 @@
package provider
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
gluetunLog "github.com/qdm12/gluetun/internal/logging"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
type piaV4 struct {
servers []models.PIAServer
timeNow func() time.Time
}
func newPrivateInternetAccessV4(servers []models.PIAServer) *piaV4 {
return &piaV4{
servers: servers,
timeNow: time.Now,
}
}
func (p *piaV4) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) {
return getPIAOpenVPNConnections(p.servers, selection)
}
func (p *piaV4) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
return buildPIAConf(connections, verbosity, root, cipher, auth, extras)
}
//nolint:gocognit
func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
if gateway == nil {
pfLogger.Error("aborting because: VPN gateway IP address was not found")
return
}
client, err := newPIAv4HTTPClient()
if err != nil {
pfLogger.Error("aborting because: %s", err)
return
}
defer pfLogger.Warn("loop exited")
data, err := readPIAPortForwardData(fileManager)
if err != nil {
pfLogger.Error(err)
}
dataFound := data.Port > 0
durationToExpiration := data.Expiration.Sub(p.timeNow())
expired := durationToExpiration <= 0
if dataFound {
pfLogger.Info("Found persistent forwarded port data for port %d", data.Port)
if expired {
pfLogger.Warn("Forwarded port data expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
} else {
pfLogger.Info("Forwarded port data expires in %s", gluetunLog.FormatDuration(durationToExpiration))
}
}
if !dataFound || expired {
tryUntilSuccessful(ctx, pfLogger, func() error {
data, err = refreshPIAPortForwardData(client, gateway, fileManager)
return err
})
if ctx.Err() != nil {
return
}
durationToExpiration = data.Expiration.Sub(p.timeNow())
}
pfLogger.Info("Port forwarded is %d expiring in %s", data.Port, gluetunLog.FormatDuration(durationToExpiration))
// First time binding
tryUntilSuccessful(ctx, pfLogger, func() error {
return bindPIAPort(client, gateway, data)
})
if ctx.Err() != nil {
return
}
filepath := syncState(data.Port)
pfLogger.Info("Writing port to %s", filepath)
if err := fileManager.WriteToFile(
string(filepath), []byte(fmt.Sprintf("%d", data.Port)),
files.Permissions(0666),
); err != nil {
pfLogger.Error(err)
}
if err := fw.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
pfLogger.Error(err)
}
expiryTimer := time.NewTimer(durationToExpiration)
defer expiryTimer.Stop()
const keepAlivePeriod = 15 * time.Minute
keepAliveTicker := time.NewTicker(keepAlivePeriod)
defer keepAliveTicker.Stop()
for {
select {
case <-ctx.Done():
removeCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := fw.RemoveAllowedPort(removeCtx, data.Port); err != nil {
pfLogger.Error(err)
}
return
case <-keepAliveTicker.C:
if err := bindPIAPort(client, gateway, data); err != nil {
pfLogger.Error(err)
}
case <-expiryTimer.C:
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
oldPort := data.Port
for {
data, err = refreshPIAPortForwardData(client, gateway, fileManager)
if err != nil {
pfLogger.Error(err)
continue
}
break
}
durationToExpiration := data.Expiration.Sub(p.timeNow())
pfLogger.Info("Port forwarded is %d expiring in %s", data.Port, gluetunLog.FormatDuration(durationToExpiration))
if err := fw.RemoveAllowedPort(ctx, oldPort); err != nil {
pfLogger.Error(err)
}
if err := fw.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
pfLogger.Error(err)
}
filepath := syncState(data.Port)
pfLogger.Info("Writing port to %s", filepath)
if err := fileManager.WriteToFile(
string(filepath), []byte(fmt.Sprintf("%d", data.Port)),
files.Permissions(0666),
); err != nil {
pfLogger.Error(err)
}
if err := bindPIAPort(client, gateway, data); err != nil {
pfLogger.Error(err)
}
keepAliveTicker.Reset(keepAlivePeriod)
expiryTimer.Reset(durationToExpiration)
}
}
}
func newPIAv4HTTPClient() (client *http.Client, err error) {
certificateBytes, err := base64.StdEncoding.DecodeString(constants.PIACertificateStrong)
if err != nil {
return nil, fmt.Errorf("cannot decode PIA root certificate: %w", err)
}
certificate, err := x509.ParseCertificate(certificateBytes)
if err != nil {
return nil, fmt.Errorf("cannot parse PIA root certificate: %w", err)
}
rootCAs := x509.NewCertPool()
rootCAs.AddCert(certificate)
TLSClientConfig := &tls.Config{
RootCAs: rootCAs,
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: true, //nolint:gosec
} // TODO fix and remove InsecureSkipVerify
transport := http.Transport{
TLSClientConfig: TLSClientConfig,
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
const httpTimeout = 5 * time.Second
client = &http.Client{Transport: &transport, Timeout: httpTimeout}
return client, nil
}
func refreshPIAPortForwardData(client *http.Client, gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) {
data.Token, err = fetchPIAToken(fileManager, client)
if err != nil {
return data, fmt.Errorf("cannot obtain token: %w", err)
}
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(client, gateway, data.Token)
if err != nil {
if strings.HasSuffix(err.Error(), "connection refused") {
return data, fmt.Errorf("cannot obtain port forwarding data: connection was refused, are you sure the region you are using supports port forwarding ;)")
}
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
}
if err := writePIAPortForwardData(fileManager, data); err != nil {
return data, fmt.Errorf("cannot persist port forwarding information to file: %w", err)
}
return data, nil
}
type piaPayload struct {
Token string `json:"token"`
Port uint16 `json:"port"`
Expiration time.Time `json:"expires_at"`
}
type piaPortForwardData struct {
Port uint16 `json:"port"`
Token string `json:"token"`
Signature string `json:"signature"`
Expiration time.Time `json:"expires_at"`
}
func readPIAPortForwardData(fileManager files.FileManager) (data piaPortForwardData, err error) {
const filepath = string(constants.PIAPortForward)
exists, err := fileManager.FileExists(filepath)
if err != nil {
return data, err
} else if !exists {
return data, nil
}
b, err := fileManager.ReadFile(filepath)
if err != nil {
return data, err
}
if err := json.Unmarshal(b, &data); err != nil {
return data, err
}
return data, nil
}
func writePIAPortForwardData(fileManager files.FileManager, data piaPortForwardData) (err error) {
b, err := json.Marshal(&data)
if err != nil {
return fmt.Errorf("cannot encode data: %w", err)
}
err = fileManager.WriteToFile(string(constants.PIAPortForward), b)
if err != nil {
return err
}
return nil
}
func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) {
b, err := base64.RawStdEncoding.DecodeString(payload)
if err != nil {
return 0, "", expiration, fmt.Errorf("cannot decode payload: %w", err)
}
var payloadData piaPayload
if err := json.Unmarshal(b, &payloadData); err != nil {
return 0, "", expiration, fmt.Errorf("cannot parse payload data: %w", err)
}
return payloadData.Port, payloadData.Token, payloadData.Expiration, nil
}
func packPIAPayload(port uint16, token string, expiration time.Time) (payload string, err error) {
payloadData := piaPayload{
Token: token,
Port: port,
Expiration: expiration,
}
b, err := json.Marshal(&payloadData)
if err != nil {
return "", fmt.Errorf("cannot serialize payload data: %w", err)
}
payload = base64.RawStdEncoding.EncodeToString(b)
return payload, nil
}
func fetchPIAToken(fileManager files.FileManager, client *http.Client) (token string, err error) {
username, password, err := getOpenvpnCredentials(fileManager)
if err != nil {
return "", fmt.Errorf("cannot get Openvpn credentials: %w", err)
}
url := url.URL{
Scheme: "https",
User: url.UserPassword(username, password),
Host: "10.0.0.1",
Path: "/authv3/generateToken",
}
request, err := http.NewRequest(http.MethodGet, url.String(), nil)
if err != nil {
return "", err
}
response, err := client.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()
b, err := ioutil.ReadAll(response.Body)
if response.StatusCode != http.StatusOK {
shortenMessage := string(b)
shortenMessage = strings.ReplaceAll(shortenMessage, "\n", "")
shortenMessage = strings.ReplaceAll(shortenMessage, " ", " ")
return "", fmt.Errorf("%s: response received: %q", response.Status, shortenMessage)
} else if err != nil {
return "", err
}
var result struct {
Token string `json:"token"`
}
if err := json.Unmarshal(b, &result); err != nil {
return "", err
} else if len(result.Token) == 0 {
return "", fmt.Errorf("token is empty")
}
return result.Token, nil
}
func getOpenvpnCredentials(fileManager files.FileManager) (username, password string, err error) {
authData, err := fileManager.ReadFile(string(constants.OpenVPNAuthConf))
if err != nil {
return "", "", fmt.Errorf("cannot read openvpn auth file: %w", err)
}
lines := strings.Split(string(authData), "\n")
if len(lines) < 2 {
return "", "", fmt.Errorf("not enough lines (%d) in openvpn auth file", len(lines))
}
username, password = lines[0], lines[1]
return username, password, nil
}
func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string) (port uint16, signature string, expiration time.Time, err error) {
queryParams := url.Values{}
queryParams.Add("token", token)
url := url.URL{
Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"),
Path: "/getSignature",
RawQuery: queryParams.Encode(),
}
response, err := client.Get(url.String())
if err != nil {
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %s", response.Status)
}
b, err := ioutil.ReadAll(response.Body)
if err != nil {
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
}
var data struct {
Status string `json:"status"`
Payload string `json:"payload"`
Signature string `json:"signature"`
}
if err := json.Unmarshal(b, &data); err != nil {
return 0, "", expiration, fmt.Errorf("cannot decode received data: %w", err)
} else if data.Status != "OK" {
return 0, "", expiration, fmt.Errorf("response received from PIA has status %s", data.Status)
}
port, _, expiration, err = unpackPIAPayload(data.Payload)
return port, data.Signature, expiration, err
}
func bindPIAPort(client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
payload, err := packPIAPayload(data.Port, data.Token, data.Expiration)
if err != nil {
return err
}
queryParams := url.Values{}
queryParams.Add("payload", payload)
queryParams.Add("signature", data.Signature)
url := url.URL{
Scheme: "https",
Host: net.JoinHostPort(gateway.String(), "19999"),
Path: "/bindPort",
RawQuery: queryParams.Encode(),
}
response, err := client.Get(url.String())
if err != nil {
return fmt.Errorf("cannot bind port: %w", err)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return fmt.Errorf("cannot bind port: %s", response.Status)
}
b, err := ioutil.ReadAll(response.Body)
if err != nil {
return fmt.Errorf("cannot bind port: %w", err)
}
var responseData struct {
Status string `json:"status"`
Message string `json:"message"`
}
if err := json.Unmarshal(b, &responseData); err != nil {
return fmt.Errorf("cannot bind port: %w", err)
} else if responseData.Status != "OK" {
return fmt.Errorf("response received from PIA: %s (%s)", responseData.Status, responseData.Message)
}
return nil
}

View File

@@ -1,24 +1,32 @@
package provider package provider
import ( import (
"context"
"net"
"net/http"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
) )
// Provider contains methods to read and modify the openvpn configuration to connect as a client // Provider contains methods to read and modify the openvpn configuration to connect as a client
type Provider interface { type Provider interface {
GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error)
BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string)
GetPortForward(client network.Client) (port uint16, err error) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath))
} }
func New(provider models.VPNProvider, allServers models.AllServers) Provider { func New(provider models.VPNProvider, allServers models.AllServers) Provider {
switch provider { switch provider {
case constants.PrivateInternetAccess: case constants.PrivateInternetAccess:
return newPrivateInternetAccess(allServers.Pia.Servers) return newPrivateInternetAccessV4(allServers.Pia.Servers)
case constants.PrivateInternetAccessOld: case constants.PrivateInternetAccessOld:
return newPrivateInternetAccess(allServers.PiaOld.Servers) return newPrivateInternetAccessV3(allServers.PiaOld.Servers)
case constants.Mullvad: case constants.Mullvad:
return newMullvad(allServers.Mullvad.Servers) return newMullvad(allServers.Mullvad.Servers)
case constants.Windscribe: case constants.Windscribe:

View File

@@ -1,12 +1,17 @@
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"net"
"net/http"
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
) )
type purevpn struct { type purevpn struct {
@@ -157,6 +162,8 @@ func (p *purevpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u
return lines return lines
} }
func (p *purevpn) GetPortForward(client network.Client) (port uint16, err error) { func (p *purevpn) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for purevpn") panic("port forwarding is not supported for purevpn")
} }

View File

@@ -1,12 +1,17 @@
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"net"
"net/http"
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
) )
type surfshark struct { type surfshark struct {
@@ -135,6 +140,8 @@ func (s *surfshark) BuildConf(connections []models.OpenVPNConnection, verbosity,
return lines return lines
} }
func (s *surfshark) GetPortForward(client network.Client) (port uint16, err error) { func (s *surfshark) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for surfshark") panic("port forwarding is not supported for surfshark")
} }

View File

@@ -0,0 +1,29 @@
package provider
import (
"context"
"time"
"github.com/qdm12/golibs/logging"
)
func tryUntilSuccessful(ctx context.Context, logger logging.Logger, fn func() error) {
const retryPeriod = 10 * time.Second
for {
err := fn()
if err == nil {
break
}
logger.Error(err)
logger.Info("Trying again in %s", retryPeriod)
timer := time.NewTimer(retryPeriod)
select {
case <-timer.C:
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return
}
}
}

View File

@@ -1,12 +1,17 @@
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"net"
"net/http"
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
) )
type vyprvpn struct { type vyprvpn struct {
@@ -121,6 +126,8 @@ func (v *vyprvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u
return lines return lines
} }
func (v *vyprvpn) GetPortForward(client network.Client) (port uint16, err error) { func (v *vyprvpn) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for vyprvpn") panic("port forwarding is not supported for vyprvpn")
} }

View File

@@ -1,12 +1,17 @@
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"net"
"net/http"
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/network" "github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
) )
type windscribe struct { type windscribe struct {
@@ -133,6 +138,8 @@ func (w *windscribe) BuildConf(connections []models.OpenVPNConnection, verbosity
return lines return lines
} }
func (w *windscribe) GetPortForward(client network.Client) (port uint16, err error) { func (w *windscribe) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for windscribe") panic("port forwarding is not supported for windscribe")
} }

View File

@@ -8,6 +8,7 @@ import (
"strings" "strings"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/files"
) )
func parseRoutingTable(data []byte) (entries []routingEntry, err error) { func parseRoutingTable(data []byte) (entries []routingEntry, err error) {
@@ -23,12 +24,16 @@ func parseRoutingTable(data []byte) (entries []routingEntry, err error) {
return entries, nil return entries, nil
} }
func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) { func getRoutingEntries(fileManager files.FileManager) (entries []routingEntry, err error) {
data, err := r.fileManager.ReadFile(string(constants.NetRoute)) data, err := fileManager.ReadFile(string(constants.NetRoute))
if err != nil { if err != nil {
return "", nil, err return nil, err
} }
entries, err := parseRoutingTable(data) return parseRoutingTable(data)
}
func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
entries, err := getRoutingEntries(r.fileManager)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@@ -52,11 +57,7 @@ func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP
} }
func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) { func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
data, err := r.fileManager.ReadFile(string(constants.NetRoute)) entries, err := getRoutingEntries(r.fileManager)
if err != nil {
return defaultSubnet, err
}
entries, err := parseRoutingTable(data)
if err != nil { if err != nil {
return defaultSubnet, err return defaultSubnet, err
} }
@@ -79,11 +80,7 @@ func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
} }
func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) { func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) {
data, err := r.fileManager.ReadFile(string(constants.NetRoute)) entries, err := getRoutingEntries(r.fileManager)
if err != nil {
return false, fmt.Errorf("cannot check route existence: %w", err)
}
entries, err := parseRoutingTable(data)
if err != nil { if err != nil {
return false, fmt.Errorf("cannot check route existence: %w", err) return false, fmt.Errorf("cannot check route existence: %w", err)
} }
@@ -96,12 +93,8 @@ func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) {
return false, nil return false, nil
} }
func (r *routing) VPNGatewayIP(defaultInterface string) (ip net.IP, err error) { func (r *routing) VPNDestinationIP(defaultInterface string) (ip net.IP, err error) {
data, err := r.fileManager.ReadFile(string(constants.NetRoute)) entries, err := getRoutingEntries(r.fileManager)
if err != nil {
return nil, fmt.Errorf("cannot find VPN gateway IP address: %w", err)
}
entries, err := parseRoutingTable(data)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot find VPN gateway IP address: %w", err) return nil, fmt.Errorf("cannot find VPN gateway IP address: %w", err)
} }
@@ -115,6 +108,20 @@ func (r *routing) VPNGatewayIP(defaultInterface string) (ip net.IP, err error) {
return nil, fmt.Errorf("cannot find VPN gateway IP address from ip routes") return nil, fmt.Errorf("cannot find VPN gateway IP address from ip routes")
} }
func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) {
entries, err := getRoutingEntries(r.fileManager)
if err != nil {
return nil, fmt.Errorf("cannot find VPN local gateway IP address: %w", err)
}
for _, entry := range entries {
if entry.iface == string(constants.TUN) &&
entry.destination.Equal(net.IP{0, 0, 0, 0}) {
return entry.gateway, nil
}
}
return nil, fmt.Errorf("cannot find VPN local gateway IP address from ip routes")
}
func ipIsPrivate(ip net.IP) bool { func ipIsPrivate(ip net.IP) bool {
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true return true

View File

@@ -291,7 +291,7 @@ eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF
} }
} }
func Test_VPNGatewayIP(t *testing.T) { func Test_VPNDestinationIP(t *testing.T) {
t.Parallel() t.Parallel()
tests := map[string]struct { tests := map[string]struct {
defaultInterface string defaultInterface string
@@ -334,7 +334,7 @@ eth0 x
filemanager.EXPECT().ReadFile(string(constants.NetRoute)). filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
Return(tc.data, tc.readErr).Times(1) Return(tc.data, tc.readErr).Times(1)
r := &routing{fileManager: filemanager} r := &routing{fileManager: filemanager}
ip, err := r.VPNGatewayIP(tc.defaultInterface) ip, err := r.VPNDestinationIP(tc.defaultInterface)
if tc.err != nil { if tc.err != nil {
require.Error(t, err) require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error()) assert.Equal(t, tc.err.Error(), err.Error())

View File

@@ -14,7 +14,8 @@ type Routing interface {
DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error) DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error)
DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error)
LocalSubnet() (defaultSubnet net.IPNet, err error) LocalSubnet() (defaultSubnet net.IPNet, err error)
VPNGatewayIP(defaultInterface string) (ip net.IP, err error) VPNDestinationIP(defaultInterface string) (ip net.IP, err error)
VPNLocalGatewayIP() (ip net.IP, err error)
SetDebug() SetDebug()
} }

View File

@@ -10,7 +10,16 @@ import (
// GetPIASettings obtains PIA settings from environment variables using the params package. // GetPIASettings obtains PIA settings from environment variables using the params package.
func GetPIASettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) { func GetPIASettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) {
settings.Name = constants.PrivateInternetAccess return getPIASettings(paramsReader, constants.PrivateInternetAccess)
}
// GetPIAOldSettings obtains PIA settings for the older PIA servers (pre summer 2020) from environment variables using the params package.
func GetPIAOldSettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) {
return getPIASettings(paramsReader, constants.PrivateInternetAccessOld)
}
func getPIASettings(paramsReader params.Reader, name models.VPNProvider) (settings models.ProviderSettings, err error) {
settings.Name = name
settings.ServerSelection.Protocol, err = paramsReader.GetNetworkProtocol() settings.ServerSelection.Protocol, err = paramsReader.GetNetworkProtocol()
if err != nil { if err != nil {
return settings, err return settings, err
@@ -29,30 +38,6 @@ func GetPIASettings(paramsReader params.Reader) (settings models.ProviderSetting
if err != nil { if err != nil {
return settings, err return settings, err
} }
return settings, nil
}
// GetPIAOldSettings obtains PIA settings for the older PIA servers (pre summer 2020) from environment variables using the params package.
func GetPIAOldSettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) {
settings.Name = constants.PrivateInternetAccessOld
settings.ServerSelection.Protocol, err = paramsReader.GetNetworkProtocol()
if err != nil {
return settings, err
}
settings.ServerSelection.TargetIP, err = paramsReader.GetTargetIP()
if err != nil {
return settings, err
}
encryptionPreset, err := paramsReader.GetPIAEncryptionPreset()
if err != nil {
return settings, err
}
settings.ServerSelection.EncryptionPreset = encryptionPreset
settings.ExtraConfigOptions.EncryptionPreset = encryptionPreset
settings.ServerSelection.Region, err = paramsReader.GetPIAOldRegion()
if err != nil {
return settings, err
}
settings.PortForwarding.Enabled, err = paramsReader.GetPortForwarding() settings.PortForwarding.Enabled, err = paramsReader.GetPortForwarding()
if err != nil { if err != nil {
return settings, err return settings, err

View File

@@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"time" "time"
"github.com/qdm12/gluetun/internal/logging"
) )
// GetMessage returns a message for the user describing if there is a newer version // GetMessage returns a message for the user describing if there is a newer version
@@ -30,35 +32,12 @@ func GetMessage(version, commitShort string, client *http.Client) (message strin
if tagName == version { if tagName == version {
return fmt.Sprintf("You are running the latest release %s", version), nil return fmt.Sprintf("You are running the latest release %s", version), nil
} }
timeSinceRelease := formatDuration(time.Since(releaseTime)) timeSinceRelease := logging.FormatDuration(time.Since(releaseTime))
return fmt.Sprintf("There is a new release %s (%s) created %s ago", return fmt.Sprintf("There is a new release %s (%s) created %s ago",
tagName, name, timeSinceRelease), tagName, name, timeSinceRelease),
nil nil
} }
func formatDuration(duration time.Duration) string {
switch {
case duration < time.Minute:
seconds := int(duration.Round(time.Second).Seconds())
if seconds < 2 {
return fmt.Sprintf("%d second", seconds)
}
return fmt.Sprintf("%d seconds", seconds)
case duration <= time.Hour:
minutes := int(duration.Round(time.Minute).Minutes())
if minutes == 1 {
return "1 minute"
}
return fmt.Sprintf("%d minutes", minutes)
case duration < 48*time.Hour:
hours := int(duration.Truncate(time.Hour).Hours())
return fmt.Sprintf("%d hours", hours)
default:
days := int(duration.Truncate(time.Hour).Hours() / 24)
return fmt.Sprintf("%d days", days)
}
}
func getLatestRelease(client *http.Client) (tagName, name string, time time.Time, err error) { func getLatestRelease(client *http.Client) (tagName, name string, time time.Time, err error) {
releases, err := getGithubReleases(client) releases, err := getGithubReleases(client)
if err != nil { if err != nil {