fix(firewall): delete chain rules by line number (#2411)
- Fix #2334 - Parsing of iptables chains, contributing to progress for #1856
This commit is contained in:
381
internal/firewall/list.go
Normal file
381
internal/firewall/list.go
Normal file
@@ -0,0 +1,381 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type chain struct {
|
||||
name string
|
||||
policy string
|
||||
packets uint64
|
||||
bytes uint64
|
||||
rules []chainRule
|
||||
}
|
||||
|
||||
type chainRule struct {
|
||||
lineNumber uint16 // starts from 1 and cannot be zero.
|
||||
packets uint64
|
||||
bytes uint64
|
||||
target string // "ACCEPT", "DROP", "REJECT" or "REDIRECT"
|
||||
protocol string // "tcp", "udp" or "" for all protocols.
|
||||
inputInterface string // input interface, for example "tun0" or "*""
|
||||
outputInterface string // output interface, for example "eth0" or "*""
|
||||
source netip.Prefix // source IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
destination netip.Prefix // destination IP CIDR, for example 0.0.0.0/0. Must be valid.
|
||||
destinationPort uint16 // Not specified if set to zero.
|
||||
redirPorts []uint16 // Not specified if empty.
|
||||
ctstate []string // for example ["RELATED","ESTABLISHED"]. Can be empty.
|
||||
}
|
||||
|
||||
var (
|
||||
ErrChainListMalformed = errors.New("iptables chain list output is malformed")
|
||||
)
|
||||
|
||||
func parseChain(iptablesOutput string) (c chain, err error) {
|
||||
// Text example:
|
||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
// pkts bytes target prot opt in out source destination
|
||||
// 0 0 ACCEPT 17 -- tun0 * 0.0.0.0/0 0.0.0.0/0 udp dpt:55405
|
||||
// 0 0 ACCEPT 6 -- tun0 * 0.0.0.0/0 0.0.0.0/0 tcp dpt:55405
|
||||
// 0 0 DROP 0 -- tun0 * 0.0.0.0/0 0.0.0.0/0
|
||||
iptablesOutput = strings.TrimSpace(iptablesOutput)
|
||||
linesWithComments := strings.Split(iptablesOutput, "\n")
|
||||
|
||||
// Filter out lines starting with a '#' character
|
||||
lines := make([]string, 0, len(linesWithComments))
|
||||
for _, line := range linesWithComments {
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, line)
|
||||
}
|
||||
|
||||
const minLines = 2 // chain general information line + legend line
|
||||
if len(lines) < minLines {
|
||||
return chain{}, fmt.Errorf("%w: not enough lines to process in: %s",
|
||||
ErrChainListMalformed, iptablesOutput)
|
||||
}
|
||||
|
||||
c, err = parseChainGeneralDataLine(lines[0])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing chain general data line: %w", err)
|
||||
}
|
||||
|
||||
// Sanity check for the legend line
|
||||
expectedLegendFields := []string{"num", "pkts", "bytes", "target", "prot", "opt", "in", "out", "source", "destination"}
|
||||
legendLine := strings.TrimSpace(lines[1])
|
||||
legendFields := strings.Fields(legendLine)
|
||||
if !slices.Equal(expectedLegendFields, legendFields) {
|
||||
return chain{}, fmt.Errorf("%w: legend %q is not the expected %q",
|
||||
ErrChainListMalformed, legendLine, strings.Join(expectedLegendFields, " "))
|
||||
}
|
||||
|
||||
lines = lines[2:] // remove chain general information line and legend line
|
||||
if len(lines) == 0 {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
c.rules = make([]chainRule, len(lines))
|
||||
for i, line := range lines {
|
||||
c.rules[i], err = parseChainRuleLine(line)
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing chain rule %q: %w", line, err)
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// parseChainGeneralDataLine parses the first line of iptables chain list output.
|
||||
// For example, it can parse the following line:
|
||||
// Chain INPUT (policy ACCEPT 140K packets, 226M bytes)
|
||||
// It returns a chain struct with the parsed data.
|
||||
func parseChainGeneralDataLine(line string) (base chain, err error) {
|
||||
line = strings.TrimSpace(line)
|
||||
runesToRemove := []rune{'(', ')', ','}
|
||||
for _, r := range runesToRemove {
|
||||
line = strings.ReplaceAll(line, string(r), "")
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
const expectedNumberOfFields = 8
|
||||
if len(fields) != expectedNumberOfFields {
|
||||
return chain{}, fmt.Errorf("%w: expected %d fields in %q",
|
||||
ErrChainListMalformed, expectedNumberOfFields, line)
|
||||
}
|
||||
|
||||
// Sanity checks
|
||||
indexToExpectedValue := map[int]string{
|
||||
0: "Chain",
|
||||
2: "policy",
|
||||
5: "packets",
|
||||
7: "bytes",
|
||||
}
|
||||
for index, expectedValue := range indexToExpectedValue {
|
||||
if fields[index] == expectedValue {
|
||||
continue
|
||||
}
|
||||
return chain{}, fmt.Errorf("%w: expected %q for field %d in %q",
|
||||
ErrChainListMalformed, expectedValue, index, line)
|
||||
}
|
||||
|
||||
base.name = fields[1] // chain name could be custom
|
||||
base.policy = fields[3]
|
||||
err = checkTarget(base.policy)
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("policy target in %q: %w", line, err)
|
||||
}
|
||||
|
||||
packets, err := parseMetricSize(fields[4])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing packets: %w", err)
|
||||
}
|
||||
base.packets = packets
|
||||
|
||||
bytes, err := parseMetricSize(fields[6])
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("parsing bytes: %w", err)
|
||||
}
|
||||
base.bytes = bytes
|
||||
|
||||
return base, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrChainRuleMalformed = errors.New("chain rule is malformed")
|
||||
)
|
||||
|
||||
func parseChainRuleLine(line string) (rule chainRule, err error) {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
return chainRule{}, fmt.Errorf("%w: empty line", ErrChainRuleMalformed)
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
|
||||
const minFields = 10
|
||||
if len(fields) < minFields {
|
||||
return chainRule{}, fmt.Errorf("%w: not enough fields", ErrChainRuleMalformed)
|
||||
}
|
||||
|
||||
for fieldIndex, field := range fields[:minFields] {
|
||||
err = parseChainRuleField(fieldIndex, field, &rule)
|
||||
if err != nil {
|
||||
return chainRule{}, fmt.Errorf("parsing chain rule field: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(fields) > minFields {
|
||||
err = parseChainRuleOptionalFields(fields[minFields:], &rule)
|
||||
if err != nil {
|
||||
return chainRule{}, fmt.Errorf("parsing optional fields: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func parseChainRuleField(fieldIndex int, field string, rule *chainRule) (err error) {
|
||||
if field == "" {
|
||||
return fmt.Errorf("%w: empty field at index %d", ErrChainRuleMalformed, fieldIndex)
|
||||
}
|
||||
|
||||
const (
|
||||
numIndex = iota
|
||||
packetsIndex
|
||||
bytesIndex
|
||||
targetIndex
|
||||
protocolIndex
|
||||
optIndex
|
||||
inputInterfaceIndex
|
||||
outputInterfaceIndex
|
||||
sourceIndex
|
||||
destinationIndex
|
||||
)
|
||||
|
||||
switch fieldIndex {
|
||||
case numIndex:
|
||||
rule.lineNumber, err = parseLineNumber(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing line number: %w", err)
|
||||
}
|
||||
case packetsIndex:
|
||||
rule.packets, err = parseMetricSize(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing packets: %w", err)
|
||||
}
|
||||
case bytesIndex:
|
||||
rule.bytes, err = parseMetricSize(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing bytes: %w", err)
|
||||
}
|
||||
case targetIndex:
|
||||
err = checkTarget(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking target: %w", err)
|
||||
}
|
||||
rule.target = field
|
||||
case protocolIndex:
|
||||
rule.protocol, err = parseProtocol(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing protocol: %w", err)
|
||||
}
|
||||
case optIndex: // ignored
|
||||
case inputInterfaceIndex:
|
||||
rule.inputInterface = field
|
||||
case outputInterfaceIndex:
|
||||
rule.outputInterface = field
|
||||
case sourceIndex:
|
||||
rule.source, err = parseIPPrefix(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing source IP CIDR: %w", err)
|
||||
}
|
||||
case destinationIndex:
|
||||
rule.destination, err = parseIPPrefix(field)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination IP CIDR: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseChainRuleOptionalFields(optionalFields []string, rule *chainRule) (err error) {
|
||||
for i := 0; i < len(optionalFields); i++ {
|
||||
key := optionalFields[i]
|
||||
switch key {
|
||||
case "tcp", "udp":
|
||||
i++
|
||||
value := optionalFields[i]
|
||||
value = strings.TrimPrefix(value, "dpt:")
|
||||
const base, bitLength = 10, 16
|
||||
destinationPort, err := strconv.ParseUint(value, base, bitLength)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing destination port %q: %w", value, err)
|
||||
}
|
||||
rule.destinationPort = uint16(destinationPort)
|
||||
case "redir":
|
||||
i++
|
||||
switch optionalFields[i] {
|
||||
case "ports":
|
||||
i++
|
||||
ports, err := parsePortsCSV(optionalFields[i])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing redirection ports: %w", err)
|
||||
}
|
||||
rule.redirPorts = ports
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s",
|
||||
ErrChainRuleMalformed, optionalFields[i])
|
||||
}
|
||||
case "ctstate":
|
||||
i++
|
||||
rule.ctstate = strings.Split(optionalFields[i], ",")
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected optional field: %s", ErrChainRuleMalformed, key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parsePortsCSV(s string) (ports []uint16, err error) {
|
||||
if s == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
fields := strings.Split(s, ",")
|
||||
ports = make([]uint16, len(fields))
|
||||
for i, field := range fields {
|
||||
const base, bitLength = 10, 16
|
||||
port, err := strconv.ParseUint(field, base, bitLength)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing port %q: %w", field, err)
|
||||
}
|
||||
ports[i] = uint16(port)
|
||||
}
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrLineNumberIsZero = errors.New("line number is zero")
|
||||
)
|
||||
|
||||
func parseLineNumber(s string) (n uint16, err error) {
|
||||
const base, bitLength = 10, 16
|
||||
lineNumber, err := strconv.ParseUint(s, base, bitLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if lineNumber == 0 {
|
||||
return 0, fmt.Errorf("%w", ErrLineNumberIsZero)
|
||||
}
|
||||
return uint16(lineNumber), nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrTargetUnknown = errors.New("unknown target")
|
||||
)
|
||||
|
||||
func checkTarget(target string) (err error) {
|
||||
switch target {
|
||||
case "ACCEPT", "DROP", "REJECT", "REDIRECT":
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%w: %s", ErrTargetUnknown, target)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrProtocolUnknown = errors.New("unknown protocol")
|
||||
)
|
||||
|
||||
func parseProtocol(s string) (protocol string, err error) {
|
||||
switch s {
|
||||
case "0":
|
||||
case "6":
|
||||
protocol = "tcp"
|
||||
case "17":
|
||||
protocol = "udp"
|
||||
default:
|
||||
return "", fmt.Errorf("%w: %s", ErrProtocolUnknown, s)
|
||||
}
|
||||
return protocol, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrMetricSizeMalformed = errors.New("metric size is malformed")
|
||||
)
|
||||
|
||||
// parseMetricSize parses a metric size string like 140K or 226M and
|
||||
// returns the raw integer matching it.
|
||||
func parseMetricSize(size string) (n uint64, err error) {
|
||||
if size == "" {
|
||||
return n, fmt.Errorf("%w: empty string", ErrMetricSizeMalformed)
|
||||
}
|
||||
|
||||
//nolint:gomnd
|
||||
multiplerLetterToValue := map[byte]uint64{
|
||||
'K': 1000,
|
||||
'M': 1000000,
|
||||
'G': 1000000000,
|
||||
'T': 1000000000000,
|
||||
}
|
||||
|
||||
lastCharacter := size[len(size)-1]
|
||||
multiplier, ok := multiplerLetterToValue[lastCharacter]
|
||||
if ok { // multiplier present
|
||||
size = size[:len(size)-1]
|
||||
} else {
|
||||
multiplier = 1
|
||||
}
|
||||
|
||||
const base, bitLength = 10, 64
|
||||
n, err = strconv.ParseUint(size, base, bitLength)
|
||||
if err != nil {
|
||||
return n, fmt.Errorf("%w: %w", ErrMetricSizeMalformed, err)
|
||||
}
|
||||
n *= multiplier
|
||||
return n, nil
|
||||
}
|
||||
Reference in New Issue
Block a user