From 85890520ab6d85a179ee741b75df9913bcbac3b9 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 17 Oct 2025 01:45:50 +0200 Subject: [PATCH] feat(healthcheck): combination of ICMP and TCP+TLS checks (#2923) - New option: `HEALTH_ICMP_TARGET_IP` defaults to `0.0.0.0` meaning use the VPN server public IP address. - Options removed: `HEALTH_VPN_INITIAL_DURATION` and `HEALTH_VPN_ADDITIONAL_DURATION` - times and retries are handpicked and hardcoded. - Less aggressive checks and less false positive detection --- Dockerfile | 4 +- cmd/gluetun/main.go | 17 +- internal/configuration/settings/deprecated.go | 8 +- internal/configuration/settings/health.go | 47 ++-- .../configuration/settings/healthywait.go | 76 ------ .../configuration/settings/settings_test.go | 7 +- internal/healthcheck/checker.go | 239 ++++++++++++++++++ .../{health_test.go => checker_test.go} | 25 +- internal/healthcheck/dns/dns.go | 39 +++ internal/healthcheck/handler.go | 4 +- internal/healthcheck/health.go | 122 --------- internal/healthcheck/icmp/apple_ipv4.go | 49 ++++ internal/healthcheck/icmp/echo.go | 190 ++++++++++++++ internal/healthcheck/icmp/interfaces.go | 6 + internal/healthcheck/icmp/listen.go | 35 +++ .../healthcheck/{logger.go => interfaces.go} | 3 +- internal/healthcheck/openvpn.go | 25 -- internal/healthcheck/run.go | 4 - internal/healthcheck/server.go | 24 +- internal/vpn/interfaces.go | 10 + internal/vpn/loop.go | 65 ++--- internal/vpn/openvpn.go | 18 +- internal/vpn/run.go | 16 +- internal/vpn/tunnelup.go | 42 ++- internal/vpn/wireguard.go | 13 +- 25 files changed, 722 insertions(+), 366 deletions(-) delete mode 100644 internal/configuration/settings/healthywait.go create mode 100644 internal/healthcheck/checker.go rename internal/healthcheck/{health_test.go => checker_test.go} (79%) create mode 100644 internal/healthcheck/dns/dns.go delete mode 100644 internal/healthcheck/health.go create mode 100644 internal/healthcheck/icmp/apple_ipv4.go create mode 100644 internal/healthcheck/icmp/echo.go create mode 100644 internal/healthcheck/icmp/interfaces.go create mode 100644 internal/healthcheck/icmp/listen.go rename internal/healthcheck/{logger.go => interfaces.go} (52%) delete mode 100644 internal/healthcheck/openvpn.go diff --git a/Dockerfile b/Dockerfile index 5e2fa98a..3539a9f5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -164,9 +164,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ # Health HEALTH_SERVER_ADDRESS=127.0.0.1:9999 \ HEALTH_TARGET_ADDRESS=cloudflare.com:443 \ - HEALTH_SUCCESS_WAIT_DURATION=5s \ - HEALTH_VPN_DURATION_INITIAL=6s \ - HEALTH_VPN_DURATION_ADDITION=5s \ + HEALTH_ICMP_TARGET_IP=0.0.0.0 \ # DNS over TLS DOT=on \ DOT_PROVIDERS=cloudflare \ diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index c5b47126..79061182 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -414,6 +414,13 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, return fmt.Errorf("starting public ip loop: %w", err) } + healthLogger := logger.New(log.SetComponent("healthcheck")) + healthcheckServer := healthcheck.NewServer(allSettings.Health, healthLogger) + healthServerHandler, healthServerCtx, healthServerDone := goshutdown.NewGoRoutineHandler( + "HTTP health server", goroutine.OptionTimeout(defaultShutdownTimeout)) + go healthcheckServer.Run(healthServerCtx, healthServerDone) + healthChecker := healthcheck.NewChecker(healthLogger) + updaterLogger := logger.New(log.SetComponent("updater")) unzipper := unzip.New(httpClient) @@ -424,8 +431,8 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, vpnLogger := logger.New(log.SetComponent("vpn")) vpnLooper := vpn.NewLoop(allSettings.VPN, ipv6Supported, allSettings.Firewall.VPNInputPorts, - providers, storage, ovpnConf, netLinker, firewallConf, routingConf, portForwardLooper, - cmder, publicIPLooper, dnsLooper, vpnLogger, httpClient, + providers, storage, allSettings.Health, healthChecker, healthcheckServer, ovpnConf, netLinker, firewallConf, + routingConf, portForwardLooper, cmder, publicIPLooper, dnsLooper, vpnLogger, httpClient, buildInfo, *allSettings.Version.Enabled) vpnHandler, vpnCtx, vpnDone := goshutdown.NewGoRoutineHandler( "vpn", goroutine.OptionTimeout(time.Second)) @@ -476,12 +483,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, <-httpServerReady controlGroupHandler.Add(httpServerHandler) - healthLogger := logger.New(log.SetComponent("healthcheck")) - healthcheckServer := healthcheck.NewServer(allSettings.Health, healthLogger, vpnLooper) - healthServerHandler, healthServerCtx, healthServerDone := goshutdown.NewGoRoutineHandler( - "HTTP health server", goroutine.OptionTimeout(defaultShutdownTimeout)) - go healthcheckServer.Run(healthServerCtx, healthServerDone) - orderHandler := goshutdown.NewOrderHandler("gluetun", order.OptionTimeout(totalShutdownTimeout), order.OptionOnSuccess(defaultShutdownOnSuccess), diff --git a/internal/configuration/settings/deprecated.go b/internal/configuration/settings/deprecated.go index 77010781..1e745eb0 100644 --- a/internal/configuration/settings/deprecated.go +++ b/internal/configuration/settings/deprecated.go @@ -9,9 +9,11 @@ import ( func readObsolete(r *reader.Reader) (warnings []string) { keyToMessage := map[string]string{ - "DOT_VERBOSITY": "DOT_VERBOSITY is obsolete, use LOG_LEVEL instead.", - "DOT_VERBOSITY_DETAILS": "DOT_VERBOSITY_DETAILS is obsolete because it was specific to Unbound.", - "DOT_VALIDATION_LOGLEVEL": "DOT_VALIDATION_LOGLEVEL is obsolete because DNSSEC validation is not implemented.", + "DOT_VERBOSITY": "DOT_VERBOSITY is obsolete, use LOG_LEVEL instead.", + "DOT_VERBOSITY_DETAILS": "DOT_VERBOSITY_DETAILS is obsolete because it was specific to Unbound.", + "DOT_VALIDATION_LOGLEVEL": "DOT_VALIDATION_LOGLEVEL is obsolete because DNSSEC validation is not implemented.", + "HEALTH_VPN_DURATION_INITIAL": "HEALTH_VPN_DURATION_INITIAL is obsolete", + "HEALTH_VPN_DURATION_ADDITION": "HEALTH_VPN_DURATION_ADDITION is obsolete", } sortedKeys := maps.Keys(keyToMessage) slices.Sort(sortedKeys) diff --git a/internal/configuration/settings/health.go b/internal/configuration/settings/health.go index 4035e209..c3c263fb 100644 --- a/internal/configuration/settings/health.go +++ b/internal/configuration/settings/health.go @@ -2,6 +2,7 @@ package settings import ( "fmt" + "net/netip" "os" "time" @@ -24,16 +25,13 @@ type Health struct { // HTTP server. It defaults to 500 milliseconds. ReadTimeout time.Duration // TargetAddress is the address (host or host:port) - // to TCP dial to periodically for the health check. + // to TCP TLS dial to periodically for the health check. // It cannot be the empty string in the internal state. TargetAddress string - // SuccessWait is the duration to wait to re-run the - // healthcheck after a successful healthcheck. - // It defaults to 5 seconds and cannot be zero in - // the internal state. - SuccessWait time.Duration - // VPN has health settings specific to the VPN loop. - VPN HealthyWait + // ICMPTargetIP is the IP address to use for ICMP echo requests + // in the health checker. It can be set to an unspecified address + // such that the VPN server IP is used, which is also the default behavior. + ICMPTargetIP netip.Addr } func (h Health) Validate() (err error) { @@ -42,11 +40,6 @@ func (h Health) Validate() (err error) { return fmt.Errorf("server listening address is not valid: %w", err) } - err = h.VPN.validate() - if err != nil { - return fmt.Errorf("health VPN settings: %w", err) - } - return nil } @@ -56,8 +49,7 @@ func (h *Health) copy() (copied Health) { ReadHeaderTimeout: h.ReadHeaderTimeout, ReadTimeout: h.ReadTimeout, TargetAddress: h.TargetAddress, - SuccessWait: h.SuccessWait, - VPN: h.VPN.copy(), + ICMPTargetIP: h.ICMPTargetIP, } } @@ -69,8 +61,7 @@ func (h *Health) OverrideWith(other Health) { h.ReadHeaderTimeout = gosettings.OverrideWithComparable(h.ReadHeaderTimeout, other.ReadHeaderTimeout) h.ReadTimeout = gosettings.OverrideWithComparable(h.ReadTimeout, other.ReadTimeout) h.TargetAddress = gosettings.OverrideWithComparable(h.TargetAddress, other.TargetAddress) - h.SuccessWait = gosettings.OverrideWithComparable(h.SuccessWait, other.SuccessWait) - h.VPN.overrideWith(other.VPN) + h.ICMPTargetIP = gosettings.OverrideWithComparable(h.ICMPTargetIP, other.ICMPTargetIP) } func (h *Health) SetDefaults() { @@ -80,9 +71,7 @@ func (h *Health) SetDefaults() { const defaultReadTimeout = 500 * time.Millisecond h.ReadTimeout = gosettings.DefaultComparable(h.ReadTimeout, defaultReadTimeout) h.TargetAddress = gosettings.DefaultComparable(h.TargetAddress, "cloudflare.com:443") - const defaultSuccessWait = 5 * time.Second - h.SuccessWait = gosettings.DefaultComparable(h.SuccessWait, defaultSuccessWait) - h.VPN.setDefaults() + h.ICMPTargetIP = gosettings.DefaultComparable(h.ICMPTargetIP, netip.IPv4Unspecified()) // use the VPN server IP } func (h Health) String() string { @@ -93,10 +82,11 @@ func (h Health) toLinesNode() (node *gotree.Node) { node = gotree.New("Health settings:") node.Appendf("Server listening address: %s", h.ServerAddress) node.Appendf("Target address: %s", h.TargetAddress) - node.Appendf("Duration to wait after success: %s", h.SuccessWait) - node.Appendf("Read header timeout: %s", h.ReadHeaderTimeout) - node.Appendf("Read timeout: %s", h.ReadTimeout) - node.AppendNode(h.VPN.toLinesNode("VPN")) + icmpTarget := "VPN server IP" + if !h.ICMPTargetIP.IsUnspecified() { + icmpTarget = h.ICMPTargetIP.String() + } + node.Appendf("ICMP target IP: %s", icmpTarget) return node } @@ -104,16 +94,9 @@ func (h *Health) Read(r *reader.Reader) (err error) { h.ServerAddress = r.String("HEALTH_SERVER_ADDRESS") h.TargetAddress = r.String("HEALTH_TARGET_ADDRESS", reader.RetroKeys("HEALTH_ADDRESS_TO_PING")) - - h.SuccessWait, err = r.Duration("HEALTH_SUCCESS_WAIT_DURATION") + h.ICMPTargetIP, err = r.NetipAddr("HEALTH_ICMP_TARGET_IP") if err != nil { return err } - - err = h.VPN.read(r) - if err != nil { - return fmt.Errorf("VPN health settings: %w", err) - } - return nil } diff --git a/internal/configuration/settings/healthywait.go b/internal/configuration/settings/healthywait.go deleted file mode 100644 index e30d610e..00000000 --- a/internal/configuration/settings/healthywait.go +++ /dev/null @@ -1,76 +0,0 @@ -package settings - -import ( - "time" - - "github.com/qdm12/gosettings" - "github.com/qdm12/gosettings/reader" - "github.com/qdm12/gotree" -) - -type HealthyWait struct { - // Initial is the initial duration to wait for the program - // to be healthy before taking action. - // It cannot be nil in the internal state. - Initial *time.Duration - // Addition is the duration to add to the Initial duration - // after Initial has expired to wait longer for the program - // to be healthy. - // It cannot be nil in the internal state. - Addition *time.Duration -} - -func (h HealthyWait) validate() (err error) { - return nil -} - -func (h *HealthyWait) copy() (copied HealthyWait) { - return HealthyWait{ - Initial: gosettings.CopyPointer(h.Initial), - Addition: gosettings.CopyPointer(h.Addition), - } -} - -// overrideWith overrides fields of the receiver -// settings object with any field set in the other -// settings. -func (h *HealthyWait) overrideWith(other HealthyWait) { - h.Initial = gosettings.OverrideWithPointer(h.Initial, other.Initial) - h.Addition = gosettings.OverrideWithPointer(h.Addition, other.Addition) -} - -func (h *HealthyWait) setDefaults() { - const initialDurationDefault = 6 * time.Second - const additionDurationDefault = 5 * time.Second - h.Initial = gosettings.DefaultPointer(h.Initial, initialDurationDefault) - h.Addition = gosettings.DefaultPointer(h.Addition, additionDurationDefault) -} - -func (h HealthyWait) String() string { - return h.toLinesNode("Health").String() -} - -func (h HealthyWait) toLinesNode(kind string) (node *gotree.Node) { - node = gotree.New(kind + " wait durations:") - node.Appendf("Initial duration: %s", *h.Initial) - node.Appendf("Additional duration: %s", *h.Addition) - return node -} - -func (h *HealthyWait) read(r *reader.Reader) (err error) { - h.Initial, err = r.DurationPtr( - "HEALTH_VPN_DURATION_INITIAL", - reader.RetroKeys("HEALTH_OPENVPN_DURATION_INITIAL")) - if err != nil { - return err - } - - h.Addition, err = r.DurationPtr( - "HEALTH_VPN_DURATION_ADDITION", - reader.RetroKeys("HEALTH_OPENVPN_DURATION_ADDITION")) - if err != nil { - return err - } - - return nil -} diff --git a/internal/configuration/settings/settings_test.go b/internal/configuration/settings/settings_test.go index 7aa30a15..955aa349 100644 --- a/internal/configuration/settings/settings_test.go +++ b/internal/configuration/settings/settings_test.go @@ -58,12 +58,7 @@ func Test_Settings_String(t *testing.T) { ├── Health settings: | ├── Server listening address: 127.0.0.1:9999 | ├── Target address: cloudflare.com:443 -| ├── Duration to wait after success: 5s -| ├── Read header timeout: 100ms -| ├── Read timeout: 500ms -| └── VPN wait durations: -| ├── Initial duration: 6s -| └── Additional duration: 5s +| └── ICMP target IP: VPN server IP ├── Shadowsocks server settings: | └── Enabled: no ├── HTTP proxy settings: diff --git a/internal/healthcheck/checker.go b/internal/healthcheck/checker.go new file mode 100644 index 00000000..4775021a --- /dev/null +++ b/internal/healthcheck/checker.go @@ -0,0 +1,239 @@ +package healthcheck + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/netip" + "strings" + "sync" + "time" + + "github.com/qdm12/gluetun/internal/healthcheck/dns" + "github.com/qdm12/gluetun/internal/healthcheck/icmp" +) + +type Checker struct { + tlsDialAddr string + dialer *net.Dialer + echoer *icmp.Echoer + dnsClient *dns.Client + logger Logger + icmpTarget netip.Addr + configMutex sync.Mutex + + icmpNotPermitted bool + + // Internal periodic service signals + stop context.CancelFunc + done <-chan struct{} +} + +func NewChecker(logger Logger) *Checker { + return &Checker{ + dialer: &net.Dialer{ + Resolver: &net.Resolver{ + PreferGo: true, + }, + }, + echoer: icmp.NewEchoer(logger), + dnsClient: dns.New(), + logger: logger, + } +} + +// SetConfig sets the TCP+TLS dial address and the ICMP echo IP address +// to target by the [Checker]. +// This function MUST be called before calling [Checker.Start]. +func (c *Checker) SetConfig(tlsDialAddr string, icmpTarget netip.Addr) { + c.configMutex.Lock() + defer c.configMutex.Unlock() + c.tlsDialAddr = tlsDialAddr + c.icmpTarget = icmpTarget +} + +// Start starts the checker by first running a blocking 2s-timed TCP+TLS check, +// and, on success, starts the periodic checks in a separate goroutine: +// - a "small" ICMP echo check every 15 seconds +// - a "full" TCP+TLS check every 5 minutes +// It returns a channel `runError` that receives an error if one of the periodic checks fail. +// It returns an error if the initial TCP+TLS check fails. +func (c *Checker) Start(ctx context.Context) (runError <-chan error, err error) { + if c.tlsDialAddr == "" || c.icmpTarget.IsUnspecified() { + panic("call Checker.SetConfig with non empty values before Checker.Start") + } + + // connection isn't under load yet when the checker starts, so a short + // 6 seconds timeout suffices and provides quick enough feedback that + // the new connection is not working. + const timeout = 6 * time.Second + tcpTLSCheckCtx, tcpTLSCheckCancel := context.WithTimeout(ctx, timeout) + err = tcpTLSCheck(tcpTLSCheckCtx, c.dialer, c.tlsDialAddr) + tcpTLSCheckCancel() + if err != nil { + return nil, fmt.Errorf("startup check: %w", err) + } + + ready := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + c.stop = cancel + done := make(chan struct{}) + c.done = done + const smallCheckPeriod = 15 * time.Second + smallCheckTimer := time.NewTimer(smallCheckPeriod) + const fullCheckPeriod = 5 * time.Minute + fullCheckTimer := time.NewTimer(fullCheckPeriod) + runErrorCh := make(chan error) + runError = runErrorCh + go func() { + defer close(done) + close(ready) + for { + select { + case <-ctx.Done(): + fullCheckTimer.Stop() + smallCheckTimer.Stop() + return + case <-smallCheckTimer.C: + err := c.smallPeriodicCheck(ctx) + if err != nil { + runErrorCh <- fmt.Errorf("periodic small check: %w", err) + return + } + smallCheckTimer.Reset(smallCheckPeriod) + case <-fullCheckTimer.C: + err := c.fullPeriodicCheck(ctx) + if err != nil { + runErrorCh <- fmt.Errorf("periodic full check: %w", err) + return + } + fullCheckTimer.Reset(fullCheckPeriod) + } + } + }() + <-ready + return runError, nil +} + +func (c *Checker) Stop() error { + c.stop() + <-c.done + c.icmpTarget = netip.Addr{} + return nil +} + +func (c *Checker) smallPeriodicCheck(ctx context.Context) error { + c.configMutex.Lock() + ip := c.icmpTarget + c.configMutex.Unlock() + const maxTries = 3 + const timeout = 3 * time.Second + const extraTryTime = time.Second // 1s added for each subsequent retry + check := func(ctx context.Context) error { + if c.icmpNotPermitted { + return c.dnsClient.Check(ctx) + } + err := c.echoer.Echo(ctx, ip) + if errors.Is(err, icmp.ErrNotPermitted) { + c.icmpNotPermitted = true + c.logger.Warnf("%s; permanently falling back to plaintext DNS checks.", err) + return c.dnsClient.Check(ctx) + } + return err + } + return withRetries(ctx, maxTries, timeout, extraTryTime, c.logger, "ICMP echo", check) +} + +func (c *Checker) fullPeriodicCheck(ctx context.Context) error { + const maxTries = 2 + // 10s timeout in case the connection is under stress + // See https://github.com/qdm12/gluetun/issues/2270 + const timeout = 10 * time.Second + const extraTryTime = 3 * time.Second // 3s added for each subsequent retry + check := func(ctx context.Context) error { + return tcpTLSCheck(ctx, c.dialer, c.tlsDialAddr) + } + return withRetries(ctx, maxTries, timeout, extraTryTime, c.logger, "TCP+TLS dial", check) +} + +func tcpTLSCheck(ctx context.Context, dialer *net.Dialer, targetAddress string) error { + // TODO use mullvad API if current provider is Mullvad + + address, err := makeAddressToDial(targetAddress) + if err != nil { + return err + } + + const dialNetwork = "tcp4" + connection, err := dialer.DialContext(ctx, dialNetwork, address) + if err != nil { + return fmt.Errorf("dialing: %w", err) + } + + if strings.HasSuffix(address, ":443") { + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("splitting host and port: %w", err) + } + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + ServerName: host, + } + tlsConnection := tls.Client(connection, tlsConfig) + err = tlsConnection.HandshakeContext(ctx) + if err != nil { + return fmt.Errorf("running TLS handshake: %w", err) + } + } + + err = connection.Close() + if err != nil { + return fmt.Errorf("closing connection: %w", err) + } + + return nil +} + +func makeAddressToDial(address string) (addressToDial string, err error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + addrErr := new(net.AddrError) + ok := errors.As(err, &addrErr) + if !ok || addrErr.Err != "missing port in address" { + return "", fmt.Errorf("splitting host and port from address: %w", err) + } + host = address + const defaultPort = "443" + port = defaultPort + } + address = net.JoinHostPort(host, port) + return address, nil +} + +var ErrAllCheckTriesFailed = errors.New("all check tries failed") + +func withRetries(ctx context.Context, maxTries uint, tryTimeout, extraTryTime time.Duration, + warner Logger, checkName string, check func(ctx context.Context) error, +) error { + try := uint(0) + for { + timeout := tryTimeout + time.Duration(try)*extraTryTime //nolint:gosec + checkCtx, cancel := context.WithTimeout(ctx, timeout) + err := check(checkCtx) + cancel() + switch { + case err == nil: + return nil + case ctx.Err() != nil: + return fmt.Errorf("%s context error: %w", checkName, ctx.Err()) + default: + warner.Warnf("%s attempt %d/%d failed: %v", checkName, try+1, maxTries, err) + try++ + if try == maxTries { + return fmt.Errorf("%w: %s: after %d attempts", ErrAllCheckTriesFailed, checkName, maxTries) + } + } + } +} diff --git a/internal/healthcheck/health_test.go b/internal/healthcheck/checker_test.go similarity index 79% rename from internal/healthcheck/health_test.go rename to internal/healthcheck/checker_test.go index 803cb8d0..209be53a 100644 --- a/internal/healthcheck/health_test.go +++ b/internal/healthcheck/checker_test.go @@ -7,12 +7,11 @@ import ( "testing" "time" - "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func Test_Server_healthCheck(t *testing.T) { +func Test_Checker_fullcheck(t *testing.T) { t.Parallel() t.Run("canceled real dialer", func(t *testing.T) { @@ -21,20 +20,18 @@ func Test_Server_healthCheck(t *testing.T) { dialer := &net.Dialer{} const address = "cloudflare.com:443" - server := &Server{ - dialer: dialer, - config: settings.Health{ - TargetAddress: address, - }, + checker := &Checker{ + dialer: dialer, + tlsDialAddr: address, } canceledCtx, cancel := context.WithCancel(context.Background()) cancel() - err := server.healthCheck(canceledCtx) + err := checker.fullPeriodicCheck(canceledCtx) require.Error(t, err) - assert.Contains(t, err.Error(), "operation was canceled") + assert.EqualError(t, err, "TCP+TLS dial context error: context canceled") }) t.Run("dial localhost:0", func(t *testing.T) { @@ -54,14 +51,12 @@ func Test_Server_healthCheck(t *testing.T) { listeningAddress := listener.Addr() dialer := &net.Dialer{} - server := &Server{ - dialer: dialer, - config: settings.Health{ - TargetAddress: listeningAddress.String(), - }, + checker := &Checker{ + dialer: dialer, + tlsDialAddr: listeningAddress.String(), } - err = server.healthCheck(ctx) + err = checker.fullPeriodicCheck(ctx) assert.NoError(t, err) }) diff --git a/internal/healthcheck/dns/dns.go b/internal/healthcheck/dns/dns.go new file mode 100644 index 00000000..13e7d591 --- /dev/null +++ b/internal/healthcheck/dns/dns.go @@ -0,0 +1,39 @@ +package dns + +import ( + "context" + "errors" + "fmt" + "net" +) + +// Client is a simple plaintext UDP DNS client, to be used for healthchecks. +// Note the client connects to a DNS server only over UDP on port 53, +// because we don't want to use DoT or DoH and impact the TCP connections +// when running a healthcheck. +type Client struct{} + +func New() *Client { + return &Client{} +} + +var ErrLookupNoIPs = errors.New("no IPs found from DNS lookup") + +func (c *Client) Check(ctx context.Context) error { + resolver := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { + dialer := net.Dialer{} + return dialer.DialContext(ctx, "udp", "1.1.1.1:53") + }, + } + ips, err := resolver.LookupIP(ctx, "ip", "github.com") + switch { + case err != nil: + return err + case len(ips) == 0: + return fmt.Errorf("%w", ErrLookupNoIPs) + default: + return nil + } +} diff --git a/internal/healthcheck/handler.go b/internal/healthcheck/handler.go index 20b7e888..2b9ee99b 100644 --- a/internal/healthcheck/handler.go +++ b/internal/healthcheck/handler.go @@ -9,13 +9,15 @@ import ( type handler struct { healthErr error healthErrMu sync.RWMutex + logger Logger } var errHealthcheckNotRunYet = errors.New("healthcheck did not run yet") -func newHandler() *handler { +func newHandler(logger Logger) *handler { return &handler{ healthErr: errHealthcheckNotRunYet, + logger: logger, } } diff --git a/internal/healthcheck/health.go b/internal/healthcheck/health.go deleted file mode 100644 index c93b18b2..00000000 --- a/internal/healthcheck/health.go +++ /dev/null @@ -1,122 +0,0 @@ -package healthcheck - -import ( - "context" - "crypto/tls" - "errors" - "fmt" - "net" - "strings" - "time" -) - -func (s *Server) runHealthcheckLoop(ctx context.Context, done chan<- struct{}) { - defer close(done) - - timeoutIndex := 0 - healthcheckTimeouts := []time.Duration{ - 2 * time.Second, - 4 * time.Second, - 6 * time.Second, - 8 * time.Second, - // This can be useful when the connection is under stress - // See https://github.com/qdm12/gluetun/issues/2270 - 10 * time.Second, - } - s.vpn.healthyTimer = time.NewTimer(s.vpn.healthyWait) - - for { - previousErr := s.handler.getErr() - - timeout := healthcheckTimeouts[timeoutIndex] - healthcheckCtx, healthcheckCancel := context.WithTimeout( - ctx, timeout) - err := s.healthCheck(healthcheckCtx) - healthcheckCancel() - - s.handler.setErr(err) - - switch { - case previousErr != nil && err == nil: // First success - s.logger.Info("healthy!") - timeoutIndex = 0 - s.vpn.healthyTimer.Stop() - s.vpn.healthyWait = *s.config.VPN.Initial - case previousErr == nil && err != nil: // First failure - s.logger.Debug("unhealthy: " + err.Error()) - s.vpn.healthyTimer.Stop() - s.vpn.healthyTimer = time.NewTimer(s.vpn.healthyWait) - case previousErr != nil && err != nil: // Nth failure - if timeoutIndex < len(healthcheckTimeouts)-1 { - timeoutIndex++ - } - select { - case <-s.vpn.healthyTimer.C: - timeoutIndex = 0 // retry next with the smallest timeout - s.onUnhealthyVPN(ctx, err.Error()) - default: - } - case previousErr == nil && err == nil: // Nth success - timer := time.NewTimer(s.config.SuccessWait) - select { - case <-ctx.Done(): - return - case <-timer.C: - } - } - } -} - -func (s *Server) healthCheck(ctx context.Context) (err error) { - // TODO use mullvad API if current provider is Mullvad - - address, err := makeAddressToDial(s.config.TargetAddress) - if err != nil { - return err - } - - const dialNetwork = "tcp4" - connection, err := s.dialer.DialContext(ctx, dialNetwork, address) - if err != nil { - return fmt.Errorf("dialing: %w", err) - } - - if strings.HasSuffix(address, ":443") { - host, _, err := net.SplitHostPort(address) - if err != nil { - return fmt.Errorf("splitting host and port: %w", err) - } - tlsConfig := &tls.Config{ - MinVersion: tls.VersionTLS12, - ServerName: host, - } - tlsConnection := tls.Client(connection, tlsConfig) - err = tlsConnection.HandshakeContext(ctx) - if err != nil { - return fmt.Errorf("running TLS handshake: %w", err) - } - } - - err = connection.Close() - if err != nil { - return fmt.Errorf("closing connection: %w", err) - } - - return nil -} - -func makeAddressToDial(address string) (addressToDial string, err error) { - host, port, err := net.SplitHostPort(address) - if err != nil { - addrErr := new(net.AddrError) - ok := errors.As(err, &addrErr) - if !ok || addrErr.Err != "missing port in address" { - return "", fmt.Errorf("splitting host and port from address: %w", err) - } - host = address - const defaultPort = "443" - port = defaultPort - } - address = net.JoinHostPort(host, port) - return address, nil -} diff --git a/internal/healthcheck/icmp/apple_ipv4.go b/internal/healthcheck/icmp/apple_ipv4.go new file mode 100644 index 00000000..7f9c6484 --- /dev/null +++ b/internal/healthcheck/icmp/apple_ipv4.go @@ -0,0 +1,49 @@ +package icmp + +import ( + "net" + "time" + + "golang.org/x/net/ipv4" +) + +var _ net.PacketConn = &ipv4Wrapper{} + +// ipv4Wrapper is a wrapper around ipv4.PacketConn to implement +// the net.PacketConn interface. It's only used for Darwin or iOS. +type ipv4Wrapper struct { + ipv4Conn *ipv4.PacketConn +} + +func ipv4ToNetPacketConn(ipv4 *ipv4.PacketConn) *ipv4Wrapper { + return &ipv4Wrapper{ipv4Conn: ipv4} +} + +func (i *ipv4Wrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, _, addr, err = i.ipv4Conn.ReadFrom(p) + return n, addr, err +} + +func (i *ipv4Wrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return i.ipv4Conn.WriteTo(p, nil, addr) +} + +func (i *ipv4Wrapper) Close() error { + return i.ipv4Conn.Close() +} + +func (i *ipv4Wrapper) LocalAddr() net.Addr { + return i.ipv4Conn.LocalAddr() +} + +func (i *ipv4Wrapper) SetDeadline(t time.Time) error { + return i.ipv4Conn.SetDeadline(t) +} + +func (i *ipv4Wrapper) SetReadDeadline(t time.Time) error { + return i.ipv4Conn.SetReadDeadline(t) +} + +func (i *ipv4Wrapper) SetWriteDeadline(t time.Time) error { + return i.ipv4Conn.SetWriteDeadline(t) +} diff --git a/internal/healthcheck/icmp/echo.go b/internal/healthcheck/icmp/echo.go new file mode 100644 index 00000000..608abeff --- /dev/null +++ b/internal/healthcheck/icmp/echo.go @@ -0,0 +1,190 @@ +package icmp + +import ( + "bytes" + "context" + cryptorand "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "math/rand/v2" + "net" + "net/netip" + "strings" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var ( + ErrICMPBodyUnsupported = errors.New("ICMP body type is not supported") + ErrICMPEchoDataMismatch = errors.New("ICMP data mismatch") +) + +type Echoer struct { + buffer []byte + randomSource io.Reader + logger Logger +} + +func NewEchoer(logger Logger) *Echoer { + const maxICMPEchoSize = 1500 + buffer := make([]byte, maxICMPEchoSize) + var seed [32]byte + _, _ = cryptorand.Read(seed[:]) + randomSource := rand.NewChaCha8(seed) + return &Echoer{ + buffer: buffer, + randomSource: randomSource, + logger: logger, + } +} + +var ( + ErrTimedOut = errors.New("timed out waiting for ICMP echo reply") + ErrNotPermitted = errors.New("not permitted") +) + +func (i *Echoer) Echo(ctx context.Context, ip netip.Addr) (err error) { + var ipVersion string + var conn net.PacketConn + if ip.Is4() { + ipVersion = "v4" + conn, err = listenICMPv4(ctx) + } else { + ipVersion = "v6" + conn, err = listenICMPv6(ctx) + } + if err != nil { + if strings.HasSuffix(err.Error(), "socket: operation not permitted") { + err = fmt.Errorf("%w: you can try adding NET_RAW capability to resolve this", ErrNotPermitted) + } + return fmt.Errorf("listening for ICMP packets: %w", err) + } + + go func() { + <-ctx.Done() + conn.Close() + }() + + const echoDataSize = 32 + id, message := buildMessageToSend(ipVersion, echoDataSize, i.randomSource) + + encodedMessage, err := message.Marshal(nil) + if err != nil { + return fmt.Errorf("encoding ICMP message: %w", err) + } + + _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()}) + if err != nil { + if strings.HasSuffix(err.Error(), "sendto: operation not permitted") { + err = fmt.Errorf("%w", ErrNotPermitted) + } + return fmt.Errorf("writing ICMP message: %w", err) + } + + receivedData, err := receiveEchoReply(conn, id, i.buffer, ipVersion, i.logger) + if err != nil { + if errors.Is(err, net.ErrClosed) && ctx.Err() != nil { + return fmt.Errorf("%w", ErrTimedOut) + } + return fmt.Errorf("receiving ICMP echo reply: %w", err) + } + + sentData := message.Body.(*icmp.Echo).Data //nolint:forcetypeassert + if !bytes.Equal(receivedData, sentData) { + return fmt.Errorf("%w: sent %x and received %x", ErrICMPEchoDataMismatch, sentData, receivedData) + } + + return nil +} + +func buildMessageToSend(ipVersion string, size uint, randomSource io.Reader) (id int, message *icmp.Message) { + const uint16Bytes = 2 + idBytes := make([]byte, uint16Bytes) + _, _ = randomSource.Read(idBytes) + id = int(binary.BigEndian.Uint16(idBytes)) + + var icmpType icmp.Type + switch ipVersion { + case "v4": + icmpType = ipv4.ICMPTypeEcho + case "v6": + icmpType = ipv6.ICMPTypeEchoRequest + default: + panic(fmt.Sprintf("IP version %q not supported", ipVersion)) + } + messageBodyData := make([]byte, size) + _, _ = randomSource.Read(messageBodyData) + + // See https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml#icmp-parameters-types + message = &icmp.Message{ + Type: icmpType, // echo request + Code: 0, // no code + Checksum: 0, // calculated at encoding (ipv4) or sending (ipv6) + Body: &icmp.Echo{ + ID: id, + Seq: 0, // only one packet + Data: messageBodyData, + }, + } + return id, message +} + +func receiveEchoReply(conn net.PacketConn, id int, buffer []byte, ipVersion string, logger Logger, +) (data []byte, err error) { + var icmpProtocol int + const ( + icmpv4Protocol = 1 + icmpv6Protocol = 58 + ) + switch ipVersion { + case "v4": + icmpProtocol = icmpv4Protocol + case "v6": + icmpProtocol = icmpv6Protocol + default: + panic(fmt.Sprintf("unknown IP version: %s", ipVersion)) + } + + for { + // Note we need to read the whole packet in one call to ReadFrom, so the buffer + // must be large enough to read the entire reply packet. See: + // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J + bytesRead, _, err := conn.ReadFrom(buffer) + if err != nil { + return nil, fmt.Errorf("reading from ICMP connection: %w", err) + } + packetBytes := buffer[:bytesRead] + + // Parse the ICMP message + message, err := icmp.ParseMessage(icmpProtocol, packetBytes) + if err != nil { + return nil, fmt.Errorf("parsing message: %w", err) + } + + switch body := message.Body.(type) { + case *icmp.Echo: + if id != body.ID { + logger.Warnf("ignoring ICMP echo reply mismatching expected id %d (id: %d, type: %d, code: %d, length: %d)", + id, body.ID, message.Type, message.Code, len(packetBytes)) + continue // not the ID we are looking for + } + return body.Data, nil + case *icmp.DstUnreach: + logger.Debugf("ignoring ICMP destination unreachable message (type: 3, code: %d, expected-id %d)", message.Code, id) + // See https://github.com/qdm12/gluetun/pull/2923#issuecomment-3377532249 + // on why we ignore this message. If it is actually unreachable, the timeout on waiting for + // the echo reply will do instead of returning an error error. + continue + case *icmp.TimeExceeded: + logger.Debugf("ignoring ICMP time exceeded message (type: 11, code: %d, expected-id %d)", message.Code, id) + continue + default: + return nil, fmt.Errorf("%w: %T (type %d, code %d, expected-id %d)", + ErrICMPBodyUnsupported, body, message.Type, message.Code, id) + } + } +} diff --git a/internal/healthcheck/icmp/interfaces.go b/internal/healthcheck/icmp/interfaces.go new file mode 100644 index 00000000..62979247 --- /dev/null +++ b/internal/healthcheck/icmp/interfaces.go @@ -0,0 +1,6 @@ +package icmp + +type Logger interface { + Debugf(format string, args ...any) + Warnf(format string, args ...any) +} diff --git a/internal/healthcheck/icmp/listen.go b/internal/healthcheck/icmp/listen.go new file mode 100644 index 00000000..7c01c12c --- /dev/null +++ b/internal/healthcheck/icmp/listen.go @@ -0,0 +1,35 @@ +package icmp + +import ( + "context" + "fmt" + "net" + "runtime" + + "golang.org/x/net/ipv4" +) + +func listenICMPv4(ctx context.Context) (conn net.PacketConn, err error) { + var listenConfig net.ListenConfig + const listenAddress = "" + packetConn, err := listenConfig.ListenPacket(ctx, "ip4:icmp", listenAddress) + if err != nil { + return nil, fmt.Errorf("listening for ICMP packets: %w", err) + } + + if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { + packetConn = ipv4ToNetPacketConn(ipv4.NewPacketConn(packetConn)) + } + + return packetConn, nil +} + +func listenICMPv6(ctx context.Context) (conn net.PacketConn, err error) { + var listenConfig net.ListenConfig + const listenAddress = "" + packetConn, err := listenConfig.ListenPacket(ctx, "ip6:ipv6-icmp", listenAddress) + if err != nil { + return nil, fmt.Errorf("listening for ICMPv6 packets: %w", err) + } + return packetConn, nil +} diff --git a/internal/healthcheck/logger.go b/internal/healthcheck/interfaces.go similarity index 52% rename from internal/healthcheck/logger.go rename to internal/healthcheck/interfaces.go index 52a7a546..87c349e6 100644 --- a/internal/healthcheck/logger.go +++ b/internal/healthcheck/interfaces.go @@ -1,7 +1,8 @@ package healthcheck type Logger interface { - Debug(s string) + Debugf(format string, args ...any) Info(s string) + Warnf(format string, args ...any) Error(s string) } diff --git a/internal/healthcheck/openvpn.go b/internal/healthcheck/openvpn.go deleted file mode 100644 index 6d82491f..00000000 --- a/internal/healthcheck/openvpn.go +++ /dev/null @@ -1,25 +0,0 @@ -package healthcheck - -import ( - "context" - "time" - - "github.com/qdm12/gluetun/internal/constants" -) - -type vpnHealth struct { - loop StatusApplier - healthyWait time.Duration - healthyTimer *time.Timer -} - -func (s *Server) onUnhealthyVPN(ctx context.Context, lastErrMessage string) { - s.logger.Info("program has been unhealthy for " + - s.vpn.healthyWait.String() + ": restarting VPN (healthcheck error: " + lastErrMessage + ")") - s.logger.Info("👉 See https://github.com/qdm12/gluetun-wiki/blob/main/faq/healthcheck.md") - s.logger.Info("DO NOT OPEN AN ISSUE UNLESS YOU HAVE READ AND TRIED EVERY POSSIBLE SOLUTION") - _, _ = s.vpn.loop.ApplyStatus(ctx, constants.Stopped) - _, _ = s.vpn.loop.ApplyStatus(ctx, constants.Running) - s.vpn.healthyWait += *s.config.VPN.Addition - s.vpn.healthyTimer = time.NewTimer(s.vpn.healthyWait) -} diff --git a/internal/healthcheck/run.go b/internal/healthcheck/run.go index 5f7bb7fc..3d092514 100644 --- a/internal/healthcheck/run.go +++ b/internal/healthcheck/run.go @@ -10,9 +10,6 @@ import ( func (s *Server) Run(ctx context.Context, done chan<- struct{}) { defer close(done) - loopDone := make(chan struct{}) - go s.runHealthcheckLoop(ctx, loopDone) - server := http.Server{ Addr: s.config.ServerAddress, Handler: s.handler, @@ -37,6 +34,5 @@ func (s *Server) Run(ctx context.Context, done chan<- struct{}) { s.logger.Error(err.Error()) } - <-loopDone <-serverDone } diff --git a/internal/healthcheck/server.go b/internal/healthcheck/server.go index c3a3a6be..4360ad69 100644 --- a/internal/healthcheck/server.go +++ b/internal/healthcheck/server.go @@ -2,7 +2,6 @@ package healthcheck import ( "context" - "net" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/models" @@ -11,30 +10,21 @@ import ( type Server struct { logger Logger handler *handler - dialer *net.Dialer config settings.Health - vpn vpnHealth } -func NewServer(config settings.Health, - logger Logger, vpnLoop StatusApplier, -) *Server { +func NewServer(config settings.Health, logger Logger) *Server { return &Server{ logger: logger, - handler: newHandler(), - dialer: &net.Dialer{ - Resolver: &net.Resolver{ - PreferGo: true, - }, - }, - config: config, - vpn: vpnHealth{ - loop: vpnLoop, - healthyWait: *config.VPN.Initial, - }, + handler: newHandler(logger), + config: config, } } +func (s *Server) SetError(err error) { + s.handler.setErr(err) +} + type StatusApplier interface { ApplyStatus(ctx context.Context, status models.LoopStatus) ( outcome string, err error) diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index 68103690..12b880d9 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -99,3 +99,13 @@ type CmdStarter interface { stdoutLines, stderrLines <-chan string, waitError <-chan error, startErr error) } + +type HealthChecker interface { + SetConfig(tlsDialAddr string, icmpTarget netip.Addr) + Start(ctx context.Context) (runError <-chan error, err error) + Stop() error +} + +type HealthServer interface { + SetError(err error) +} diff --git a/internal/vpn/loop.go b/internal/vpn/loop.go index 40bbf488..c83ad11a 100644 --- a/internal/vpn/loop.go +++ b/internal/vpn/loop.go @@ -13,10 +13,13 @@ import ( ) type Loop struct { - statusManager *loopstate.State - state *state.State - providers Providers - storage Storage + statusManager *loopstate.State + state *state.State + providers Providers + storage Storage + healthSettings settings.Health + healthChecker HealthChecker + healthServer HealthServer // Fixed parameters buildInfo models.BuildInformation versionInfo bool @@ -49,7 +52,8 @@ const ( ) func NewLoop(vpnSettings settings.VPN, ipv6Supported bool, vpnInputPorts []uint16, - providers Providers, storage Storage, openvpnConf OpenVPN, + providers Providers, storage Storage, healthSettings settings.Health, + healthChecker HealthChecker, healthServer HealthServer, openvpnConf OpenVPN, netLinker NetLinker, fw Firewall, routing Routing, portForward PortForward, starter CmdStarter, publicip PublicIPLoop, dnsLooper DNSLoop, @@ -65,29 +69,32 @@ func NewLoop(vpnSettings settings.VPN, ipv6Supported bool, vpnInputPorts []uint1 state := state.New(statusManager, vpnSettings) return &Loop{ - statusManager: statusManager, - state: state, - providers: providers, - storage: storage, - buildInfo: buildInfo, - versionInfo: versionInfo, - ipv6Supported: ipv6Supported, - vpnInputPorts: vpnInputPorts, - openvpnConf: openvpnConf, - netLinker: netLinker, - fw: fw, - routing: routing, - portForward: portForward, - publicip: publicip, - dnsLooper: dnsLooper, - starter: starter, - logger: logger, - client: client, - start: start, - running: running, - stop: stop, - stopped: stopped, - userTrigger: true, - backoffTime: defaultBackoffTime, + statusManager: statusManager, + state: state, + providers: providers, + storage: storage, + healthSettings: healthSettings, + healthChecker: healthChecker, + healthServer: healthServer, + buildInfo: buildInfo, + versionInfo: versionInfo, + ipv6Supported: ipv6Supported, + vpnInputPorts: vpnInputPorts, + openvpnConf: openvpnConf, + netLinker: netLinker, + fw: fw, + routing: routing, + portForward: portForward, + publicip: publicip, + dnsLooper: dnsLooper, + starter: starter, + logger: logger, + client: client, + start: start, + running: running, + stop: stop, + stopped: stopped, + userTrigger: true, + backoffTime: defaultBackoffTime, } } diff --git a/internal/vpn/openvpn.go b/internal/vpn/openvpn.go index 102640e1..c6e8bc9e 100644 --- a/internal/vpn/openvpn.go +++ b/internal/vpn/openvpn.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/openvpn" "github.com/qdm12/gluetun/internal/provider" ) @@ -14,39 +15,38 @@ import ( func setupOpenVPN(ctx context.Context, fw Firewall, openvpnConf OpenVPN, providerConf provider.Provider, settings settings.VPN, ipv6Supported bool, starter CmdStarter, - logger openvpn.Logger) (runner *openvpn.Runner, serverName string, - canPortForward bool, err error, + logger openvpn.Logger) (runner *openvpn.Runner, connection models.Connection, err error, ) { - connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) + connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) if err != nil { - return nil, "", false, fmt.Errorf("finding a valid server connection: %w", err) + return nil, models.Connection{}, fmt.Errorf("finding a valid server connection: %w", err) } lines := providerConf.OpenVPNConfig(connection, settings.OpenVPN, ipv6Supported) if err := openvpnConf.WriteConfig(lines); err != nil { - return nil, "", false, fmt.Errorf("writing configuration to file: %w", err) + return nil, models.Connection{}, fmt.Errorf("writing configuration to file: %w", err) } if *settings.OpenVPN.User != "" { err := openvpnConf.WriteAuthFile(*settings.OpenVPN.User, *settings.OpenVPN.Password) if err != nil { - return nil, "", false, fmt.Errorf("writing auth to file: %w", err) + return nil, models.Connection{}, fmt.Errorf("writing auth to file: %w", err) } } if *settings.OpenVPN.KeyPassphrase != "" { err := openvpnConf.WriteAskPassFile(*settings.OpenVPN.KeyPassphrase) if err != nil { - return nil, "", false, fmt.Errorf("writing askpass file: %w", err) + return nil, models.Connection{}, fmt.Errorf("writing askpass file: %w", err) } } if err := fw.SetVPNConnection(ctx, connection, settings.OpenVPN.Interface); err != nil { - return nil, "", false, fmt.Errorf("allowing VPN connection through firewall: %w", err) + return nil, models.Connection{}, fmt.Errorf("allowing VPN connection through firewall: %w", err) } runner = openvpn.NewRunner(settings.OpenVPN, starter, logger) - return runner, connection.ServerName, connection.PortForward, nil + return runner, connection, nil } diff --git a/internal/vpn/run.go b/internal/vpn/run.go index a0cc0274..6ca1f483 100644 --- a/internal/vpn/run.go +++ b/internal/vpn/run.go @@ -5,6 +5,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/vpn" + "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/log" ) @@ -28,17 +29,17 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { var vpnRunner interface { Run(ctx context.Context, waitError chan<- error, tunnelReady chan<- struct{}) } - var serverName, vpnInterface string - var canPortForward bool + var vpnInterface string + var connection models.Connection var err error subLogger := l.logger.New(log.SetComponent(settings.Type)) if settings.Type == vpn.OpenVPN { vpnInterface = settings.OpenVPN.Interface - vpnRunner, serverName, canPortForward, err = setupOpenVPN(ctx, l.fw, + vpnRunner, connection, err = setupOpenVPN(ctx, l.fw, l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger) } else { // Wireguard vpnInterface = settings.Wireguard.Interface - vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw, + vpnRunner, connection, err = setupWireguard(ctx, l.netLinker, l.fw, providerConf, settings, l.ipv6Supported, subLogger) } if err != nil { @@ -46,8 +47,9 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { continue } tunnelUpData := tunnelUpData{ - serverName: serverName, - canPortForward: canPortForward, + serverIP: connection.IP, + serverName: connection.ServerName, + canPortForward: connection.PortForward, portForwarder: portForwarder, vpnIntf: vpnInterface, username: settings.Provider.PortForwarding.Username, @@ -73,7 +75,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { for stayHere { select { case <-tunnelReady: - go l.onTunnelUp(openvpnCtx, tunnelUpData) + go l.onTunnelUp(openvpnCtx, ctx, tunnelUpData) case <-ctx.Done(): l.cleanup() openvpnCancel() diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index 103d65dd..536a3ce6 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -2,6 +2,7 @@ package vpn import ( "context" + "net/netip" "github.com/qdm12/dns/v2/pkg/check" "github.com/qdm12/gluetun/internal/constants" @@ -9,6 +10,8 @@ import ( ) type tunnelUpData struct { + // Healthcheck + serverIP netip.Addr // Port forwarding vpnIntf string serverName string // used for PIA @@ -18,7 +21,7 @@ type tunnelUpData struct { portForwarder PortForwarder } -func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { +func (l *Loop) onTunnelUp(ctx, loopCtx context.Context, data tunnelUpData) { l.client.CloseIdleConnections() for _, vpnPort := range l.vpnInputPorts { @@ -28,6 +31,24 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { } } + icmpTarget := l.healthSettings.ICMPTargetIP + if icmpTarget.IsUnspecified() { + icmpTarget = data.serverIP + } + l.healthChecker.SetConfig(l.healthSettings.TargetAddress, icmpTarget) + + healthErrCh, err := l.healthChecker.Start(ctx) + l.healthServer.SetError(err) + if err != nil { + // Note this restart call must be done in a separate goroutine + // from the VPN loop goroutine. + l.restartVPN(loopCtx, err) + return + } + defer func() { + _ = l.healthChecker.Stop() + }() + if *l.dnsLooper.GetSettings().DoT.Enabled { _, _ = l.dnsLooper.ApplyStatus(ctx, constants.Running) } else { @@ -37,7 +58,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { } } - err := l.publicip.RunOnce(ctx) + err = l.publicip.RunOnce(ctx) if err != nil { l.logger.Error("getting public IP address information: " + err.Error()) } @@ -56,4 +77,21 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { if err != nil { l.logger.Error(err.Error()) } + + select { + case <-ctx.Done(): + case healthErr := <-healthErrCh: + l.healthServer.SetError(healthErr) + // Note this restart call must be done in a separate goroutine + // from the VPN loop goroutine. + l.restartVPN(loopCtx, healthErr) + } +} + +func (l *Loop) restartVPN(ctx context.Context, healthErr error) { + l.logger.Warnf("restarting VPN because it failed to pass the healthcheck: %s", healthErr) + l.logger.Info("👉 See https://github.com/qdm12/gluetun-wiki/blob/main/faq/healthcheck.md") + l.logger.Info("DO NOT OPEN AN ISSUE UNLESS YOU HAVE READ AND TRIED EVERY POSSIBLE SOLUTION") + _, _ = l.ApplyStatus(ctx, constants.Stopped) + _, _ = l.ApplyStatus(ctx, constants.Running) } diff --git a/internal/vpn/wireguard.go b/internal/vpn/wireguard.go index 7f5c4246..60fc9afd 100644 --- a/internal/vpn/wireguard.go +++ b/internal/vpn/wireguard.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/provider/utils" "github.com/qdm12/gluetun/internal/wireguard" @@ -16,11 +17,11 @@ import ( func setupWireguard(ctx context.Context, netlinker NetLinker, fw Firewall, providerConf provider.Provider, settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) ( - wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error, + wireguarder *wireguard.Wireguard, connection models.Connection, err error, ) { - connection, err := providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) + connection, err = providerConf.GetConnection(settings.Provider.ServerSelection, ipv6Supported) if err != nil { - return nil, "", false, fmt.Errorf("finding a VPN server: %w", err) + return nil, models.Connection{}, fmt.Errorf("finding a VPN server: %w", err) } wireguardSettings := utils.BuildWireguardSettings(connection, settings.Wireguard, ipv6Supported) @@ -31,13 +32,13 @@ func setupWireguard(ctx context.Context, netlinker NetLinker, wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger) if err != nil { - return nil, "", false, fmt.Errorf("creating Wireguard: %w", err) + return nil, models.Connection{}, fmt.Errorf("creating Wireguard: %w", err) } err = fw.SetVPNConnection(ctx, connection, settings.Wireguard.Interface) if err != nil { - return nil, "", false, fmt.Errorf("setting firewall: %w", err) + return nil, models.Connection{}, fmt.Errorf("setting firewall: %w", err) } - return wireguarder, connection.ServerName, connection.PortForward, nil + return wireguarder, connection, nil }