diff --git a/internal/firewall/parse.go b/internal/firewall/parse.go index e8d3f100..bb8dd9e0 100644 --- a/internal/firewall/parse.go +++ b/internal/firewall/parse.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/netip" - "regexp" "slices" "strconv" "strings" @@ -153,11 +152,15 @@ func parseInstructionFlag(key, value string, instruction *iptablesInstruction) ( return nil } -var regexCidrSuffix = regexp.MustCompile(`/[0-9][0-9]{0,2}$`) - func parseIPPrefix(value string) (prefix netip.Prefix, err error) { - if !regexCidrSuffix.MatchString(value) { - value += "/32" + slashIndex := strings.Index(value, "/") + if slashIndex >= 0 { + return netip.ParsePrefix(value) } - return netip.ParsePrefix(value) + + ip, err := netip.ParseAddr(value) + if err != nil { + return netip.Prefix{}, fmt.Errorf("parsing IP address: %w", err) + } + return netip.PrefixFrom(ip, ip.BitLen()), nil } diff --git a/internal/firewall/parse_test.go b/internal/firewall/parse_test.go index ad102c6d..1244dc6c 100644 --- a/internal/firewall/parse_test.go +++ b/internal/firewall/parse_test.go @@ -82,3 +82,57 @@ func Test_parseIptablesInstruction(t *testing.T) { }) } } + +func Test_parseIPPrefix(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + value string + prefix netip.Prefix + errMessage string + }{ + "empty": { + errMessage: `parsing IP address: ParseAddr(""): unable to parse IP`, + }, + "invalid": { + value: "invalid", + errMessage: `parsing IP address: ParseAddr("invalid"): unable to parse IP`, + }, + "valid_ipv4_with_bits": { + value: "10.0.0.0/16", + prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 0}), 16), + }, + "valid_ipv4_without_bits": { + value: "10.0.0.4", + prefix: netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 0, 0, 4}), 32), + }, + "valid_ipv6_with_bits": { + value: "2001:db8::/32", + prefix: netip.PrefixFrom( + netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}), + 32), + }, + "valid_ipv6_without_bits": { + value: "2001:db8::", + prefix: netip.PrefixFrom( + netip.AddrFrom16([16]byte{0x20, 0x01, 0x0d, 0xb8}), + 128), + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + prefix, err := parseIPPrefix(testCase.value) + + assert.Equal(t, testCase.prefix, prefix) + if testCase.errMessage != "" { + assert.EqualError(t, err, testCase.errMessage) + } else { + assert.NoError(t, err) + } + }) + } +}