Firewall refactoring
- Ability to enable and disable rules in various loops - Simplified code overall - Port forwarding moved into openvpn loop - Route addition and removal improved
This commit is contained in:
112
internal/firewall/vpn.go
Normal file
112
internal/firewall/vpn.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
)
|
||||
|
||||
func (c *configurator) SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error) {
|
||||
c.stateMutex.Lock()
|
||||
defer c.stateMutex.Unlock()
|
||||
|
||||
if !c.enabled {
|
||||
c.logger.Info("firewall disabled, only updating VPN connections internal list")
|
||||
c.vpnConnections = make([]models.OpenVPNConnection, len(connections))
|
||||
copy(c.vpnConnections, connections)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Info("setting VPN connections through firewall...")
|
||||
|
||||
connectionsToAdd := findConnectionsToAdd(c.vpnConnections, connections)
|
||||
connectionsToRemove := findConnectionsToRemove(c.vpnConnections, connections)
|
||||
if len(connectionsToAdd) == 0 && len(connectionsToRemove) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
defaultInterface, _, _, err := c.routing.DefaultRoute()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
|
||||
}
|
||||
|
||||
// TODO remove elsewhere?
|
||||
if err := c.acceptOutputThroughInterface(ctx, string(constants.TUN), false); err != nil {
|
||||
return fmt.Errorf("cannot allow traffic through tunnel: %w", err)
|
||||
}
|
||||
|
||||
c.removeConnections(ctx, connectionsToRemove, defaultInterface)
|
||||
if err := c.addConnections(ctx, connectionsToAdd, defaultInterface); err != nil {
|
||||
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeConnectionFromConnections(connections []models.OpenVPNConnection, connection models.OpenVPNConnection) []models.OpenVPNConnection {
|
||||
L := len(connections)
|
||||
for i := range connections {
|
||||
if connection.Equal(connections[i]) {
|
||||
connections[i] = connections[L-1]
|
||||
connections = connections[:L-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
return connections
|
||||
}
|
||||
|
||||
func findConnectionsToAdd(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToAdd []models.OpenVPNConnection) {
|
||||
for _, newConnection := range newConnections {
|
||||
found := false
|
||||
for _, oldConnection := range oldConnections {
|
||||
if oldConnection.Equal(newConnection) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
connectionsToAdd = append(connectionsToAdd, newConnection)
|
||||
}
|
||||
}
|
||||
return connectionsToAdd
|
||||
}
|
||||
|
||||
func findConnectionsToRemove(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToRemove []models.OpenVPNConnection) {
|
||||
for _, oldConnection := range oldConnections {
|
||||
found := false
|
||||
for _, newConnection := range newConnections {
|
||||
if oldConnection.Equal(newConnection) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
connectionsToRemove = append(connectionsToRemove, oldConnection)
|
||||
}
|
||||
}
|
||||
return connectionsToRemove
|
||||
}
|
||||
|
||||
func (c *configurator) removeConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) {
|
||||
for _, conn := range connections {
|
||||
const remove = true
|
||||
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
|
||||
c.logger.Error("cannot remove outdated VPN connection through firewall: %s", err)
|
||||
continue
|
||||
}
|
||||
c.vpnConnections = removeConnectionFromConnections(c.vpnConnections, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *configurator) addConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) error {
|
||||
const remove = false
|
||||
for _, conn := range connections {
|
||||
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
|
||||
return err
|
||||
}
|
||||
c.vpnConnections = append(c.vpnConnections, conn)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user