Fixing extra subnets firewall rules
- Fix #194 - Fix #190 - Refers to #188
This commit is contained in:
@@ -102,17 +102,17 @@ func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocogn
|
|||||||
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil {
|
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
if err := c.acceptInputFromToSubnet(ctx, localSubnet, "*", remove); err != nil {
|
if err := c.acceptInputFromSubnetToSubnet(ctx, "*", localSubnet, localSubnet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
if err := c.acceptOutputFromToSubnet(ctx, localSubnet, "*", remove); err != nil {
|
if err := c.acceptOutputFromSubnetToSubnet(ctx, "*", localSubnet, localSubnet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
for _, subnet := range c.allowedSubnets {
|
for _, subnet := range c.allowedSubnets {
|
||||||
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
if err := c.acceptInputFromSubnetToSubnet(ctx, defaultInterface, subnet, localSubnet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
if err := c.acceptOutputFromSubnetToSubnet(ctx, defaultInterface, localSubnet, subnet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -112,26 +112,24 @@ func (c *configurator) acceptOutputTrafficToVPN(ctx context.Context, defaultInte
|
|||||||
appendOrDelete(remove), connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port))
|
appendOrDelete(remove), connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) acceptInputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error {
|
func (c *configurator) acceptInputFromSubnetToSubnet(ctx context.Context, intf string, sourceSubnet, destinationSubnet net.IPNet, remove bool) error {
|
||||||
subnetStr := subnet.String()
|
|
||||||
interfaceFlag := "-i " + intf
|
interfaceFlag := "-i " + intf
|
||||||
if intf == "*" { // all interfaces
|
if intf == "*" { // all interfaces
|
||||||
interfaceFlag = ""
|
interfaceFlag = ""
|
||||||
}
|
}
|
||||||
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||||
"%s INPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr,
|
"%s INPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, sourceSubnet.String(), destinationSubnet.String(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Thanks to @npawelek
|
// Thanks to @npawelek
|
||||||
func (c *configurator) acceptOutputFromToSubnet(ctx context.Context, subnet net.IPNet, intf string, remove bool) error {
|
func (c *configurator) acceptOutputFromSubnetToSubnet(ctx context.Context, intf string, sourceSubnet, destinationSubnet net.IPNet, remove bool) error {
|
||||||
subnetStr := subnet.String()
|
|
||||||
interfaceFlag := "-o " + intf
|
interfaceFlag := "-o " + intf
|
||||||
if intf == "*" { // all interfaces
|
if intf == "*" { // all interfaces
|
||||||
interfaceFlag = ""
|
interfaceFlag = ""
|
||||||
}
|
}
|
||||||
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
return c.runIptablesInstruction(ctx, fmt.Sprintf(
|
||||||
"%s OUTPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, subnetStr, subnetStr,
|
"%s OUTPUT %s -s %s -d %s -j ACCEPT", appendOrDelete(remove), interfaceFlag, sourceSubnet.String(), destinationSubnet.String(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,9 +32,13 @@ func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNe
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
localSubnet, err := c.routing.LocalSubnet()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
c.removeSubnets(ctx, subnetsToRemove, defaultInterface)
|
c.removeSubnets(ctx, subnetsToRemove, defaultInterface, localSubnet)
|
||||||
if err := c.addSubnets(ctx, subnetsToAdd, defaultInterface, defaultGateway); err != nil {
|
if err := c.addSubnets(ctx, subnetsToAdd, defaultInterface, defaultGateway, localSubnet); err != nil {
|
||||||
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,15 +93,16 @@ func removeSubnetFromSubnets(subnets []net.IPNet, subnet net.IPNet) []net.IPNet
|
|||||||
return subnets
|
return subnets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string) {
|
func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string,
|
||||||
|
localSubnet net.IPNet) {
|
||||||
const remove = true
|
const remove = true
|
||||||
for _, subnet := range subnets {
|
for _, subnet := range subnets {
|
||||||
failed := false
|
failed := false
|
||||||
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
if err := c.acceptInputFromSubnetToSubnet(ctx, defaultInterface, subnet, localSubnet, remove); err != nil {
|
||||||
failed = true
|
failed = true
|
||||||
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
|
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
|
||||||
}
|
}
|
||||||
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
if err := c.acceptOutputFromSubnetToSubnet(ctx, defaultInterface, subnet, localSubnet, remove); err != nil {
|
||||||
failed = true
|
failed = true
|
||||||
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
|
c.logger.Error("cannot remove outdated allowed subnet through firewall: %s", err)
|
||||||
}
|
}
|
||||||
@@ -112,13 +117,14 @@ func (c *configurator) removeSubnets(ctx context.Context, subnets []net.IPNet, d
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string, defaultGateway net.IP) error {
|
func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defaultInterface string,
|
||||||
|
defaultGateway net.IP, localSubnet net.IPNet) error {
|
||||||
const remove = false
|
const remove = false
|
||||||
for _, subnet := range subnets {
|
for _, subnet := range subnets {
|
||||||
if err := c.acceptInputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
if err := c.acceptInputFromSubnetToSubnet(ctx, defaultInterface, subnet, localSubnet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
|
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
|
||||||
}
|
}
|
||||||
if err := c.acceptOutputFromToSubnet(ctx, subnet, defaultInterface, remove); err != nil {
|
if err := c.acceptOutputFromSubnetToSubnet(ctx, defaultInterface, localSubnet, subnet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
|
return fmt.Errorf("cannot add allowed subnet through firewall: %w", err)
|
||||||
}
|
}
|
||||||
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
|
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user