PIA nextgen portforward (#242)
* Split provider/pia.go in piav3.go and piav4.go * Change port forwarding signature * Enable port forwarding parameter for PIA v4 * Fix VPN gateway IP obtention * Setup HTTP client for TLS with custom cert * Error message for regions not supporting pf
This commit is contained in:
@@ -115,8 +115,8 @@ docker run --rm --network=container:gluetun alpine:3.12 wget -qO- https://ipinfo
|
||||
| 🏁 `PASSWORD` | | | Your password |
|
||||
| `REGION` | | One of the [PIA regions](https://www.privateinternetaccess.com/pages/network/) | VPN server region |
|
||||
| `PIA_ENCRYPTION` | `strong` | `normal`, `strong` | Encryption preset |
|
||||
| `PORT_FORWARDING` | `off` | `on`, `off` | Enable port forwarding on the VPN server **for old only** |
|
||||
| `PORT_FORWARDING_STATUS_FILE` | `/tmp/gluetun/forwarded_port` | Any filepath | Filepath to store the forwarded port number **for old only** |
|
||||
| `PORT_FORWARDING` | `off` | `on`, `off` | Enable port forwarding on the VPN server |
|
||||
| `PORT_FORWARDING_STATUS_FILE` | `/tmp/gluetun/forwarded_port` | Any filepath | Filepath to store the forwarded port number |
|
||||
|
||||
- Mullvad
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -188,7 +189,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
||||
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady)
|
||||
|
||||
openvpnLooper := openvpn.NewLooper(allSettings.VPNSP, allSettings.OpenVPN, uid, gid, allServers,
|
||||
ovpnConf, firewallConf, logger, client, fileManager, streamMerger, cancel)
|
||||
ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel)
|
||||
wg.Add(1)
|
||||
// wait for restartOpenvpn
|
||||
go openvpnLooper.Run(ctx, wg)
|
||||
@@ -341,10 +342,11 @@ func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger,
|
||||
})
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dnsReadyCh <-chan struct{},
|
||||
unboundLooper dns.Looper, updaterLooper updater.Looper, publicIPLooper publicip.Looper,
|
||||
routing routing.Routing, logger logging.Logger, httpClient *http.Client,
|
||||
versionInformation, portForwardingEnabled bool, startPortForward func()) {
|
||||
versionInformation, portForwardingEnabled bool, startPortForward func(vpnGateway net.IP)) {
|
||||
defer wg.Done()
|
||||
tickerWg := &sync.WaitGroup{}
|
||||
// for linters only
|
||||
@@ -364,18 +366,35 @@ func routeReadyEvents(ctx context.Context, wg *sync.WaitGroup, tunnelReadyCh, dn
|
||||
tickerWg.Add(2)
|
||||
go unboundLooper.RunRestartTicker(restartTickerContext, tickerWg)
|
||||
go updaterLooper.RunRestartTicker(restartTickerContext, tickerWg)
|
||||
if portForwardingEnabled {
|
||||
time.AfterFunc(5*time.Second, startPortForward)
|
||||
}
|
||||
defaultInterface, _, err := routing.DefaultRoute()
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
} else {
|
||||
vpnGatewayIP, err := routing.VPNGatewayIP(defaultInterface)
|
||||
vpnDestination, err := routing.VPNDestinationIP(defaultInterface)
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
} else {
|
||||
logger.Info("Gateway VPN IP address: %s", vpnGatewayIP)
|
||||
logger.Info("VPN routing IP address: %s", vpnDestination)
|
||||
}
|
||||
}
|
||||
if portForwardingEnabled {
|
||||
// TODO make instantaneous once v3 go out of service
|
||||
const waitDuration = 5 * time.Second
|
||||
timer := time.NewTimer(waitDuration)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
continue
|
||||
case <-timer.C:
|
||||
// vpnGateway required only for PIA v4
|
||||
vpnGateway, err := routing.VPNLocalGatewayIP()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
logger.Info("VPN gateway IP address: %s", vpnGateway)
|
||||
startPortForward(vpnGateway)
|
||||
}
|
||||
}
|
||||
case <-dnsReadyCh:
|
||||
|
||||
@@ -15,6 +15,8 @@ const (
|
||||
OpenVPNAuthConf models.Filepath = "/etc/openvpn/auth.conf"
|
||||
// OpenVPNConf is the file path to the OpenVPN client configuration file
|
||||
OpenVPNConf models.Filepath = "/etc/openvpn/target.ovpn"
|
||||
// PIAPortForward is the file path to the port forwarding JSON information for PIA v4 servers
|
||||
PIAPortForward models.Filepath = "/gluetun/piaportforward.json"
|
||||
// TunnelDevice is the file path to tun device
|
||||
TunnelDevice models.Filepath = "/dev/net/tun"
|
||||
// NetRoute is the path to the file containing information on the network route
|
||||
|
||||
@@ -2,9 +2,9 @@ package constants
|
||||
|
||||
const (
|
||||
// Announcement is a message announcement
|
||||
Announcement = "Update servers information see https://github.com/qdm12/gluetun/wiki/Update-servers-information"
|
||||
Announcement = "Port forwarding is working for PIA v4 servers"
|
||||
// AnnouncementExpiration is the expiration date of the announcement in format yyyy-mm-dd
|
||||
AnnouncementExpiration = "2020-10-10"
|
||||
AnnouncementExpiration = "2020-11-15"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
29
internal/logging/duration.go
Normal file
29
internal/logging/duration.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func FormatDuration(duration time.Duration) string {
|
||||
switch {
|
||||
case duration < time.Minute:
|
||||
seconds := int(duration.Round(time.Second).Seconds())
|
||||
if seconds < 2 {
|
||||
return fmt.Sprintf("%d second", seconds)
|
||||
}
|
||||
return fmt.Sprintf("%d seconds", seconds)
|
||||
case duration <= time.Hour:
|
||||
minutes := int(duration.Round(time.Minute).Minutes())
|
||||
if minutes == 1 {
|
||||
return "1 minute"
|
||||
}
|
||||
return fmt.Sprintf("%d minutes", minutes)
|
||||
case duration < 48*time.Hour:
|
||||
hours := int(duration.Truncate(time.Hour).Hours())
|
||||
return fmt.Sprintf("%d hours", hours)
|
||||
default:
|
||||
days := int(duration.Truncate(time.Hour).Hours() / 24)
|
||||
return fmt.Sprintf("%d days", days)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package version
|
||||
package logging
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_formatDuration(t *testing.T) {
|
||||
func Test_FormatDuration(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := map[string]struct {
|
||||
duration time.Duration
|
||||
@@ -57,7 +57,7 @@ func Test_formatDuration(t *testing.T) {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := formatDuration(testCase.duration)
|
||||
s := FormatDuration(testCase.duration)
|
||||
assert.Equal(t, testCase.s, s)
|
||||
})
|
||||
}
|
||||
@@ -90,6 +90,7 @@ func (p *ProviderSettings) String() string {
|
||||
settingsList = append(settingsList,
|
||||
"Region: "+p.ServerSelection.Region,
|
||||
"Encryption preset: "+p.ExtraConfigOptions.EncryptionPreset,
|
||||
"Port forwarding: "+p.PortForwarding.String(),
|
||||
)
|
||||
case "mullvad":
|
||||
settingsList = append(settingsList,
|
||||
|
||||
@@ -2,7 +2,8 @@ package openvpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -10,17 +11,17 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
)
|
||||
|
||||
type Looper interface {
|
||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
||||
Restart()
|
||||
PortForward()
|
||||
PortForward(vpnGatewayIP net.IP)
|
||||
GetSettings() (settings settings.OpenVPN)
|
||||
SetSettings(settings settings.OpenVPN)
|
||||
GetPortForwarded() (portForwarded uint16)
|
||||
@@ -42,21 +43,22 @@ type looper struct {
|
||||
// Configurators
|
||||
conf Configurator
|
||||
fw firewall.Configurator
|
||||
routing routing.Routing
|
||||
// Other objects
|
||||
logger logging.Logger
|
||||
client network.Client
|
||||
logger, pfLogger logging.Logger
|
||||
client *http.Client
|
||||
fileManager files.FileManager
|
||||
streamMerger command.StreamMerger
|
||||
cancel context.CancelFunc
|
||||
// Internal channels
|
||||
restart chan struct{}
|
||||
portForwardSignals chan struct{}
|
||||
portForwardSignals chan net.IP
|
||||
}
|
||||
|
||||
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
||||
uid, gid int, allServers models.AllServers,
|
||||
conf Configurator, fw firewall.Configurator,
|
||||
logger logging.Logger, client network.Client, fileManager files.FileManager,
|
||||
conf Configurator, fw firewall.Configurator, routing routing.Routing,
|
||||
logger logging.Logger, client *http.Client, fileManager files.FileManager,
|
||||
streamMerger command.StreamMerger, cancel context.CancelFunc) Looper {
|
||||
return &looper{
|
||||
provider: provider,
|
||||
@@ -66,18 +68,20 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
||||
allServers: allServers,
|
||||
conf: conf,
|
||||
fw: fw,
|
||||
routing: routing,
|
||||
logger: logger.WithPrefix("openvpn: "),
|
||||
pfLogger: logger.WithPrefix("port forwarding: "),
|
||||
client: client,
|
||||
fileManager: fileManager,
|
||||
streamMerger: streamMerger,
|
||||
cancel: cancel,
|
||||
restart: make(chan struct{}),
|
||||
portForwardSignals: make(chan struct{}),
|
||||
portForwardSignals: make(chan net.IP),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) Restart() { l.restart <- struct{}{} }
|
||||
func (l *looper) PortForward() { l.portForwardSignals <- struct{}{} }
|
||||
func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway }
|
||||
|
||||
func (l *looper) GetSettings() (settings settings.OpenVPN) {
|
||||
l.settingsMutex.RLock()
|
||||
@@ -158,10 +162,12 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||
go func(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
// TODO have a way to disable pf with a context
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-l.portForwardSignals:
|
||||
l.portForward(ctx, providerConf, l.client)
|
||||
case gateway := <-l.portForwardSignals:
|
||||
wg.Add(1)
|
||||
go l.portForward(ctx, wg, providerConf, l.client, gateway)
|
||||
}
|
||||
}
|
||||
}(openvpnCtx)
|
||||
@@ -200,43 +206,25 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
func (l *looper) portForward(ctx context.Context, providerConf provider.Provider, client network.Client) {
|
||||
// portForward is a blocking operation which may or may not be infinite.
|
||||
// You should therefore always call it in a goroutine
|
||||
func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup,
|
||||
providerConf provider.Provider, client *http.Client, gateway net.IP) {
|
||||
defer wg.Done()
|
||||
settings := l.GetSettings()
|
||||
if !settings.Provider.PortForwarding.Enabled {
|
||||
return
|
||||
}
|
||||
var port uint16
|
||||
err := fmt.Errorf("")
|
||||
for err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
port, err = providerConf.GetPortForward(client)
|
||||
if err != nil {
|
||||
l.logAndWait(ctx, err)
|
||||
}
|
||||
}
|
||||
|
||||
l.logger.Info("port forwarded is %d", port)
|
||||
syncState := func(port uint16) (pfFilepath models.Filepath) {
|
||||
l.portForwardedMutex.Lock()
|
||||
if err := l.fw.RemoveAllowedPort(ctx, l.portForwarded); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
if err := l.fw.SetAllowedPort(ctx, port, string(constants.TUN)); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
l.portForwarded = port
|
||||
l.portForwardedMutex.Unlock()
|
||||
|
||||
filepath := settings.Provider.PortForwarding.Filepath
|
||||
l.logger.Info("writing forwarded port to %s", filepath)
|
||||
err = l.fileManager.WriteLinesToFile(
|
||||
string(filepath), []string{fmt.Sprintf("%d", port)},
|
||||
files.Ownership(l.uid, l.gid), files.Permissions(0400),
|
||||
)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
settings := l.GetSettings()
|
||||
return settings.Provider.PortForwarding.Filepath
|
||||
}
|
||||
providerConf.PortForward(ctx,
|
||||
client, l.fileManager, l.pfLogger,
|
||||
gateway, l.fw, syncState)
|
||||
}
|
||||
|
||||
func (l *looper) GetPortForwarded() (portForwarded uint16) {
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type cyberghost struct {
|
||||
@@ -135,6 +140,8 @@ func (c *cyberghost) BuildConf(connections []models.OpenVPNConnection, verbosity
|
||||
return lines
|
||||
}
|
||||
|
||||
func (c *cyberghost) GetPortForward(client network.Client) (port uint16, err error) {
|
||||
func (c *cyberghost) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for cyberghost")
|
||||
}
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type mullvad struct {
|
||||
@@ -134,6 +139,8 @@ func (m *mullvad) BuildConf(connections []models.OpenVPNConnection, verbosity, u
|
||||
return lines
|
||||
}
|
||||
|
||||
func (m *mullvad) GetPortForward(client network.Client) (port uint16, err error) {
|
||||
func (m *mullvad) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for mullvad")
|
||||
}
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type nordvpn struct {
|
||||
@@ -142,6 +147,8 @@ func (n *nordvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u
|
||||
return lines
|
||||
}
|
||||
|
||||
func (n *nordvpn) GetPortForward(client network.Client) (port uint16, err error) {
|
||||
func (n *nordvpn) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for nordvpn")
|
||||
}
|
||||
|
||||
@@ -1,35 +1,18 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/crypto/random"
|
||||
"github.com/qdm12/golibs/network"
|
||||
)
|
||||
|
||||
type pia struct {
|
||||
random random.Random
|
||||
servers []models.PIAServer
|
||||
}
|
||||
|
||||
func newPrivateInternetAccess(servers []models.PIAServer) *pia {
|
||||
return &pia{
|
||||
random: random.NewRandom(),
|
||||
servers: servers,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pia) filterServers(region string) (servers []models.PIAServer) {
|
||||
func filterPIAServers(servers []models.PIAServer, region string) (filtered []models.PIAServer) {
|
||||
if len(region) == 0 {
|
||||
return p.servers
|
||||
return servers
|
||||
}
|
||||
for _, server := range p.servers {
|
||||
for _, server := range servers {
|
||||
if strings.EqualFold(server.Region, region) {
|
||||
return []models.PIAServer{server}
|
||||
}
|
||||
@@ -37,8 +20,8 @@ func (p *pia) filterServers(region string) (servers []models.PIAServer) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pia) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) {
|
||||
servers := p.filterServers(selection.Region)
|
||||
func getPIAOpenVPNConnections(allServers []models.PIAServer, selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) {
|
||||
servers := filterPIAServers(allServers, selection.Region)
|
||||
if len(servers) == 0 {
|
||||
return nil, fmt.Errorf("no server found for region %q", selection.Region)
|
||||
}
|
||||
@@ -87,7 +70,7 @@ func (p *pia) GetOpenVPNConnections(selection models.ServerSelection) (connectio
|
||||
return connections, nil
|
||||
}
|
||||
|
||||
func (p *pia) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
|
||||
func buildPIAConf(connections []models.OpenVPNConnection, verbosity int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
|
||||
var X509CRL, certificate string
|
||||
if extras.EncryptionPreset == constants.PIAEncryptionPresetNormal {
|
||||
if len(cipher) == 0 {
|
||||
@@ -161,28 +144,3 @@ func (p *pia) BuildConf(connections []models.OpenVPNConnection, verbosity, uid,
|
||||
}...)
|
||||
return lines
|
||||
}
|
||||
|
||||
func (p *pia) GetPortForward(client network.Client) (port uint16, err error) {
|
||||
b, err := p.random.GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
clientID := hex.EncodeToString(b)
|
||||
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
|
||||
content, status, err := client.GetContent(url) // TODO add ctx
|
||||
switch {
|
||||
case err != nil:
|
||||
return 0, err
|
||||
case status != http.StatusOK:
|
||||
return 0, fmt.Errorf("status is %d for %s; does your PIA server support port forwarding?", status, url)
|
||||
case len(content) == 0:
|
||||
return 0, fmt.Errorf("port forwarding is already activated on this connection, has expired, or you are not connected to a PIA region that supports port forwarding")
|
||||
}
|
||||
body := struct {
|
||||
Port uint16 `json:"port"`
|
||||
}{}
|
||||
if err := json.Unmarshal(content, &body); err != nil {
|
||||
return 0, fmt.Errorf("port forwarding response: %w", err)
|
||||
}
|
||||
return body.Port, nil
|
||||
}
|
||||
|
||||
94
internal/provider/piav3.go
Normal file
94
internal/provider/piav3.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/crypto/random"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type piaV3 struct {
|
||||
random random.Random
|
||||
servers []models.PIAServer
|
||||
}
|
||||
|
||||
func newPrivateInternetAccessV3(servers []models.PIAServer) *piaV3 {
|
||||
return &piaV3{
|
||||
random: random.NewRandom(),
|
||||
servers: servers,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *piaV3) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) {
|
||||
return getPIAOpenVPNConnections(p.servers, selection)
|
||||
}
|
||||
|
||||
func (p *piaV3) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
|
||||
return buildPIAConf(connections, verbosity, root, cipher, auth, extras)
|
||||
}
|
||||
|
||||
func (p *piaV3) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
b, err := p.random.GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
pfLogger.Error(err)
|
||||
return
|
||||
}
|
||||
clientID := hex.EncodeToString(b)
|
||||
url := fmt.Sprintf("%s/?client_id=%s", constants.PIAPortForwardURL, clientID)
|
||||
response, err := client.Get(url) // TODO add ctx
|
||||
if err != nil {
|
||||
pfLogger.Error(err)
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != http.StatusOK {
|
||||
pfLogger.Error(fmt.Errorf("%s for %s; does your PIA server support port forwarding?", response.Status, url))
|
||||
return
|
||||
}
|
||||
b, err = ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
pfLogger.Error(err)
|
||||
return
|
||||
} else if len(b) == 0 {
|
||||
pfLogger.Error(fmt.Errorf("port forwarding is already activated on this connection, has expired, or you are not connected to a PIA region that supports port forwarding"))
|
||||
return
|
||||
}
|
||||
body := struct {
|
||||
Port uint16 `json:"port"`
|
||||
}{}
|
||||
if err := json.Unmarshal(b, &body); err != nil {
|
||||
pfLogger.Error(fmt.Errorf("port forwarding response: %w", err))
|
||||
return
|
||||
}
|
||||
port := body.Port
|
||||
|
||||
filepath := syncState(port)
|
||||
pfLogger.Info("Writing port to %s", filepath)
|
||||
if err := fileManager.WriteToFile(
|
||||
string(filepath), []byte(fmt.Sprintf("%d", port)),
|
||||
files.Permissions(0666),
|
||||
); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
|
||||
if err := fw.SetAllowedPort(ctx, port, string(constants.TUN)); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
if err := fw.RemoveAllowedPort(ctx, port); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
}
|
||||
412
internal/provider/piav4.go
Normal file
412
internal/provider/piav4.go
Normal file
@@ -0,0 +1,412 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
gluetunLog "github.com/qdm12/gluetun/internal/logging"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type piaV4 struct {
|
||||
servers []models.PIAServer
|
||||
timeNow func() time.Time
|
||||
}
|
||||
|
||||
func newPrivateInternetAccessV4(servers []models.PIAServer) *piaV4 {
|
||||
return &piaV4{
|
||||
servers: servers,
|
||||
timeNow: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *piaV4) GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error) {
|
||||
return getPIAOpenVPNConnections(p.servers, selection)
|
||||
}
|
||||
|
||||
func (p *piaV4) BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string) {
|
||||
return buildPIAConf(connections, verbosity, root, cipher, auth, extras)
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (p *piaV4) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
if gateway == nil {
|
||||
pfLogger.Error("aborting because: VPN gateway IP address was not found")
|
||||
return
|
||||
}
|
||||
client, err := newPIAv4HTTPClient()
|
||||
if err != nil {
|
||||
pfLogger.Error("aborting because: %s", err)
|
||||
return
|
||||
}
|
||||
defer pfLogger.Warn("loop exited")
|
||||
data, err := readPIAPortForwardData(fileManager)
|
||||
if err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
dataFound := data.Port > 0
|
||||
durationToExpiration := data.Expiration.Sub(p.timeNow())
|
||||
expired := durationToExpiration <= 0
|
||||
|
||||
if dataFound {
|
||||
pfLogger.Info("Found persistent forwarded port data for port %d", data.Port)
|
||||
if expired {
|
||||
pfLogger.Warn("Forwarded port data expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
|
||||
} else {
|
||||
pfLogger.Info("Forwarded port data expires in %s", gluetunLog.FormatDuration(durationToExpiration))
|
||||
}
|
||||
}
|
||||
|
||||
if !dataFound || expired {
|
||||
tryUntilSuccessful(ctx, pfLogger, func() error {
|
||||
data, err = refreshPIAPortForwardData(client, gateway, fileManager)
|
||||
return err
|
||||
})
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
durationToExpiration = data.Expiration.Sub(p.timeNow())
|
||||
}
|
||||
pfLogger.Info("Port forwarded is %d expiring in %s", data.Port, gluetunLog.FormatDuration(durationToExpiration))
|
||||
|
||||
// First time binding
|
||||
tryUntilSuccessful(ctx, pfLogger, func() error {
|
||||
return bindPIAPort(client, gateway, data)
|
||||
})
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
filepath := syncState(data.Port)
|
||||
pfLogger.Info("Writing port to %s", filepath)
|
||||
if err := fileManager.WriteToFile(
|
||||
string(filepath), []byte(fmt.Sprintf("%d", data.Port)),
|
||||
files.Permissions(0666),
|
||||
); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
|
||||
if err := fw.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
|
||||
expiryTimer := time.NewTimer(durationToExpiration)
|
||||
defer expiryTimer.Stop()
|
||||
const keepAlivePeriod = 15 * time.Minute
|
||||
keepAliveTicker := time.NewTicker(keepAlivePeriod)
|
||||
defer keepAliveTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
removeCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := fw.RemoveAllowedPort(removeCtx, data.Port); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
return
|
||||
case <-keepAliveTicker.C:
|
||||
if err := bindPIAPort(client, gateway, data); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
case <-expiryTimer.C:
|
||||
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
|
||||
oldPort := data.Port
|
||||
for {
|
||||
data, err = refreshPIAPortForwardData(client, gateway, fileManager)
|
||||
if err != nil {
|
||||
pfLogger.Error(err)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
durationToExpiration := data.Expiration.Sub(p.timeNow())
|
||||
pfLogger.Info("Port forwarded is %d expiring in %s", data.Port, gluetunLog.FormatDuration(durationToExpiration))
|
||||
if err := fw.RemoveAllowedPort(ctx, oldPort); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
if err := fw.SetAllowedPort(ctx, data.Port, string(constants.TUN)); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
filepath := syncState(data.Port)
|
||||
pfLogger.Info("Writing port to %s", filepath)
|
||||
if err := fileManager.WriteToFile(
|
||||
string(filepath), []byte(fmt.Sprintf("%d", data.Port)),
|
||||
files.Permissions(0666),
|
||||
); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
if err := bindPIAPort(client, gateway, data); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
keepAliveTicker.Reset(keepAlivePeriod)
|
||||
expiryTimer.Reset(durationToExpiration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newPIAv4HTTPClient() (client *http.Client, err error) {
|
||||
certificateBytes, err := base64.StdEncoding.DecodeString(constants.PIACertificateStrong)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot decode PIA root certificate: %w", err)
|
||||
}
|
||||
certificate, err := x509.ParseCertificate(certificateBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse PIA root certificate: %w", err)
|
||||
}
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(certificate)
|
||||
TLSClientConfig := &tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
InsecureSkipVerify: true, //nolint:gosec
|
||||
} // TODO fix and remove InsecureSkipVerify
|
||||
transport := http.Transport{
|
||||
TLSClientConfig: TLSClientConfig,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
const httpTimeout = 5 * time.Second
|
||||
client = &http.Client{Transport: &transport, Timeout: httpTimeout}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func refreshPIAPortForwardData(client *http.Client, gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) {
|
||||
data.Token, err = fetchPIAToken(fileManager, client)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("cannot obtain token: %w", err)
|
||||
}
|
||||
data.Port, data.Signature, data.Expiration, err = fetchPIAPortForwardData(client, gateway, data.Token)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), "connection refused") {
|
||||
return data, fmt.Errorf("cannot obtain port forwarding data: connection was refused, are you sure the region you are using supports port forwarding ;)")
|
||||
}
|
||||
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
|
||||
}
|
||||
if err := writePIAPortForwardData(fileManager, data); err != nil {
|
||||
return data, fmt.Errorf("cannot persist port forwarding information to file: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
type piaPayload struct {
|
||||
Token string `json:"token"`
|
||||
Port uint16 `json:"port"`
|
||||
Expiration time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
type piaPortForwardData struct {
|
||||
Port uint16 `json:"port"`
|
||||
Token string `json:"token"`
|
||||
Signature string `json:"signature"`
|
||||
Expiration time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
func readPIAPortForwardData(fileManager files.FileManager) (data piaPortForwardData, err error) {
|
||||
const filepath = string(constants.PIAPortForward)
|
||||
exists, err := fileManager.FileExists(filepath)
|
||||
if err != nil {
|
||||
return data, err
|
||||
} else if !exists {
|
||||
return data, nil
|
||||
}
|
||||
b, err := fileManager.ReadFile(filepath)
|
||||
if err != nil {
|
||||
return data, err
|
||||
}
|
||||
if err := json.Unmarshal(b, &data); err != nil {
|
||||
return data, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func writePIAPortForwardData(fileManager files.FileManager, data piaPortForwardData) (err error) {
|
||||
b, err := json.Marshal(&data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot encode data: %w", err)
|
||||
}
|
||||
err = fileManager.WriteToFile(string(constants.PIAPortForward), b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) {
|
||||
b, err := base64.RawStdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return 0, "", expiration, fmt.Errorf("cannot decode payload: %w", err)
|
||||
}
|
||||
var payloadData piaPayload
|
||||
if err := json.Unmarshal(b, &payloadData); err != nil {
|
||||
return 0, "", expiration, fmt.Errorf("cannot parse payload data: %w", err)
|
||||
}
|
||||
return payloadData.Port, payloadData.Token, payloadData.Expiration, nil
|
||||
}
|
||||
|
||||
func packPIAPayload(port uint16, token string, expiration time.Time) (payload string, err error) {
|
||||
payloadData := piaPayload{
|
||||
Token: token,
|
||||
Port: port,
|
||||
Expiration: expiration,
|
||||
}
|
||||
b, err := json.Marshal(&payloadData)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot serialize payload data: %w", err)
|
||||
}
|
||||
payload = base64.RawStdEncoding.EncodeToString(b)
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func fetchPIAToken(fileManager files.FileManager, client *http.Client) (token string, err error) {
|
||||
username, password, err := getOpenvpnCredentials(fileManager)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot get Openvpn credentials: %w", err)
|
||||
}
|
||||
url := url.URL{
|
||||
Scheme: "https",
|
||||
User: url.UserPassword(username, password),
|
||||
Host: "10.0.0.1",
|
||||
Path: "/authv3/generateToken",
|
||||
}
|
||||
request, err := http.NewRequest(http.MethodGet, url.String(), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
b, err := ioutil.ReadAll(response.Body)
|
||||
if response.StatusCode != http.StatusOK {
|
||||
shortenMessage := string(b)
|
||||
shortenMessage = strings.ReplaceAll(shortenMessage, "\n", "")
|
||||
shortenMessage = strings.ReplaceAll(shortenMessage, " ", " ")
|
||||
return "", fmt.Errorf("%s: response received: %q", response.Status, shortenMessage)
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.Unmarshal(b, &result); err != nil {
|
||||
return "", err
|
||||
} else if len(result.Token) == 0 {
|
||||
return "", fmt.Errorf("token is empty")
|
||||
}
|
||||
return result.Token, nil
|
||||
}
|
||||
|
||||
func getOpenvpnCredentials(fileManager files.FileManager) (username, password string, err error) {
|
||||
authData, err := fileManager.ReadFile(string(constants.OpenVPNAuthConf))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("cannot read openvpn auth file: %w", err)
|
||||
}
|
||||
lines := strings.Split(string(authData), "\n")
|
||||
if len(lines) < 2 {
|
||||
return "", "", fmt.Errorf("not enough lines (%d) in openvpn auth file", len(lines))
|
||||
}
|
||||
username, password = lines[0], lines[1]
|
||||
return username, password, nil
|
||||
}
|
||||
|
||||
func fetchPIAPortForwardData(client *http.Client, gateway net.IP, token string) (port uint16, signature string, expiration time.Time, err error) {
|
||||
queryParams := url.Values{}
|
||||
queryParams.Add("token", token)
|
||||
url := url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(gateway.String(), "19999"),
|
||||
Path: "/getSignature",
|
||||
RawQuery: queryParams.Encode(),
|
||||
}
|
||||
response, err := client.Get(url.String())
|
||||
if err != nil {
|
||||
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %s", response.Status)
|
||||
}
|
||||
b, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return 0, "", expiration, fmt.Errorf("cannot obtain signature: %w", err)
|
||||
}
|
||||
var data struct {
|
||||
Status string `json:"status"`
|
||||
Payload string `json:"payload"`
|
||||
Signature string `json:"signature"`
|
||||
}
|
||||
if err := json.Unmarshal(b, &data); err != nil {
|
||||
return 0, "", expiration, fmt.Errorf("cannot decode received data: %w", err)
|
||||
} else if data.Status != "OK" {
|
||||
return 0, "", expiration, fmt.Errorf("response received from PIA has status %s", data.Status)
|
||||
}
|
||||
|
||||
port, _, expiration, err = unpackPIAPayload(data.Payload)
|
||||
return port, data.Signature, expiration, err
|
||||
}
|
||||
|
||||
func bindPIAPort(client *http.Client, gateway net.IP, data piaPortForwardData) (err error) {
|
||||
payload, err := packPIAPayload(data.Port, data.Token, data.Expiration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
queryParams := url.Values{}
|
||||
queryParams.Add("payload", payload)
|
||||
queryParams.Add("signature", data.Signature)
|
||||
url := url.URL{
|
||||
Scheme: "https",
|
||||
Host: net.JoinHostPort(gateway.String(), "19999"),
|
||||
Path: "/bindPort",
|
||||
RawQuery: queryParams.Encode(),
|
||||
}
|
||||
|
||||
response, err := client.Get(url.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot bind port: %w", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("cannot bind port: %s", response.Status)
|
||||
}
|
||||
b, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot bind port: %w", err)
|
||||
}
|
||||
var responseData struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
if err := json.Unmarshal(b, &responseData); err != nil {
|
||||
return fmt.Errorf("cannot bind port: %w", err)
|
||||
} else if responseData.Status != "OK" {
|
||||
return fmt.Errorf("response received from PIA: %s (%s)", responseData.Status, responseData.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,24 +1,32 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
// Provider contains methods to read and modify the openvpn configuration to connect as a client
|
||||
type Provider interface {
|
||||
GetOpenVPNConnections(selection models.ServerSelection) (connections []models.OpenVPNConnection, err error)
|
||||
BuildConf(connections []models.OpenVPNConnection, verbosity, uid, gid int, root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string)
|
||||
GetPortForward(client network.Client) (port uint16, err error)
|
||||
PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath))
|
||||
}
|
||||
|
||||
func New(provider models.VPNProvider, allServers models.AllServers) Provider {
|
||||
switch provider {
|
||||
case constants.PrivateInternetAccess:
|
||||
return newPrivateInternetAccess(allServers.Pia.Servers)
|
||||
return newPrivateInternetAccessV4(allServers.Pia.Servers)
|
||||
case constants.PrivateInternetAccessOld:
|
||||
return newPrivateInternetAccess(allServers.PiaOld.Servers)
|
||||
return newPrivateInternetAccessV3(allServers.PiaOld.Servers)
|
||||
case constants.Mullvad:
|
||||
return newMullvad(allServers.Mullvad.Servers)
|
||||
case constants.Windscribe:
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type purevpn struct {
|
||||
@@ -157,6 +162,8 @@ func (p *purevpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u
|
||||
return lines
|
||||
}
|
||||
|
||||
func (p *purevpn) GetPortForward(client network.Client) (port uint16, err error) {
|
||||
func (p *purevpn) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for purevpn")
|
||||
}
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type surfshark struct {
|
||||
@@ -135,6 +140,8 @@ func (s *surfshark) BuildConf(connections []models.OpenVPNConnection, verbosity,
|
||||
return lines
|
||||
}
|
||||
|
||||
func (s *surfshark) GetPortForward(client network.Client) (port uint16, err error) {
|
||||
func (s *surfshark) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for surfshark")
|
||||
}
|
||||
|
||||
29
internal/provider/utils.go
Normal file
29
internal/provider/utils.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
func tryUntilSuccessful(ctx context.Context, logger logging.Logger, fn func() error) {
|
||||
const retryPeriod = 10 * time.Second
|
||||
for {
|
||||
err := fn()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
logger.Error(err)
|
||||
logger.Info("Trying again in %s", retryPeriod)
|
||||
timer := time.NewTimer(retryPeriod)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,17 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type vyprvpn struct {
|
||||
@@ -121,6 +126,8 @@ func (v *vyprvpn) BuildConf(connections []models.OpenVPNConnection, verbosity, u
|
||||
return lines
|
||||
}
|
||||
|
||||
func (v *vyprvpn) GetPortForward(client network.Client) (port uint16, err error) {
|
||||
func (v *vyprvpn) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for vyprvpn")
|
||||
}
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/network"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
type windscribe struct {
|
||||
@@ -133,6 +138,8 @@ func (w *windscribe) BuildConf(connections []models.OpenVPNConnection, verbosity
|
||||
return lines
|
||||
}
|
||||
|
||||
func (w *windscribe) GetPortForward(client network.Client) (port uint16, err error) {
|
||||
func (w *windscribe) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for windscribe")
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/files"
|
||||
)
|
||||
|
||||
func parseRoutingTable(data []byte) (entries []routingEntry, err error) {
|
||||
@@ -23,12 +24,16 @@ func parseRoutingTable(data []byte) (entries []routingEntry, err error) {
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
|
||||
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
|
||||
func getRoutingEntries(fileManager files.FileManager) (entries []routingEntry, err error) {
|
||||
data, err := fileManager.ReadFile(string(constants.NetRoute))
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return nil, err
|
||||
}
|
||||
entries, err := parseRoutingTable(data)
|
||||
return parseRoutingTable(data)
|
||||
}
|
||||
|
||||
func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
|
||||
entries, err := getRoutingEntries(r.fileManager)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
@@ -52,11 +57,7 @@ func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP
|
||||
}
|
||||
|
||||
func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
|
||||
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
|
||||
if err != nil {
|
||||
return defaultSubnet, err
|
||||
}
|
||||
entries, err := parseRoutingTable(data)
|
||||
entries, err := getRoutingEntries(r.fileManager)
|
||||
if err != nil {
|
||||
return defaultSubnet, err
|
||||
}
|
||||
@@ -79,11 +80,7 @@ func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) {
|
||||
}
|
||||
|
||||
func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) {
|
||||
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot check route existence: %w", err)
|
||||
}
|
||||
entries, err := parseRoutingTable(data)
|
||||
entries, err := getRoutingEntries(r.fileManager)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("cannot check route existence: %w", err)
|
||||
}
|
||||
@@ -96,12 +93,8 @@ func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *routing) VPNGatewayIP(defaultInterface string) (ip net.IP, err error) {
|
||||
data, err := r.fileManager.ReadFile(string(constants.NetRoute))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot find VPN gateway IP address: %w", err)
|
||||
}
|
||||
entries, err := parseRoutingTable(data)
|
||||
func (r *routing) VPNDestinationIP(defaultInterface string) (ip net.IP, err error) {
|
||||
entries, err := getRoutingEntries(r.fileManager)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot find VPN gateway IP address: %w", err)
|
||||
}
|
||||
@@ -115,6 +108,20 @@ func (r *routing) VPNGatewayIP(defaultInterface string) (ip net.IP, err error) {
|
||||
return nil, fmt.Errorf("cannot find VPN gateway IP address from ip routes")
|
||||
}
|
||||
|
||||
func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) {
|
||||
entries, err := getRoutingEntries(r.fileManager)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot find VPN local gateway IP address: %w", err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.iface == string(constants.TUN) &&
|
||||
entry.destination.Equal(net.IP{0, 0, 0, 0}) {
|
||||
return entry.gateway, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("cannot find VPN local gateway IP address from ip routes")
|
||||
}
|
||||
|
||||
func ipIsPrivate(ip net.IP) bool {
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return true
|
||||
|
||||
@@ -291,7 +291,7 @@ eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF
|
||||
}
|
||||
}
|
||||
|
||||
func Test_VPNGatewayIP(t *testing.T) {
|
||||
func Test_VPNDestinationIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
defaultInterface string
|
||||
@@ -334,7 +334,7 @@ eth0 x
|
||||
filemanager.EXPECT().ReadFile(string(constants.NetRoute)).
|
||||
Return(tc.data, tc.readErr).Times(1)
|
||||
r := &routing{fileManager: filemanager}
|
||||
ip, err := r.VPNGatewayIP(tc.defaultInterface)
|
||||
ip, err := r.VPNDestinationIP(tc.defaultInterface)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
|
||||
@@ -14,7 +14,8 @@ type Routing interface {
|
||||
DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error)
|
||||
DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error)
|
||||
LocalSubnet() (defaultSubnet net.IPNet, err error)
|
||||
VPNGatewayIP(defaultInterface string) (ip net.IP, err error)
|
||||
VPNDestinationIP(defaultInterface string) (ip net.IP, err error)
|
||||
VPNLocalGatewayIP() (ip net.IP, err error)
|
||||
SetDebug()
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,16 @@ import (
|
||||
|
||||
// GetPIASettings obtains PIA settings from environment variables using the params package.
|
||||
func GetPIASettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) {
|
||||
settings.Name = constants.PrivateInternetAccess
|
||||
return getPIASettings(paramsReader, constants.PrivateInternetAccess)
|
||||
}
|
||||
|
||||
// GetPIAOldSettings obtains PIA settings for the older PIA servers (pre summer 2020) from environment variables using the params package.
|
||||
func GetPIAOldSettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) {
|
||||
return getPIASettings(paramsReader, constants.PrivateInternetAccessOld)
|
||||
}
|
||||
|
||||
func getPIASettings(paramsReader params.Reader, name models.VPNProvider) (settings models.ProviderSettings, err error) {
|
||||
settings.Name = name
|
||||
settings.ServerSelection.Protocol, err = paramsReader.GetNetworkProtocol()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
@@ -29,30 +38,6 @@ func GetPIASettings(paramsReader params.Reader) (settings models.ProviderSetting
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// GetPIAOldSettings obtains PIA settings for the older PIA servers (pre summer 2020) from environment variables using the params package.
|
||||
func GetPIAOldSettings(paramsReader params.Reader) (settings models.ProviderSettings, err error) {
|
||||
settings.Name = constants.PrivateInternetAccessOld
|
||||
settings.ServerSelection.Protocol, err = paramsReader.GetNetworkProtocol()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.ServerSelection.TargetIP, err = paramsReader.GetTargetIP()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
encryptionPreset, err := paramsReader.GetPIAEncryptionPreset()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.ServerSelection.EncryptionPreset = encryptionPreset
|
||||
settings.ExtraConfigOptions.EncryptionPreset = encryptionPreset
|
||||
settings.ServerSelection.Region, err = paramsReader.GetPIAOldRegion()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
}
|
||||
settings.PortForwarding.Enabled, err = paramsReader.GetPortForwarding()
|
||||
if err != nil {
|
||||
return settings, err
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/logging"
|
||||
)
|
||||
|
||||
// GetMessage returns a message for the user describing if there is a newer version
|
||||
@@ -30,35 +32,12 @@ func GetMessage(version, commitShort string, client *http.Client) (message strin
|
||||
if tagName == version {
|
||||
return fmt.Sprintf("You are running the latest release %s", version), nil
|
||||
}
|
||||
timeSinceRelease := formatDuration(time.Since(releaseTime))
|
||||
timeSinceRelease := logging.FormatDuration(time.Since(releaseTime))
|
||||
return fmt.Sprintf("There is a new release %s (%s) created %s ago",
|
||||
tagName, name, timeSinceRelease),
|
||||
nil
|
||||
}
|
||||
|
||||
func formatDuration(duration time.Duration) string {
|
||||
switch {
|
||||
case duration < time.Minute:
|
||||
seconds := int(duration.Round(time.Second).Seconds())
|
||||
if seconds < 2 {
|
||||
return fmt.Sprintf("%d second", seconds)
|
||||
}
|
||||
return fmt.Sprintf("%d seconds", seconds)
|
||||
case duration <= time.Hour:
|
||||
minutes := int(duration.Round(time.Minute).Minutes())
|
||||
if minutes == 1 {
|
||||
return "1 minute"
|
||||
}
|
||||
return fmt.Sprintf("%d minutes", minutes)
|
||||
case duration < 48*time.Hour:
|
||||
hours := int(duration.Truncate(time.Hour).Hours())
|
||||
return fmt.Sprintf("%d hours", hours)
|
||||
default:
|
||||
days := int(duration.Truncate(time.Hour).Hours() / 24)
|
||||
return fmt.Sprintf("%d days", days)
|
||||
}
|
||||
}
|
||||
|
||||
func getLatestRelease(client *http.Client) (tagName, name string, time time.Time, err error) {
|
||||
releases, err := getGithubReleases(client)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user