192 lines
5.7 KiB
Go
192 lines
5.7 KiB
Go
package vpn
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/netip"
|
|
"time"
|
|
|
|
"github.com/qdm12/dns/v2/pkg/check"
|
|
"github.com/qdm12/gluetun/internal/constants"
|
|
"github.com/qdm12/gluetun/internal/pmtud"
|
|
"github.com/qdm12/gluetun/internal/version"
|
|
"github.com/qdm12/log"
|
|
)
|
|
|
|
type tunnelUpData struct {
|
|
// Healthcheck
|
|
serverIP netip.Addr
|
|
// vpnType is used for path MTU discovery to find the protocol overhead.
|
|
// It can be "wireguard" or "openvpn".
|
|
vpnType string
|
|
// Port forwarding
|
|
vpnIntf string
|
|
serverName string // used for PIA
|
|
canPortForward bool // used for PIA
|
|
username string // used for PIA
|
|
password string // used for PIA
|
|
portForwarder PortForwarder
|
|
}
|
|
|
|
func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) {
|
|
l.client.CloseIdleConnections()
|
|
|
|
for _, vpnPort := range l.vpnInputPorts {
|
|
err := l.fw.SetAllowedPort(ctx, vpnPort, data.vpnIntf)
|
|
if err != nil {
|
|
l.logger.Error("cannot allow input port through firewall: " + err.Error())
|
|
}
|
|
}
|
|
|
|
icmpTarget := l.healthSettings.ICMPTargetIP
|
|
if icmpTarget.IsUnspecified() {
|
|
icmpTarget = data.serverIP
|
|
}
|
|
l.healthChecker.SetConfig(l.healthSettings.TargetAddress, icmpTarget)
|
|
|
|
healthErrCh, err := l.healthChecker.Start(ctx)
|
|
l.healthServer.SetError(err)
|
|
if err != nil {
|
|
// Note this restart call must be done in a separate goroutine
|
|
// from the VPN loop goroutine.
|
|
l.restartVPN(loopCtx, err)
|
|
return
|
|
}
|
|
|
|
mtuLogger := l.logger.New(log.SetComponent("MTU discovery"))
|
|
err = updateToMaxMTU(ctx, data.vpnIntf, data.vpnType,
|
|
l.netLinker, l.routing, mtuLogger)
|
|
if err != nil {
|
|
mtuLogger.Error(err.Error())
|
|
}
|
|
|
|
if *l.dnsLooper.GetSettings().ServerEnabled {
|
|
_, _ = l.dnsLooper.ApplyStatus(ctx, constants.Running)
|
|
} else {
|
|
err := check.WaitForDNS(ctx, check.Settings{})
|
|
if err != nil {
|
|
l.logger.Error("waiting for DNS to be ready: " + err.Error())
|
|
}
|
|
}
|
|
|
|
err = l.publicip.RunOnce(ctx)
|
|
if err != nil {
|
|
l.logger.Error("getting public IP address information: " + err.Error())
|
|
}
|
|
|
|
if l.versionInfo {
|
|
l.versionInfo = false // only get the version information once
|
|
message, err := version.GetMessage(ctx, l.buildInfo, l.client)
|
|
if err != nil {
|
|
l.logger.Error("cannot get version information: " + err.Error())
|
|
} else {
|
|
l.logger.Info(message)
|
|
}
|
|
}
|
|
|
|
err = l.startPortForwarding(data)
|
|
if err != nil {
|
|
l.logger.Error(err.Error())
|
|
}
|
|
|
|
l.collectHealthErrors(ctx, loopCtx, healthErrCh)
|
|
}
|
|
|
|
func (l *Loop) collectHealthErrors(ctx, loopCtx context.Context, healthErrCh <-chan error) {
|
|
var previousHealthErr error
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = l.healthChecker.Stop()
|
|
return
|
|
case healthErr := <-healthErrCh:
|
|
l.healthServer.SetError(healthErr)
|
|
if healthErr != nil {
|
|
if *l.healthSettings.RestartVPN {
|
|
// Note this restart call must be done in a separate goroutine
|
|
// from the VPN loop goroutine.
|
|
_ = l.healthChecker.Stop()
|
|
l.restartVPN(loopCtx, healthErr)
|
|
return
|
|
}
|
|
l.logger.Warnf("healthcheck failed: %s", healthErr)
|
|
l.logger.Info("👉 See https://github.com/qdm12/gluetun-wiki/blob/main/faq/healthcheck.md")
|
|
} else if previousHealthErr != nil {
|
|
l.logger.Info("healthcheck passed successfully after previous failure(s)")
|
|
}
|
|
previousHealthErr = healthErr
|
|
}
|
|
}
|
|
}
|
|
|
|
func (l *Loop) restartVPN(ctx context.Context, healthErr error) {
|
|
l.logger.Warnf("restarting VPN because it failed to pass the healthcheck: %s", healthErr)
|
|
l.logger.Info("👉 See https://github.com/qdm12/gluetun-wiki/blob/main/faq/healthcheck.md")
|
|
l.logger.Info("DO NOT OPEN AN ISSUE UNLESS YOU HAVE READ AND TRIED EVERY POSSIBLE SOLUTION")
|
|
_, _ = l.ApplyStatus(ctx, constants.Stopped)
|
|
_, _ = l.ApplyStatus(ctx, constants.Running)
|
|
}
|
|
|
|
var errVPNTypeUnknown = errors.New("unknown VPN type")
|
|
|
|
func updateToMaxMTU(ctx context.Context, vpnInterface string,
|
|
vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger,
|
|
) error {
|
|
logger.Info("finding maximum MTU, this can take up to 4 seconds")
|
|
|
|
vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface)
|
|
if err != nil {
|
|
return fmt.Errorf("getting VPN gateway IP address: %w", err)
|
|
}
|
|
|
|
link, err := netlinker.LinkByName(vpnInterface)
|
|
if err != nil {
|
|
return fmt.Errorf("getting VPN interface by name: %w", err)
|
|
}
|
|
|
|
originalMTU := link.MTU
|
|
|
|
// Note: no point testing for an MTU of 1500, it will never work due to the VPN
|
|
// protocol overhead, so start lower than 1500 according to the protocol used.
|
|
const physicalLinkMTU = 1500
|
|
vpnLinkMTU := physicalLinkMTU
|
|
switch vpnType {
|
|
case "wireguard":
|
|
vpnLinkMTU -= 60 // Wireguard overhead
|
|
case "openvpn":
|
|
vpnLinkMTU -= 41 // OpenVPN overhead
|
|
default:
|
|
return fmt.Errorf("%w: %q", errVPNTypeUnknown, vpnType)
|
|
}
|
|
|
|
// Setting the VPN link MTU to 1500 might interrupt the connection until
|
|
// the new MTU is set again, but this is necessary to find the highest valid MTU.
|
|
logger.Debugf("VPN interface %s MTU temporarily set to %d", vpnInterface, vpnLinkMTU)
|
|
|
|
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
|
|
if err != nil {
|
|
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
|
}
|
|
|
|
const pingTimeout = time.Second
|
|
vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger)
|
|
switch {
|
|
case err == nil:
|
|
logger.Infof("setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)
|
|
case errors.Is(err, pmtud.ErrMTUNotFound) || errors.Is(err, pmtud.ErrICMPNotPermitted):
|
|
vpnLinkMTU = int(originalMTU)
|
|
logger.Infof("reverting VPN interface %s MTU to %d (due to: %s)",
|
|
vpnInterface, originalMTU, err)
|
|
default:
|
|
return fmt.Errorf("path MTU discovering: %w", err)
|
|
}
|
|
|
|
err = netlinker.LinkSetMTU(link, vpnLinkMTU)
|
|
if err != nil {
|
|
return fmt.Errorf("setting VPN interface %s MTU to %d: %w", vpnInterface, vpnLinkMTU, err)
|
|
}
|
|
|
|
return nil
|
|
}
|