Maint: remove startPFCh from Openvpn loop

This commit is contained in:
Quentin McGaw (desktop)
2021-08-18 16:07:35 +00:00
parent 3ad4319163
commit bd110b960b
5 changed files with 37 additions and 17 deletions

View File

@@ -10,7 +10,7 @@ import (
) )
func (l *Loop) collectLines(ctx context.Context, done chan<- struct{}, func (l *Loop) collectLines(ctx context.Context, done chan<- struct{},
stdout, stderr chan string) { stdout, stderr chan string, tunnelUpData tunnelUpData) {
defer close(done) defer close(done)
var line string var line string
@@ -46,8 +46,7 @@ func (l *Loop) collectLines(ctx context.Context, done chan<- struct{},
l.logger.Error(line) l.logger.Error(line)
} }
if strings.Contains(line, "Initialization Sequence Completed") { if strings.Contains(line, "Initialization Sequence Completed") {
l.onTunnelUp(ctx) l.onTunnelUp(ctx, tunnelUpData)
l.startPFCh <- struct{}{}
} }
} }
} }

View File

@@ -50,7 +50,6 @@ type Loop struct {
start <-chan struct{} start <-chan struct{}
running chan<- models.LoopStatus running chan<- models.LoopStatus
userTrigger bool userTrigger bool
startPFCh chan struct{}
// Internal constant values // Internal constant values
backoffTime time.Duration backoffTime time.Duration
} }
@@ -99,7 +98,6 @@ func NewLoop(openVPNSettings configuration.OpenVPN,
stop: stop, stop: stop,
stopped: stopped, stopped: stopped,
userTrigger: true, userTrigger: true,
startPFCh: make(chan struct{}),
backoffTime: defaultBackoffTime, backoffTime: defaultBackoffTime,
} }
} }

View File

@@ -2,6 +2,8 @@ package openvpn
import ( import (
"context" "context"
"errors"
"fmt"
"time" "time"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
@@ -9,20 +11,24 @@ import (
"github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider"
) )
func (l *Loop) startPortForwarding(ctx context.Context, var (
enabled bool, portForwarder provider.PortForwarder, errObtainVPNLocalGateway = errors.New("cannot obtain VPN local gateway IP")
serverName string) { errStartPortForwarding = errors.New("cannot start port forwarding")
)
func (l *Loop) startPortForwarding(ctx context.Context, enabled bool,
portForwarder provider.PortForwarder, serverName string) (err error) {
if !enabled { if !enabled {
return return nil
} }
// only used for PIA for now // only used for PIA for now
gateway, err := l.routing.VPNLocalGatewayIP() gateway, err := l.routing.VPNLocalGatewayIP()
if err != nil { if err != nil {
l.logger.Error("cannot obtain VPN local gateway IP: " + err.Error()) return fmt.Errorf("%w: %s", errObtainVPNLocalGateway, err)
return
} }
l.logger.Info("VPN gateway IP address: " + gateway.String()) l.logger.Info("VPN gateway IP address: " + gateway.String())
pfData := portforward.StartData{ pfData := portforward.StartData{
PortForwarder: portForwarder, PortForwarder: portForwarder,
Gateway: gateway, Gateway: gateway,
@@ -31,8 +37,10 @@ func (l *Loop) startPortForwarding(ctx context.Context,
} }
_, err = l.portForward.Start(ctx, pfData) _, err = l.portForward.Start(ctx, pfData)
if err != nil { if err != nil {
l.logger.Error("cannot start port forwarding: " + err.Error()) return fmt.Errorf("%w: %s", errStartPortForwarding, err)
} }
return nil
} }
func (l *Loop) stopPortForwarding(ctx context.Context, enabled bool, func (l *Loop) stopPortForwarding(ctx context.Context, enabled bool,

View File

@@ -73,8 +73,13 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
linesCollectionCtx, linesCollectionCancel := context.WithCancel(context.Background()) linesCollectionCtx, linesCollectionCancel := context.WithCancel(context.Background())
lineCollectionDone := make(chan struct{}) lineCollectionDone := make(chan struct{})
tunnelUpData := tunnelUpData{
portForwarding: providerSettings.PortForwarding.Enabled,
serverName: connection.Hostname,
portForwarder: providerConf,
}
go l.collectLines(linesCollectionCtx, lineCollectionDone, go l.collectLines(linesCollectionCtx, lineCollectionDone,
stdoutLines, stderrLines) stdoutLines, stderrLines, tunnelUpData)
closeStreams := func() { closeStreams := func() {
linesCollectionCancel() linesCollectionCancel()
<-lineCollectionDone <-lineCollectionDone
@@ -86,9 +91,6 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) {
stayHere := true stayHere := true
for stayHere { for stayHere {
select { select {
case <-l.startPFCh:
l.startPortForwarding(ctx, providerSettings.PortForwarding.Enabled,
providerConf, connection.Hostname)
case <-ctx.Done(): case <-ctx.Done():
const pfTimeout = 100 * time.Millisecond const pfTimeout = 100 * time.Millisecond
l.stopPortForwarding(context.Background(), l.stopPortForwarding(context.Background(),

View File

@@ -4,10 +4,18 @@ import (
"context" "context"
"github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/version" "github.com/qdm12/gluetun/internal/version"
) )
func (l *Loop) onTunnelUp(ctx context.Context) { type tunnelUpData struct {
// Port forwarding
portForwarding bool
serverName string
portForwarder provider.PortForwarder
}
func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
vpnDestination, err := l.routing.VPNDestinationIP() vpnDestination, err := l.routing.VPNDestinationIP()
if err != nil { if err != nil {
l.logger.Warn(err.Error()) l.logger.Warn(err.Error())
@@ -30,4 +38,9 @@ func (l *Loop) onTunnelUp(ctx context.Context) {
l.logger.Info(message) l.logger.Info(message)
} }
} }
err = l.startPortForwarding(ctx, data.portForwarding, data.portForwarder, data.serverName)
if err != nil {
l.logger.Error(err.Error())
}
} }