Firewall refactoring
- Ability to enable and disable rules in various loops - Simplified code overall - Port forwarding moved into openvpn loop - Route addition and removal improved
This commit is contained in:
@@ -7,29 +7,34 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (r *routing) AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error {
|
||||
for _, subnet := range subnets {
|
||||
exists, err := r.routeExists(subnet)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if exists { // thanks to @npawelek https://github.com/npawelek
|
||||
if err := r.removeRoute(ctx, subnet); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.logger.Info("adding %s as route via %s", subnet.String(), defaultInterface)
|
||||
output, err := r.commander.Run(ctx, "ip", "route", "add", subnet.String(), "via", defaultGateway.String(), "dev", defaultInterface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnet.String(), defaultGateway.String(), "dev", defaultInterface, output, err)
|
||||
}
|
||||
func (r *routing) AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error {
|
||||
subnetStr := subnet.String()
|
||||
r.logger.Info("adding %s as route via %s %s", subnetStr, defaultGateway, defaultInterface)
|
||||
exists, err := r.routeExists(subnet)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if exists {
|
||||
return nil
|
||||
}
|
||||
output, err := r.commander.Run(ctx, "ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnetStr, defaultGateway, "dev", defaultInterface, output, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *routing) removeRoute(ctx context.Context, subnet net.IPNet) (err error) {
|
||||
output, err := r.commander.Run(ctx, "ip", "route", "del", subnet.String())
|
||||
func (r *routing) DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error) {
|
||||
subnetStr := subnet.String()
|
||||
r.logger.Info("deleting route for %s", subnetStr)
|
||||
exists, err := r.routeExists(subnet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot delete route for %s: %s: %w", subnet.String(), output, err)
|
||||
return err
|
||||
} else if !exists { // thanks to @npawelek https://github.com/npawelek
|
||||
return nil
|
||||
}
|
||||
output, err := r.commander.Run(ctx, "ip", "route", "del", subnetStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot delete route for %s: %s: %w", subnetStr, output, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,12 +8,16 @@ import (
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/golibs/command/mock_command"
|
||||
"github.com/qdm12/golibs/files/mock_files"
|
||||
"github.com/qdm12/golibs/logging/mock_logging"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_removeRoute(t *testing.T) {
|
||||
func Test_DeleteRouteVia(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
tests := map[string]struct {
|
||||
subnet net.IPNet
|
||||
runOutput string
|
||||
@@ -22,26 +26,26 @@ func Test_removeRoute(t *testing.T) {
|
||||
}{
|
||||
"no output no error": {
|
||||
subnet: net.IPNet{
|
||||
IP: net.IP{192, 168, 1, 0},
|
||||
IP: net.IP{192, 168, 2, 0},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
},
|
||||
"error only": {
|
||||
subnet: net.IPNet{
|
||||
IP: net.IP{192, 168, 1, 0},
|
||||
IP: net.IP{192, 168, 2, 0},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
runErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("cannot delete route for 192.168.1.0/24: : error"),
|
||||
err: fmt.Errorf("cannot delete route for 192.168.2.0/24: : error"),
|
||||
},
|
||||
"error and output": {
|
||||
subnet: net.IPNet{
|
||||
IP: net.IP{192, 168, 1, 0},
|
||||
IP: net.IP{192, 168, 2, 0},
|
||||
Mask: net.IPMask{255, 255, 255, 0},
|
||||
},
|
||||
runErr: fmt.Errorf("error"),
|
||||
runOutput: "output",
|
||||
err: fmt.Errorf("cannot delete route for 192.168.1.0/24: output: error"),
|
||||
err: fmt.Errorf("cannot delete route for 192.168.2.0/24: output: error"),
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
@@ -50,12 +54,26 @@ func Test_removeRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
commander := mock_command.NewMockCommander(mockCtrl)
|
||||
|
||||
commander.EXPECT().Run(context.Background(), "ip", "route", "del", tc.subnet.String()).
|
||||
subnetStr := tc.subnet.String()
|
||||
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("deleting route for %s")
|
||||
commander := mock_command.NewMockCommander(mockCtrl)
|
||||
commander.EXPECT().Run(ctx, "ip", "route", "del", subnetStr).
|
||||
Return(tc.runOutput, tc.runErr).Times(1)
|
||||
r := &routing{commander: commander}
|
||||
err := r.removeRoute(context.Background(), tc.subnet)
|
||||
fileManager := mock_files.NewMockFileManager(mockCtrl)
|
||||
routesData := []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT
|
||||
eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0
|
||||
`)
|
||||
fileManager.EXPECT().ReadFile(string(constants.NetRoute)).Return(routesData, nil)
|
||||
r := &routing{
|
||||
logger: logger,
|
||||
commander: commander,
|
||||
fileManager: fileManager,
|
||||
}
|
||||
|
||||
err := r.DeleteRouteVia(ctx, tc.subnet)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
|
||||
@@ -10,7 +10,8 @@ import (
|
||||
)
|
||||
|
||||
type Routing interface {
|
||||
AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error
|
||||
AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error
|
||||
DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error)
|
||||
DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error)
|
||||
VPNGatewayIP(defaultInterface string) (ip net.IP, err error)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user