Get default route and local subnet only at start
This commit is contained in:
@@ -104,6 +104,18 @@ func _main(background context.Context, args []string) int {
|
|||||||
routingConf.SetDebug()
|
routingConf.SetDebug()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defaultInterface, defaultGateway, err := routingConf.DefaultRoute()
|
||||||
|
if err != nil {
|
||||||
|
fatalOnError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
localSubnet, err := routingConf.LocalSubnet()
|
||||||
|
if err != nil {
|
||||||
|
fatalOnError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localSubnet)
|
||||||
|
|
||||||
if err := ovpnConf.CheckTUN(); err != nil {
|
if err := ovpnConf.CheckTUN(); err != nil {
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
err = ovpnConf.CreateTUN()
|
err = ovpnConf.CreateTUN()
|
||||||
|
|||||||
@@ -62,15 +62,6 @@ func (c *configurator) fallbackToDisabled(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocognit
|
func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocognit
|
||||||
defaultInterface, defaultGateway, err := c.routing.DefaultRoute()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
|
||||||
}
|
|
||||||
localSubnet, err := c.routing.LocalSubnet()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = c.setAllPolicies(ctx, "DROP"); err != nil {
|
if err = c.setAllPolicies(ctx, "DROP"); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
@@ -95,30 +86,30 @@ func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocogn
|
|||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
for _, conn := range c.vpnConnections {
|
for _, conn := range c.vpnConnections {
|
||||||
if err = c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
|
if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, conn, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil {
|
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
if err := c.acceptInputFromSubnetToSubnet(ctx, "*", localSubnet, localSubnet, remove); err != nil {
|
if err := c.acceptInputFromSubnetToSubnet(ctx, "*", c.localSubnet, c.localSubnet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
if err := c.acceptOutputFromSubnetToSubnet(ctx, "*", localSubnet, localSubnet, remove); err != nil {
|
if err := c.acceptOutputFromSubnetToSubnet(ctx, "*", c.localSubnet, c.localSubnet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
for _, subnet := range c.allowedSubnets {
|
for _, subnet := range c.allowedSubnets {
|
||||||
if err := c.acceptInputFromSubnetToSubnet(ctx, defaultInterface, subnet, localSubnet, remove); err != nil {
|
if err := c.acceptInputFromSubnetToSubnet(ctx, c.defaultInterface, subnet, c.localSubnet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
if err := c.acceptOutputFromSubnetToSubnet(ctx, defaultInterface, localSubnet, subnet, remove); err != nil {
|
if err := c.acceptOutputFromSubnetToSubnet(ctx, c.defaultInterface, c.localSubnet, subnet, remove); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Re-ensure all routes exist
|
// Re-ensure all routes exist
|
||||||
for _, subnet := range c.allowedSubnets {
|
for _, subnet := range c.allowedSubnets {
|
||||||
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
|
if err := c.routing.AddRouteVia(ctx, subnet, c.defaultGateway, c.defaultInterface); err != nil {
|
||||||
return fmt.Errorf("cannot enable firewall: %w", err)
|
return fmt.Errorf("cannot enable firewall: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,15 +22,21 @@ type Configurator interface {
|
|||||||
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
||||||
SetPortForward(ctx context.Context, port uint16) (err error)
|
SetPortForward(ctx context.Context, port uint16) (err error)
|
||||||
SetDebug()
|
SetDebug()
|
||||||
|
// SetNetworkInformation is meant to be called only once
|
||||||
|
SetNetworkInformation(defaultInterface string, defaultGateway net.IP, localSubnet net.IPNet)
|
||||||
}
|
}
|
||||||
|
|
||||||
type configurator struct { //nolint:maligned
|
type configurator struct { //nolint:maligned
|
||||||
commander command.Commander
|
commander command.Commander
|
||||||
logger logging.Logger
|
logger logging.Logger
|
||||||
routing routing.Routing
|
routing routing.Routing
|
||||||
fileManager files.FileManager // for custom iptables rules
|
fileManager files.FileManager // for custom iptables rules
|
||||||
iptablesMutex sync.Mutex
|
iptablesMutex sync.Mutex
|
||||||
debug bool
|
debug bool
|
||||||
|
defaultInterface string
|
||||||
|
defaultGateway net.IP
|
||||||
|
localSubnet net.IPNet
|
||||||
|
networkInfoMutex sync.Mutex
|
||||||
|
|
||||||
// State
|
// State
|
||||||
enabled bool
|
enabled bool
|
||||||
@@ -55,3 +61,11 @@ func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager
|
|||||||
func (c *configurator) SetDebug() {
|
func (c *configurator) SetDebug() {
|
||||||
c.debug = true
|
c.debug = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *configurator) SetNetworkInformation(defaultInterface string, defaultGateway net.IP, localSubnet net.IPNet) {
|
||||||
|
c.networkInfoMutex.Lock()
|
||||||
|
defer c.networkInfoMutex.Unlock()
|
||||||
|
c.defaultInterface = defaultInterface
|
||||||
|
c.defaultGateway = defaultGateway
|
||||||
|
c.localSubnet = localSubnet
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,9 +12,7 @@ func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNe
|
|||||||
|
|
||||||
if !c.enabled {
|
if !c.enabled {
|
||||||
c.logger.Info("firewall disabled, only updating allowed subnets internal list and updating routes")
|
c.logger.Info("firewall disabled, only updating allowed subnets internal list and updating routes")
|
||||||
if err := c.updateSubnetRoutes(ctx, c.allowedSubnets, subnets); err != nil {
|
c.updateSubnetRoutes(ctx, c.allowedSubnets, subnets)
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.allowedSubnets = make([]net.IPNet, len(subnets))
|
c.allowedSubnets = make([]net.IPNet, len(subnets))
|
||||||
copy(c.allowedSubnets, subnets)
|
copy(c.allowedSubnets, subnets)
|
||||||
return nil
|
return nil
|
||||||
@@ -28,17 +26,8 @@ func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNe
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultInterface, defaultGateway, err := c.routing.DefaultRoute()
|
c.removeSubnets(ctx, subnetsToRemove, c.defaultInterface, c.localSubnet)
|
||||||
if err != nil {
|
if err := c.addSubnets(ctx, subnetsToAdd, c.defaultInterface, c.defaultGateway, c.localSubnet); err != nil {
|
||||||
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
|
||||||
}
|
|
||||||
localSubnet, err := c.routing.LocalSubnet()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.removeSubnets(ctx, subnetsToRemove, defaultInterface, localSubnet)
|
|
||||||
if err := c.addSubnets(ctx, subnetsToAdd, defaultInterface, defaultGateway, localSubnet); err != nil {
|
|
||||||
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
return fmt.Errorf("cannot set allowed subnets through firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,15 +124,12 @@ func (c *configurator) addSubnets(ctx context.Context, subnets []net.IPNet, defa
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *configurator) updateSubnetRoutes(ctx context.Context, oldSubnets, newSubnets []net.IPNet) error {
|
// updateSubnetRoutes does not return an error in order to try to run as many route commands as possible
|
||||||
|
func (c *configurator) updateSubnetRoutes(ctx context.Context, oldSubnets, newSubnets []net.IPNet) {
|
||||||
subnetsToAdd := findSubnetsToAdd(oldSubnets, newSubnets)
|
subnetsToAdd := findSubnetsToAdd(oldSubnets, newSubnets)
|
||||||
subnetsToRemove := findSubnetsToRemove(oldSubnets, newSubnets)
|
subnetsToRemove := findSubnetsToRemove(oldSubnets, newSubnets)
|
||||||
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
|
if len(subnetsToAdd) == 0 && len(subnetsToRemove) == 0 {
|
||||||
return nil
|
return
|
||||||
}
|
|
||||||
defaultInterface, defaultGateway, err := c.routing.DefaultRoute()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
for _, subnet := range subnetsToRemove {
|
for _, subnet := range subnetsToRemove {
|
||||||
if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil {
|
if err := c.routing.DeleteRouteVia(ctx, subnet); err != nil {
|
||||||
@@ -151,9 +137,8 @@ func (c *configurator) updateSubnetRoutes(ctx context.Context, oldSubnets, newSu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, subnet := range subnetsToAdd {
|
for _, subnet := range subnetsToAdd {
|
||||||
if err := c.routing.AddRouteVia(ctx, subnet, defaultGateway, defaultInterface); err != nil {
|
if err := c.routing.AddRouteVia(ctx, subnet, c.defaultGateway, c.defaultInterface); err != nil {
|
||||||
c.logger.Error("cannot add route for subnet: %s", err)
|
c.logger.Error("cannot add route for subnet: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,13 +26,8 @@ func (c *configurator) SetVPNConnections(ctx context.Context, connections []mode
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultInterface, _, err := c.routing.DefaultRoute()
|
c.removeConnections(ctx, connectionsToRemove, c.defaultInterface)
|
||||||
if err != nil {
|
if err := c.addConnections(ctx, connectionsToAdd, c.defaultInterface); err != nil {
|
||||||
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.removeConnections(ctx, connectionsToRemove, defaultInterface)
|
|
||||||
if err := c.addConnections(ctx, connectionsToAdd, defaultInterface); err != nil {
|
|
||||||
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
|
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user