Maint: remove startPFCh from Openvpn loop

This commit is contained in:
Quentin McGaw (desktop)
2021-08-18 16:07:35 +00:00
parent 3ad4319163
commit bd110b960b
5 changed files with 37 additions and 17 deletions

View File

@@ -10,7 +10,7 @@ import (
)
func (l *Loop) collectLines(ctx context.Context, done chan<- struct{},
stdout, stderr chan string) {
stdout, stderr chan string, tunnelUpData tunnelUpData) {
defer close(done)
var line string
@@ -46,8 +46,7 @@ func (l *Loop) collectLines(ctx context.Context, done chan<- struct{},
l.logger.Error(line)
}
if strings.Contains(line, "Initialization Sequence Completed") {
l.onTunnelUp(ctx)
l.startPFCh <- struct{}{}
l.onTunnelUp(ctx, tunnelUpData)
}
}
}

View File

@@ -50,7 +50,6 @@ type Loop struct {
start <-chan struct{}
running chan<- models.LoopStatus
userTrigger bool
startPFCh chan struct{}
// Internal constant values
backoffTime time.Duration
}
@@ -99,7 +98,6 @@ func NewLoop(openVPNSettings configuration.OpenVPN,
stop: stop,
stopped: stopped,
userTrigger: true,
startPFCh: make(chan struct{}),
backoffTime: defaultBackoffTime,
}
}

View File

@@ -2,6 +2,8 @@ package openvpn
import (
"context"
"errors"
"fmt"
"time"
"github.com/qdm12/gluetun/internal/constants"
@@ -9,20 +11,24 @@ import (
"github.com/qdm12/gluetun/internal/provider"
)
func (l *Loop) startPortForwarding(ctx context.Context,
enabled bool, portForwarder provider.PortForwarder,
serverName string) {
var (
errObtainVPNLocalGateway = errors.New("cannot obtain VPN local gateway IP")
errStartPortForwarding = errors.New("cannot start port forwarding")
)
func (l *Loop) startPortForwarding(ctx context.Context, enabled bool,
portForwarder provider.PortForwarder, serverName string) (err error) {
if !enabled {
return
return nil
}
// only used for PIA for now
gateway, err := l.routing.VPNLocalGatewayIP()
if err != nil {
l.logger.Error("cannot obtain VPN local gateway IP: " + err.Error())
return
return fmt.Errorf("%w: %s", errObtainVPNLocalGateway, err)
}
l.logger.Info("VPN gateway IP address: " + gateway.String())
pfData := portforward.StartData{
PortForwarder: portForwarder,
Gateway: gateway,
@@ -31,8 +37,10 @@ func (l *Loop) startPortForwarding(ctx context.Context,
}
_, err = l.portForward.Start(ctx, pfData)
if err != nil {
l.logger.Error("cannot start port forwarding: " + err.Error())
return fmt.Errorf("%w: %s", errStartPortForwarding, err)
}
return nil
}
func (l *Loop) stopPortForwarding(ctx context.Context, enabled bool,

View File

@@ -73,8 +73,13 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
linesCollectionCtx, linesCollectionCancel := context.WithCancel(context.Background())
lineCollectionDone := make(chan struct{})
tunnelUpData := tunnelUpData{
portForwarding: providerSettings.PortForwarding.Enabled,
serverName: connection.Hostname,
portForwarder: providerConf,
}
go l.collectLines(linesCollectionCtx, lineCollectionDone,
stdoutLines, stderrLines)
stdoutLines, stderrLines, tunnelUpData)
closeStreams := func() {
linesCollectionCancel()
<-lineCollectionDone
@@ -86,9 +91,6 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
stayHere := true
for stayHere {
select {
case <-l.startPFCh:
l.startPortForwarding(ctx, providerSettings.PortForwarding.Enabled,
providerConf, connection.Hostname)
case <-ctx.Done():
const pfTimeout = 100 * time.Millisecond
l.stopPortForwarding(context.Background(),

View File

@@ -4,10 +4,18 @@ import (
"context"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/version"
)
func (l *Loop) onTunnelUp(ctx context.Context) {
type tunnelUpData struct {
// Port forwarding
portForwarding bool
serverName string
portForwarder provider.PortForwarder
}
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
vpnDestination, err := l.routing.VPNDestinationIP()
if err != nil {
l.logger.Warn(err.Error())
@@ -30,4 +38,9 @@ func (l *Loop) onTunnelUp(ctx context.Context) {
l.logger.Info(message)
}
}
err = l.startPortForwarding(ctx, data.portForwarding, data.portForwarder, data.serverName)
if err != nil {
l.logger.Error(err.Error())
}
}