diff --git a/internal/openvpn/logs.go b/internal/openvpn/logs.go index 4ad2c912..c880e88d 100644 --- a/internal/openvpn/logs.go +++ b/internal/openvpn/logs.go @@ -10,7 +10,7 @@ import ( ) func (l *Loop) collectLines(ctx context.Context, done chan<- struct{}, - stdout, stderr chan string) { + stdout, stderr chan string, tunnelUpData tunnelUpData) { defer close(done) var line string @@ -46,8 +46,7 @@ func (l *Loop) collectLines(ctx context.Context, done chan<- struct{}, l.logger.Error(line) } if strings.Contains(line, "Initialization Sequence Completed") { - l.onTunnelUp(ctx) - l.startPFCh <- struct{}{} + l.onTunnelUp(ctx, tunnelUpData) } } } diff --git a/internal/openvpn/loop.go b/internal/openvpn/loop.go index 11200e2e..7cb00c53 100644 --- a/internal/openvpn/loop.go +++ b/internal/openvpn/loop.go @@ -50,7 +50,6 @@ type Loop struct { start <-chan struct{} running chan<- models.LoopStatus userTrigger bool - startPFCh chan struct{} // Internal constant values backoffTime time.Duration } @@ -99,7 +98,6 @@ func NewLoop(openVPNSettings configuration.OpenVPN, stop: stop, stopped: stopped, userTrigger: true, - startPFCh: make(chan struct{}), backoffTime: defaultBackoffTime, } } diff --git a/internal/openvpn/portforward.go b/internal/openvpn/portforward.go index 40db6f3f..0c02f34a 100644 --- a/internal/openvpn/portforward.go +++ b/internal/openvpn/portforward.go @@ -2,6 +2,8 @@ package openvpn import ( "context" + "errors" + "fmt" "time" "github.com/qdm12/gluetun/internal/constants" @@ -9,20 +11,24 @@ import ( "github.com/qdm12/gluetun/internal/provider" ) -func (l *Loop) startPortForwarding(ctx context.Context, - enabled bool, portForwarder provider.PortForwarder, - serverName string) { +var ( + errObtainVPNLocalGateway = errors.New("cannot obtain VPN local gateway IP") + errStartPortForwarding = errors.New("cannot start port forwarding") +) + +func (l *Loop) startPortForwarding(ctx context.Context, enabled bool, + portForwarder provider.PortForwarder, serverName string) (err error) { if !enabled { - return + return nil } // only used for PIA for now gateway, err := l.routing.VPNLocalGatewayIP() if err != nil { - l.logger.Error("cannot obtain VPN local gateway IP: " + err.Error()) - return + return fmt.Errorf("%w: %s", errObtainVPNLocalGateway, err) } l.logger.Info("VPN gateway IP address: " + gateway.String()) + pfData := portforward.StartData{ PortForwarder: portForwarder, Gateway: gateway, @@ -31,8 +37,10 @@ func (l *Loop) startPortForwarding(ctx context.Context, } _, err = l.portForward.Start(ctx, pfData) if err != nil { - l.logger.Error("cannot start port forwarding: " + err.Error()) + return fmt.Errorf("%w: %s", errStartPortForwarding, err) } + + return nil } func (l *Loop) stopPortForwarding(ctx context.Context, enabled bool, diff --git a/internal/openvpn/run.go b/internal/openvpn/run.go index 74b13ddf..26e01a77 100644 --- a/internal/openvpn/run.go +++ b/internal/openvpn/run.go @@ -73,8 +73,13 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { linesCollectionCtx, linesCollectionCancel := context.WithCancel(context.Background()) lineCollectionDone := make(chan struct{}) + tunnelUpData := tunnelUpData{ + portForwarding: providerSettings.PortForwarding.Enabled, + serverName: connection.Hostname, + portForwarder: providerConf, + } go l.collectLines(linesCollectionCtx, lineCollectionDone, - stdoutLines, stderrLines) + stdoutLines, stderrLines, tunnelUpData) closeStreams := func() { linesCollectionCancel() <-lineCollectionDone @@ -86,9 +91,6 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { stayHere := true for stayHere { select { - case <-l.startPFCh: - l.startPortForwarding(ctx, providerSettings.PortForwarding.Enabled, - providerConf, connection.Hostname) case <-ctx.Done(): const pfTimeout = 100 * time.Millisecond l.stopPortForwarding(context.Background(), diff --git a/internal/openvpn/tunnelup.go b/internal/openvpn/tunnelup.go index 7515e712..0c21ef25 100644 --- a/internal/openvpn/tunnelup.go +++ b/internal/openvpn/tunnelup.go @@ -4,10 +4,18 @@ import ( "context" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/version" ) -func (l *Loop) onTunnelUp(ctx context.Context) { +type tunnelUpData struct { + // Port forwarding + portForwarding bool + serverName string + portForwarder provider.PortForwarder +} + +func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { vpnDestination, err := l.routing.VPNDestinationIP() if err != nil { l.logger.Warn(err.Error()) @@ -30,4 +38,9 @@ func (l *Loop) onTunnelUp(ctx context.Context) { l.logger.Info(message) } } + + err = l.startPortForwarding(ctx, data.portForwarding, data.portForwarder, data.serverName) + if err != nil { + l.logger.Error(err.Error()) + } }