feat(settings): prevent public firewall outbound subnets
This commit is contained in:
@@ -7,7 +7,8 @@ var (
|
|||||||
ErrControlServerPrivilegedPort = errors.New("cannot use privileged port without running as root")
|
ErrControlServerPrivilegedPort = errors.New("cannot use privileged port without running as root")
|
||||||
ErrCountryNotValid = errors.New("the country specified is not valid")
|
ErrCountryNotValid = errors.New("the country specified is not valid")
|
||||||
ErrFilepathMissing = errors.New("filepath is missing")
|
ErrFilepathMissing = errors.New("filepath is missing")
|
||||||
ErrFirewallZeroPort = errors.New("cannot have a zero port to block")
|
ErrFirewallZeroPort = errors.New("cannot have a zero port")
|
||||||
|
ErrFirewallPublicOutboundSubnet = errors.New("outbound subnet is public")
|
||||||
ErrHostnameNotValid = errors.New("the hostname specified is not valid")
|
ErrHostnameNotValid = errors.New("the hostname specified is not valid")
|
||||||
ErrISPNotValid = errors.New("the ISP specified is not valid")
|
ErrISPNotValid = errors.New("the ISP specified is not valid")
|
||||||
ErrMinRatioNotValid = errors.New("minimum ratio is not valid")
|
ErrMinRatioNotValid = errors.New("minimum ratio is not valid")
|
||||||
|
|||||||
@@ -26,6 +26,12 @@ func (f Firewall) validate() (err error) {
|
|||||||
return fmt.Errorf("input ports: %w", ErrFirewallZeroPort)
|
return fmt.Errorf("input ports: %w", ErrFirewallZeroPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, subnet := range f.OutboundSubnets {
|
||||||
|
if !subnet.Addr().IsPrivate() {
|
||||||
|
return fmt.Errorf("%w: %s", ErrFirewallPublicOutboundSubnet, subnet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
76
internal/configuration/settings/firewall_test.go
Normal file
76
internal/configuration/settings/firewall_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package settings
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_Firewall_validate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
firewall Firewall
|
||||||
|
errWrapped error
|
||||||
|
errMessage string
|
||||||
|
}{
|
||||||
|
"empty": {},
|
||||||
|
"zero_vpn_input_port": {
|
||||||
|
firewall: Firewall{
|
||||||
|
VPNInputPorts: []uint16{0},
|
||||||
|
},
|
||||||
|
errWrapped: ErrFirewallZeroPort,
|
||||||
|
errMessage: "VPN input ports: cannot have a zero port",
|
||||||
|
},
|
||||||
|
"zero_input_port": {
|
||||||
|
firewall: Firewall{
|
||||||
|
InputPorts: []uint16{0},
|
||||||
|
},
|
||||||
|
errWrapped: ErrFirewallZeroPort,
|
||||||
|
errMessage: "input ports: cannot have a zero port",
|
||||||
|
},
|
||||||
|
"unspecified_outbound_subnet": {
|
||||||
|
firewall: Firewall{
|
||||||
|
OutboundSubnets: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("0.0.0.0/0"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
errWrapped: ErrFirewallPublicOutboundSubnet,
|
||||||
|
errMessage: "outbound subnet is public: 0.0.0.0/0",
|
||||||
|
},
|
||||||
|
"public_outbound_subnet": {
|
||||||
|
firewall: Firewall{
|
||||||
|
OutboundSubnets: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("1.2.3.4/32"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
errWrapped: ErrFirewallPublicOutboundSubnet,
|
||||||
|
errMessage: "outbound subnet is public: 1.2.3.4/32",
|
||||||
|
},
|
||||||
|
"valid_settings": {
|
||||||
|
firewall: Firewall{
|
||||||
|
VPNInputPorts: []uint16{100, 101},
|
||||||
|
InputPorts: []uint16{200, 201},
|
||||||
|
OutboundSubnets: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("10.10.1.1/32"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err := testCase.firewall.validate()
|
||||||
|
|
||||||
|
assert.ErrorIs(t, err, testCase.errWrapped)
|
||||||
|
if testCase.errWrapped != nil {
|
||||||
|
assert.EqualError(t, err, testCase.errMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user