diff --git a/internal/vpn/run.go b/internal/vpn/run.go index 65b44fac..971e76c0 100644 --- a/internal/vpn/run.go +++ b/internal/vpn/run.go @@ -47,7 +47,6 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { continue } tunnelUpData := tunnelUpData{ - serverIP: connection.IP, vpnType: settings.Type, serverName: connection.ServerName, canPortForward: connection.PortForward, diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index e5c82ce3..3bba4dc0 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net/netip" "time" "github.com/qdm12/dns/v2/pkg/check" @@ -18,8 +17,6 @@ type tunnelUpData struct { // vpnIntf is the name of the VPN network interface // which is used both for port forwarding and MTU discovery vpnIntf string - // serverIP is used for path MTU discovery - serverIP netip.Addr // vpnType is used for path MTU discovery to find the protocol overhead. // It can be "wireguard" or "openvpn". vpnType string @@ -35,11 +32,10 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { l.client.CloseIdleConnections() mtuLogger := l.logger.New(log.SetComponent("MTU discovery")) - mtuLogger.Info("finding maximum MTU, this can take up to 4 seconds") - err := updateToMaxMTU(ctx, data.vpnIntf, data.serverIP, data.vpnType, - l.netLinker, mtuLogger) + err := updateToMaxMTU(ctx, data.vpnIntf, data.vpnType, + l.netLinker, l.routing, mtuLogger) if err != nil { - l.logger.Error(err.Error()) + mtuLogger.Error(err.Error()) } for _, vpnPort := range l.vpnInputPorts { @@ -82,8 +78,15 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { var errVPNTypeUnknown = errors.New("unknown VPN type") func updateToMaxMTU(ctx context.Context, vpnInterface string, - serverIP netip.Addr, vpnType string, netlinker NetLinker, logger *log.Logger, + vpnType string, netlinker NetLinker, routing Routing, logger *log.Logger, ) error { + logger.Info("finding maximum MTU, this can take up to 4 seconds") + + vpnGatewayIP, err := routing.VPNLocalGatewayIP(vpnInterface) + if err != nil { + return fmt.Errorf("getting VPN gateway IP address: %w", err) + } + link, err := netlinker.LinkByName(vpnInterface) if err != nil { return fmt.Errorf("getting VPN interface by name: %w", err) @@ -114,7 +117,7 @@ func updateToMaxMTU(ctx context.Context, vpnInterface string, } const pingTimeout = time.Second - vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, serverIP, vpnLinkMTU, pingTimeout, logger) + vpnLinkMTU, err = pmtud.PathMTUDiscover(ctx, vpnGatewayIP, vpnLinkMTU, pingTimeout, logger) switch { case err == nil: logger.Infof("Setting VPN interface %s MTU to maximum valid MTU %d", vpnInterface, vpnLinkMTU)