diff --git a/internal/configuration/settings/errors.go b/internal/configuration/settings/errors.go index c11b646a..d3ebf02c 100644 --- a/internal/configuration/settings/errors.go +++ b/internal/configuration/settings/errors.go @@ -7,7 +7,8 @@ var ( ErrControlServerPrivilegedPort = errors.New("cannot use privileged port without running as root") ErrCountryNotValid = errors.New("the country specified is not valid") ErrFilepathMissing = errors.New("filepath is missing") - ErrFirewallZeroPort = errors.New("cannot have a zero port to block") + ErrFirewallZeroPort = errors.New("cannot have a zero port") + ErrFirewallPublicOutboundSubnet = errors.New("outbound subnet is public") ErrHostnameNotValid = errors.New("the hostname specified is not valid") ErrISPNotValid = errors.New("the ISP specified is not valid") ErrMinRatioNotValid = errors.New("minimum ratio is not valid") diff --git a/internal/configuration/settings/firewall.go b/internal/configuration/settings/firewall.go index 466072d2..9d8b61dd 100644 --- a/internal/configuration/settings/firewall.go +++ b/internal/configuration/settings/firewall.go @@ -26,6 +26,12 @@ func (f Firewall) validate() (err error) { return fmt.Errorf("input ports: %w", ErrFirewallZeroPort) } + for _, subnet := range f.OutboundSubnets { + if !subnet.Addr().IsPrivate() { + return fmt.Errorf("%w: %s", ErrFirewallPublicOutboundSubnet, subnet) + } + } + return nil } diff --git a/internal/configuration/settings/firewall_test.go b/internal/configuration/settings/firewall_test.go new file mode 100644 index 00000000..7a4a4e8f --- /dev/null +++ b/internal/configuration/settings/firewall_test.go @@ -0,0 +1,76 @@ +package settings + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Firewall_validate(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + firewall Firewall + errWrapped error + errMessage string + }{ + "empty": {}, + "zero_vpn_input_port": { + firewall: Firewall{ + VPNInputPorts: []uint16{0}, + }, + errWrapped: ErrFirewallZeroPort, + errMessage: "VPN input ports: cannot have a zero port", + }, + "zero_input_port": { + firewall: Firewall{ + InputPorts: []uint16{0}, + }, + errWrapped: ErrFirewallZeroPort, + errMessage: "input ports: cannot have a zero port", + }, + "unspecified_outbound_subnet": { + firewall: Firewall{ + OutboundSubnets: []netip.Prefix{ + netip.MustParsePrefix("0.0.0.0/0"), + }, + }, + errWrapped: ErrFirewallPublicOutboundSubnet, + errMessage: "outbound subnet is public: 0.0.0.0/0", + }, + "public_outbound_subnet": { + firewall: Firewall{ + OutboundSubnets: []netip.Prefix{ + netip.MustParsePrefix("1.2.3.4/32"), + }, + }, + errWrapped: ErrFirewallPublicOutboundSubnet, + errMessage: "outbound subnet is public: 1.2.3.4/32", + }, + "valid_settings": { + firewall: Firewall{ + VPNInputPorts: []uint16{100, 101}, + InputPorts: []uint16{200, 201}, + OutboundSubnets: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.10.1.1/32"), + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + err := testCase.firewall.validate() + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +}