PIA nextgen portforward (#242)
* Split provider/pia.go in piav3.go and piav4.go * Change port forwarding signature * Enable port forwarding parameter for PIA v4 * Fix VPN gateway IP obtention * Setup HTTP client for TLS with custom cert * Error message for regions not supporting pf
This commit is contained in:
@@ -2,7 +2,8 @@ package openvpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -10,17 +11,17 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
)
|
||||
|
||||
type Looper interface {
|
||||
Run(ctx context.Context, wg *sync.WaitGroup)
|
||||
Restart()
|
||||
PortForward()
|
||||
PortForward(vpnGatewayIP net.IP)
|
||||
GetSettings() (settings settings.OpenVPN)
|
||||
SetSettings(settings settings.OpenVPN)
|
||||
GetPortForwarded() (portForwarded uint16)
|
||||
@@ -40,23 +41,24 @@ type looper struct {
|
||||
uid int
|
||||
gid int
|
||||
// Configurators
|
||||
conf Configurator
|
||||
fw firewall.Configurator
|
||||
conf Configurator
|
||||
fw firewall.Configurator
|
||||
routing routing.Routing
|
||||
// Other objects
|
||||
logger logging.Logger
|
||||
client network.Client
|
||||
fileManager files.FileManager
|
||||
streamMerger command.StreamMerger
|
||||
cancel context.CancelFunc
|
||||
logger, pfLogger logging.Logger
|
||||
client *http.Client
|
||||
fileManager files.FileManager
|
||||
streamMerger command.StreamMerger
|
||||
cancel context.CancelFunc
|
||||
// Internal channels
|
||||
restart chan struct{}
|
||||
portForwardSignals chan struct{}
|
||||
portForwardSignals chan net.IP
|
||||
}
|
||||
|
||||
func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
||||
uid, gid int, allServers models.AllServers,
|
||||
conf Configurator, fw firewall.Configurator,
|
||||
logger logging.Logger, client network.Client, fileManager files.FileManager,
|
||||
conf Configurator, fw firewall.Configurator, routing routing.Routing,
|
||||
logger logging.Logger, client *http.Client, fileManager files.FileManager,
|
||||
streamMerger command.StreamMerger, cancel context.CancelFunc) Looper {
|
||||
return &looper{
|
||||
provider: provider,
|
||||
@@ -66,18 +68,20 @@ func NewLooper(provider models.VPNProvider, settings settings.OpenVPN,
|
||||
allServers: allServers,
|
||||
conf: conf,
|
||||
fw: fw,
|
||||
routing: routing,
|
||||
logger: logger.WithPrefix("openvpn: "),
|
||||
pfLogger: logger.WithPrefix("port forwarding: "),
|
||||
client: client,
|
||||
fileManager: fileManager,
|
||||
streamMerger: streamMerger,
|
||||
cancel: cancel,
|
||||
restart: make(chan struct{}),
|
||||
portForwardSignals: make(chan struct{}),
|
||||
portForwardSignals: make(chan net.IP),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *looper) Restart() { l.restart <- struct{}{} }
|
||||
func (l *looper) PortForward() { l.portForwardSignals <- struct{}{} }
|
||||
func (l *looper) Restart() { l.restart <- struct{}{} }
|
||||
func (l *looper) PortForward(vpnGateway net.IP) { l.portForwardSignals <- vpnGateway }
|
||||
|
||||
func (l *looper) GetSettings() (settings settings.OpenVPN) {
|
||||
l.settingsMutex.RLock()
|
||||
@@ -158,10 +162,12 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||
go func(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
// TODO have a way to disable pf with a context
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-l.portForwardSignals:
|
||||
l.portForward(ctx, providerConf, l.client)
|
||||
case gateway := <-l.portForwardSignals:
|
||||
wg.Add(1)
|
||||
go l.portForward(ctx, wg, providerConf, l.client, gateway)
|
||||
}
|
||||
}
|
||||
}(openvpnCtx)
|
||||
@@ -200,43 +206,25 @@ func (l *looper) logAndWait(ctx context.Context, err error) {
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
func (l *looper) portForward(ctx context.Context, providerConf provider.Provider, client network.Client) {
|
||||
// portForward is a blocking operation which may or may not be infinite.
|
||||
// You should therefore always call it in a goroutine
|
||||
func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup,
|
||||
providerConf provider.Provider, client *http.Client, gateway net.IP) {
|
||||
defer wg.Done()
|
||||
settings := l.GetSettings()
|
||||
if !settings.Provider.PortForwarding.Enabled {
|
||||
return
|
||||
}
|
||||
var port uint16
|
||||
err := fmt.Errorf("")
|
||||
for err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
port, err = providerConf.GetPortForward(client)
|
||||
if err != nil {
|
||||
l.logAndWait(ctx, err)
|
||||
}
|
||||
}
|
||||
|
||||
l.logger.Info("port forwarded is %d", port)
|
||||
l.portForwardedMutex.Lock()
|
||||
if err := l.fw.RemoveAllowedPort(ctx, l.portForwarded); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
if err := l.fw.SetAllowedPort(ctx, port, string(constants.TUN)); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
l.portForwarded = port
|
||||
l.portForwardedMutex.Unlock()
|
||||
|
||||
filepath := settings.Provider.PortForwarding.Filepath
|
||||
l.logger.Info("writing forwarded port to %s", filepath)
|
||||
err = l.fileManager.WriteLinesToFile(
|
||||
string(filepath), []string{fmt.Sprintf("%d", port)},
|
||||
files.Ownership(l.uid, l.gid), files.Permissions(0400),
|
||||
)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
syncState := func(port uint16) (pfFilepath models.Filepath) {
|
||||
l.portForwardedMutex.Lock()
|
||||
l.portForwarded = port
|
||||
l.portForwardedMutex.Unlock()
|
||||
settings := l.GetSettings()
|
||||
return settings.Provider.PortForwarding.Filepath
|
||||
}
|
||||
providerConf.PortForward(ctx,
|
||||
client, l.fileManager, l.pfLogger,
|
||||
gateway, l.fw, syncState)
|
||||
}
|
||||
|
||||
func (l *looper) GetPortForwarded() (portForwarded uint16) {
|
||||
|
||||
Reference in New Issue
Block a user