package pmtud import ( "context" "errors" "fmt" "math" "net" "net/netip" "time" "golang.org/x/net/icmp" ) // PathMTUDiscover discovers the maximum MTU for the path to the given ip address. // If the physicalLinkMTU is zero, it defaults to 1500 which is the ethernet standard MTU. // If the pingTimeout is zero, it defaults to 1 second. // If the logger is nil, a no-op logger is used. func PathMTUDiscover(ctx context.Context, ip netip.Addr, physicalLinkMTU int, pingTimeout time.Duration, logger Logger) ( mtu int, err error, ) { if physicalLinkMTU == 0 { const ethernetStandardMTU = 1500 physicalLinkMTU = ethernetStandardMTU } if pingTimeout == 0 { pingTimeout = time.Second } if logger == nil { logger = &noopLogger{} } if ip.Is4() { logger.Debug("finding IPv4 next hop MTU") mtu, err = findIPv4NextHopMTU(ctx, ip, physicalLinkMTU, pingTimeout, logger) switch { case err == nil: return mtu, nil case errors.Is(err, net.ErrClosed): // blackhole default: return 0, fmt.Errorf("finding IPv4 next hop MTU: %w", err) } } else { logger.Debug("requesting IPv6 ICMP packet-too-big reply") mtu, err = getIPv6PacketTooBig(ctx, ip, physicalLinkMTU, pingTimeout, logger) switch { case err == nil: return mtu, nil case errors.Is(err, net.ErrClosed): // blackhole default: return 0, fmt.Errorf("getting IPv6 packet-too-big message: %w", err) } } // Fall back method: send echo requests with different packet // sizes and check which ones succeed to find the maximum MTU. logger.Debug("falling back to sending different sized echo packets") minMTU := minIPv4MTU if ip.Is6() { minMTU = minIPv6MTU } return pmtudMultiSizes(ctx, ip, minMTU, physicalLinkMTU, pingTimeout, logger) } type pmtudTestUnit struct { mtu int echoID uint16 sentBytes int ok bool } func pmtudMultiSizes(ctx context.Context, ip netip.Addr, minMTU, maxPossibleMTU int, pingTimeout time.Duration, logger Logger, ) (maxMTU int, err error) { var ipVersion string var conn net.PacketConn if ip.Is4() { ipVersion = "v4" conn, err = listenICMPv4(ctx) } else { ipVersion = "v6" conn, err = listenICMPv6(ctx) } if err != nil { return 0, fmt.Errorf("listening for ICMP packets: %w", err) } mtusToTest := makeMTUsToTest(minMTU, maxPossibleMTU) if len(mtusToTest) == 0 { return minMTU, nil } logger.Debugf("testing the following MTUs: %v", mtusToTest) tests := make([]pmtudTestUnit, len(mtusToTest)) for i := range mtusToTest { tests[i] = pmtudTestUnit{mtu: mtusToTest[i]} } timedCtx, cancel := context.WithTimeout(ctx, pingTimeout) defer cancel() go func() { <-timedCtx.Done() conn.Close() }() for i := range tests { id, message := buildMessageToSend(ipVersion, tests[i].mtu) tests[i].echoID = id encodedMessage, err := message.Marshal(nil) if err != nil { return 0, fmt.Errorf("encoding ICMP message: %w", err) } tests[i].sentBytes = len(encodedMessage) _, err = conn.WriteTo(encodedMessage, &net.IPAddr{IP: ip.AsSlice()}) if err != nil { return 0, fmt.Errorf("writing ICMP message: %w", err) } } err = collectReplies(conn, ipVersion, tests, logger) switch { case err == nil: // max possible MTU is working return tests[len(tests)-1].mtu, nil case err != nil && errors.Is(err, net.ErrClosed): // we have timeouts (IPv4 testing or IPv6 PMTUD blackholes) // so find the highest MTU which worked. // Note we start from index len(tests) - 2 since the max MTU // cannot be working if we had a timeout. for i := len(tests) - 2; i >= 0; i-- { //nolint:mnd if tests[i].ok { return pmtudMultiSizes(ctx, ip, tests[i].mtu, tests[i+1].mtu-1, pingTimeout, logger) } } if tests[0].mtu == minMTU+1 { // All MTUs failed. return minMTU, nil } // Re-test with MTUs between the minimum MTU // and the smallest next MTU we tested. return pmtudMultiSizes(ctx, ip, minMTU, tests[0].mtu-1, pingTimeout, logger) case err != nil: return 0, fmt.Errorf("collecting ICMP echo replies: %w", err) default: panic("unreachable") } } // Create the MTU slice of length 11 such that: // - the first element is the minMTU plus the step // - the last element is the maxMTU // - elements in-between are separated as close to each other // - Don't make the minMTU part of the MTUs to test since // it's assumed it's already working. // The number 11 is chosen to find the final MTU in 3 searches; // to find it in 2 searches requires 37 parallel queries which // could be blocked by firewalls. func makeMTUsToTest(minMTU, maxMTU int) (mtus []int) { const mtusLength = 11 // find the final MTU in 3 searches diff := maxMTU - minMTU switch { case minMTU > maxMTU: panic("minMTU > maxMTU") case diff <= mtusLength: mtus = make([]int, 0, diff) for mtu := minMTU + 1; mtu <= maxMTU; mtu++ { mtus = append(mtus, mtu) } case diff < mtusLength*2: step := float64(diff) / float64(mtusLength) mtus = make([]int, 0, mtusLength) for mtu := float64(minMTU) + step; len(mtus) < mtusLength-1; mtu += step { mtus = append(mtus, int(math.Round(mtu))) } mtus = append(mtus, maxMTU) // last element is the maxMTU default: step := diff / mtusLength mtus = make([]int, 0, mtusLength) for mtu := minMTU + step; len(mtus) < mtusLength-1; mtu += step { mtus = append(mtus, mtu) } mtus = append(mtus, maxMTU) // last element is the maxMTU } return mtus } func collectReplies(conn net.PacketConn, ipVersion string, tests []pmtudTestUnit, logger Logger, ) (err error) { echoIDToTestIndex := make(map[uint16]int, len(tests)) for i, test := range tests { echoIDToTestIndex[test.echoID] = i } // The theoretical limit is 4GiB for IPv6 MTU path discovery jumbograms, but that would // create huge buffers which we don't really want to support anyway. // The standard frame maximum MTU is 1500 bytes, and there are Jumbo frames with // a conventional maximum of 9000 bytes. However, some manufacturers support up // 9216-20 = 9196 bytes for the maximum MTU. We thus use buffers of size 9196 to // match eventual Jumbo frames. More information at: // https://en.wikipedia.org/wiki/Maximum_transmission_unit#MTUs_for_common_media const maxPossibleMTU = 9196 buffer := make([]byte, maxPossibleMTU) idsFound := 0 for idsFound < len(tests) { // Note we need to read the whole packet in one call to ReadFrom, so the buffer // must be large enough to read the entire reply packet. See: // https://groups.google.com/g/golang-nuts/c/5dy2Q4nPs08/m/KmuSQAGEtG4J bytesRead, _, err := conn.ReadFrom(buffer) if err != nil { return fmt.Errorf("reading from ICMP connection: %w", err) } packetBytes := buffer[:bytesRead] ipPacketLength := len(packetBytes) var icmpProtocol int switch ipVersion { case "v4": icmpProtocol = icmpv4Protocol case "v6": icmpProtocol = icmpv6Protocol default: panic(fmt.Sprintf("unknown IP version: %s", ipVersion)) } // Parse the ICMP message // Note: this parsing works for a truncated 556 bytes ICMP reply packet. message, err := icmp.ParseMessage(icmpProtocol, packetBytes) if err != nil { return fmt.Errorf("parsing message: %w", err) } echoBody, ok := message.Body.(*icmp.Echo) if !ok { return fmt.Errorf("%w: %T", ErrICMPBodyUnsupported, message.Body) } id := uint16(echoBody.ID) //nolint:gosec testIndex, testing := echoIDToTestIndex[id] if !testing { // not an id we expected so ignore it logger.Warnf("ignoring ICMP reply with unexpected ID %d (type: %d, code: %d, length: %d)", echoBody.ID, message.Type, message.Code, ipPacketLength) continue } idsFound++ sentBytes := tests[testIndex].sentBytes // echo reply should be at most the number of bytes sent, // and can be lower, more precisely 556 bytes, in case // the host we are reaching wants to stay out of trouble // and ensure its echo reply goes through without // fragmentation, see the following page: // https://datatracker.ietf.org/doc/html/rfc1122#page-59 const conservativeReplyLength = 556 truncated := ipPacketLength < sentBytes && ipPacketLength == conservativeReplyLength // Check the packet size is the same if the reply is not truncated if !truncated && sentBytes != ipPacketLength { return fmt.Errorf("%w: sent %dB and received %dB", ErrICMPEchoDataMismatch, sentBytes, ipPacketLength) } // Truncated reply or matching reply size tests[testIndex].ok = true } return nil }