chore(errors): review all errors in codebase
This commit is contained in:
@@ -2,16 +2,9 @@ package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEnable = errors.New("failed enabling firewall")
|
||||
ErrDisable = errors.New("failed disabling firewall")
|
||||
ErrUserPostRules = errors.New("cannot run user post firewall rules")
|
||||
)
|
||||
|
||||
type Enabler interface {
|
||||
SetEnabled(ctx context.Context, enabled bool) (err error)
|
||||
}
|
||||
@@ -32,7 +25,7 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
||||
if !enabled {
|
||||
c.logger.Info("disabling...")
|
||||
if err = c.disable(ctx); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("cannot disable firewall: %w", err)
|
||||
}
|
||||
c.enabled = false
|
||||
c.logger.Info("disabled successfully")
|
||||
@@ -42,7 +35,7 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
||||
c.logger.Info("enabling...")
|
||||
|
||||
if err := c.enable(ctx); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrEnable, err)
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
}
|
||||
c.enabled = true
|
||||
c.logger.Info("enabled successfully")
|
||||
@@ -52,13 +45,13 @@ func (c *Config) SetEnabled(ctx context.Context, enabled bool) (err error) {
|
||||
|
||||
func (c *Config) disable(ctx context.Context) (err error) {
|
||||
if err = c.clearAllRules(ctx); err != nil {
|
||||
return fmt.Errorf("cannot disable firewall: %w", err)
|
||||
return fmt.Errorf("cannot clear all rules: %w", err)
|
||||
}
|
||||
if err = c.setIPv4AllPolicies(ctx, "ACCEPT"); err != nil {
|
||||
return fmt.Errorf("cannot disable firewall: %w", err)
|
||||
return fmt.Errorf("cannot set ipv4 policies: %w", err)
|
||||
}
|
||||
if err = c.setIPv6AllPolicies(ctx, "ACCEPT"); err != nil {
|
||||
return fmt.Errorf("cannot disable firewall: %w", err)
|
||||
return fmt.Errorf("cannot set ipv6 policies: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -76,12 +69,12 @@ func (c *Config) fallbackToDisabled(ctx context.Context) {
|
||||
func (c *Config) enable(ctx context.Context) (err error) {
|
||||
touched := false
|
||||
if err = c.setIPv4AllPolicies(ctx, "DROP"); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
return err
|
||||
}
|
||||
touched = true
|
||||
|
||||
if err = c.setIPv6AllPolicies(ctx, "DROP"); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
const remove = false
|
||||
@@ -94,33 +87,33 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
|
||||
// Loopback traffic
|
||||
if err = c.acceptInputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
return err
|
||||
}
|
||||
if err = c.acceptOutputThroughInterface(ctx, "lo", remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
return err
|
||||
}
|
||||
if c.vpnConnection.IP != nil {
|
||||
if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
return err
|
||||
}
|
||||
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, network := range c.localNetworks {
|
||||
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, *network.IPNet, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, subnet := range c.outboundSubnets {
|
||||
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,18 +121,18 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
||||
// to reach Gluetun.
|
||||
for _, network := range c.localNetworks {
|
||||
if err := c.acceptInputToSubnet(ctx, network.InterfaceName, *network.IPNet, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for port, intf := range c.allowedInputPorts {
|
||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrUserPostRules, err)
|
||||
return fmt.Errorf("cannot run user defined post firewall rules: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIP6Tables = errors.New("failed ip6tables command")
|
||||
ErrIP6NotSupported = errors.New("ip6tables not supported")
|
||||
)
|
||||
|
||||
@@ -44,18 +43,18 @@ func (c *Config) runIP6tablesInstruction(ctx context.Context, instruction string
|
||||
flags := strings.Fields(instruction)
|
||||
cmd := exec.CommandContext(ctx, "ip6tables", flags...)
|
||||
if output, err := c.runner.Run(cmd); err != nil {
|
||||
return fmt.Errorf("%w: \"ip6tables %s\": %s: %s", ErrIP6Tables, instruction, output, err)
|
||||
return fmt.Errorf("command failed: \"ip6tables %s\": %s: %w", instruction, output, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var errPolicyNotValid = errors.New("policy is not valid")
|
||||
var ErrPolicyNotValid = errors.New("policy is not valid")
|
||||
|
||||
func (c *Config) setIPv6AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s", errPolicyNotValid, policy)
|
||||
return fmt.Errorf("%w: %s", ErrPolicyNotValid, policy)
|
||||
}
|
||||
return c.runIP6tablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
|
||||
@@ -16,10 +16,7 @@ import (
|
||||
|
||||
var (
|
||||
ErrIPTablesVersionTooShort = errors.New("iptables version string is too short")
|
||||
ErrIPTables = errors.New("failed iptables command")
|
||||
ErrPolicyUnknown = errors.New("unknown policy")
|
||||
ErrClearRules = errors.New("cannot clear all rules")
|
||||
ErrSetIPtablesPolicies = errors.New("cannot set iptables policies")
|
||||
ErrNeedIP6Tables = errors.New("ip6tables is required, please upgrade your kernel to support it")
|
||||
)
|
||||
|
||||
@@ -79,33 +76,30 @@ func (c *Config) runIptablesInstruction(ctx context.Context, instruction string)
|
||||
flags := strings.Fields(instruction)
|
||||
cmd := exec.CommandContext(ctx, "iptables", flags...)
|
||||
if output, err := c.runner.Run(cmd); err != nil {
|
||||
return fmt.Errorf("%w \"iptables %s\": %s: %s", ErrIPTables, instruction, output, err)
|
||||
return fmt.Errorf("command failed: \"iptables %s\": %s: %w", instruction, output, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) clearAllRules(ctx context.Context) error {
|
||||
if err := c.runMixedIptablesInstructions(ctx, []string{
|
||||
return c.runMixedIptablesInstructions(ctx, []string{
|
||||
"--flush", // flush all chains
|
||||
"--delete-chain", // delete all chains
|
||||
}); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrClearRules, err.Error())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) setIPv4AllPolicies(ctx context.Context, policy string) error {
|
||||
switch policy {
|
||||
case "ACCEPT", "DROP":
|
||||
default:
|
||||
return fmt.Errorf("%w: %s: %s", ErrSetIPtablesPolicies, ErrPolicyUnknown, policy)
|
||||
return fmt.Errorf("%w: %s", ErrPolicyUnknown, policy)
|
||||
}
|
||||
if err := c.runIptablesInstructions(ctx, []string{
|
||||
"--policy INPUT " + policy,
|
||||
"--policy OUTPUT " + policy,
|
||||
"--policy FORWARD " + policy,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("%w: %s", ErrSetIPtablesPolicies, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (e
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Info("setting allowed subnets through firewall...")
|
||||
c.logger.Info("setting allowed subnets...")
|
||||
|
||||
subnetsToAdd, subnetsToRemove := subnet.FindSubnetsToChange(c.outboundSubnets, subnets)
|
||||
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
|
||||
@@ -32,7 +32,7 @@ func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (e
|
||||
|
||||
c.removeOutboundSubnets(ctx, subnetsToRemove)
|
||||
if err := c.addOutboundSubnets(ctx, subnetsToAdd); err != nil {
|
||||
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
||||
return fmt.Errorf("cannot set allowed outbound subnets: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -42,7 +42,7 @@ func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet)
|
||||
const remove = true
|
||||
for _, subNet := range subnets {
|
||||
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subNet, remove); err != nil {
|
||||
c.logger.Error("cannot remove outdated outbound subnet through firewall: " + err.Error())
|
||||
c.logger.Error("cannot remove outdated outbound subnet: " + err.Error())
|
||||
continue
|
||||
}
|
||||
c.outboundSubnets = subnet.RemoveSubnetFromSubnets(c.outboundSubnets, subNet)
|
||||
@@ -53,7 +53,7 @@ func (c *Config) addOutboundSubnets(ctx context.Context, subnets []net.IPNet) er
|
||||
const remove = false
|
||||
for _, subnet := range subnets {
|
||||
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil {
|
||||
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
|
||||
return err
|
||||
}
|
||||
c.outboundSubnets = append(c.outboundSubnets, subnet)
|
||||
}
|
||||
|
||||
@@ -33,13 +33,13 @@ func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (
|
||||
}
|
||||
const remove = true
|
||||
if err := c.acceptInputToPort(ctx, existingIntf, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot remove old allowed port %d through interface %s: %w", port, existingIntf, err)
|
||||
return fmt.Errorf("cannot remove old allowed port %d: %w", port, err)
|
||||
}
|
||||
}
|
||||
|
||||
const remove = false
|
||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot set allowed port %d through interface %s: %w", port, intf, err)
|
||||
return fmt.Errorf("cannot allow input to port %d: %w", port, err)
|
||||
}
|
||||
c.allowedInputPorts[port] = intf
|
||||
|
||||
@@ -60,7 +60,7 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Info("removing allowed port " + strconv.Itoa(int(port)) + " through firewall...")
|
||||
c.logger.Info("removing allowed port " + strconv.Itoa(int(port)) + " ...")
|
||||
|
||||
intf, ok := c.allowedInputPorts[port]
|
||||
if !ok {
|
||||
@@ -69,7 +69,7 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||
|
||||
const remove = true
|
||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||
return fmt.Errorf("cannot remove allowed port %d through interface %s: %w", port, intf, err)
|
||||
return fmt.Errorf("cannot remove allowed port %d: %w", port, err)
|
||||
}
|
||||
delete(c.allowedInputPorts, port)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Info("setting VPN connection through firewall...")
|
||||
c.logger.Info("allowing VPN connection...")
|
||||
|
||||
if c.vpnConnection.Equal(connection) {
|
||||
return nil
|
||||
@@ -32,14 +32,14 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
||||
remove := true
|
||||
if c.vpnConnection.IP != nil {
|
||||
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
|
||||
c.logger.Error("cannot remove outdated VPN connection through firewall: " + err.Error())
|
||||
c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error())
|
||||
}
|
||||
}
|
||||
c.vpnConnection = models.Connection{}
|
||||
|
||||
if c.vpnIntf != "" {
|
||||
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
|
||||
c.logger.Error("cannot remove outdated VPN interface from firewall: " + err.Error())
|
||||
c.logger.Error("cannot remove outdated VPN interface rule: " + err.Error())
|
||||
}
|
||||
}
|
||||
c.vpnIntf = ""
|
||||
@@ -47,7 +47,7 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
||||
remove = false
|
||||
|
||||
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil {
|
||||
return fmt.Errorf("cannot set VPN connection through firewall: %w", err)
|
||||
return fmt.Errorf("cannot allow output traffic through VPN connection: %w", err)
|
||||
}
|
||||
c.vpnConnection = connection
|
||||
|
||||
|
||||
Reference in New Issue
Block a user