Feat: OPENVPN_INTERFACE defaulting to tun0

- Fix: custom config with custom network interface name for firewall
- Keep VPN tunnel interface in firewall state
- Vul fix: only allow traffic through vpn interface when needed
- Adapt code to adapt to network interface name
- Remove outdated TUN and TAP constants
This commit is contained in:
Quentin McGaw (desktop)
2021-08-19 23:22:55 +00:00
parent 7191d4e911
commit bec8ff27ae
20 changed files with 219 additions and 89 deletions

View File

@@ -76,6 +76,7 @@ ENV VPNSP=pia \
OPENVPN_TARGET_IP= \
OPENVPN_IPV6=off \
OPENVPN_CUSTOM_CONFIG= \
OPENVPN_INTERFACE=tun0 \
TZ= \
PUID= \
PGID= \

View File

@@ -290,7 +290,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
}
for _, vpnPort := range allSettings.Firewall.VPNInputPorts {
err = firewallConf.SetAllowedPort(ctx, vpnPort, string(constants.TUN))
err = firewallConf.SetAllowedPort(ctx, vpnPort, allSettings.VPN.OpenVPN.Interface)
if err != nil {
return err
}

View File

@@ -1,7 +1,9 @@
package configuration
import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"
@@ -26,6 +28,7 @@ type OpenVPN struct {
EncPreset string `json:"encryption_preset"` // PIA
IPv6 bool `json:"ipv6"` // Mullvad
ProcUser string `json:"procuser"` // Process username
Interface string `json:"interface"`
}
func (settings *OpenVPN) String() string {
@@ -39,6 +42,8 @@ func (settings *OpenVPN) lines() (lines []string) {
lines = append(lines, indent+lastIndent+"Verbosity level: "+strconv.Itoa(settings.Verbosity))
lines = append(lines, indent+lastIndent+"Network interface: "+settings.Interface)
if len(settings.Flags) > 0 {
lines = append(lines, indent+lastIndent+"Flags: "+strings.Join(settings.Flags, " "))
}
@@ -148,6 +153,11 @@ func (settings *OpenVPN) read(r reader, serviceProvider string) (err error) {
return fmt.Errorf("environment variable OPENVPN_IPV6: %w", err)
}
settings.Interface, err = readInterface(r.env)
if err != nil {
return err
}
settings.EncPreset, err = getPIAEncryptionPreset(r)
if err != nil {
return err
@@ -173,3 +183,22 @@ func readProtocol(env params.Env) (tcp bool, err error) {
}
return protocol == constants.TCP, nil
}
const openvpnIntfRegexString = `^.*[0-9]$`
var openvpnIntfRegexp = regexp.MustCompile(openvpnIntfRegexString)
var errInterfaceNameNotValid = errors.New("interface name is not valid")
func readInterface(env params.Env) (intf string, err error) {
intf, err = env.Get("OPENVPN_INTERFACE", params.Default("tun0"))
if err != nil {
return "", fmt.Errorf("environment variable OPENVPN_INTERFACE: %w", err)
}
if !openvpnIntfRegexp.MatchString(intf) {
return "", fmt.Errorf("%w: does not match regex %s: %s",
errInterfaceNameNotValid, openvpnIntfRegexString, intf)
}
return intf, nil
}

View File

@@ -29,7 +29,8 @@ func Test_OpenVPN_JSON(t *testing.T) {
"version": "",
"encryption_preset": "",
"ipv6": false,
"procuser": ""
"procuser": "",
"interface": ""
}`, string(data))
var out OpenVPN
err = json.Unmarshal(data, &out)

View File

@@ -23,6 +23,7 @@ func Test_Settings_lines(t *testing.T) {
},
OpenVPN: OpenVPN{
Version: constants.Openvpn25,
Interface: "tun",
},
},
},
@@ -33,6 +34,7 @@ func Test_Settings_lines(t *testing.T) {
" |--OpenVPN:",
" |--Version: 2.5",
" |--Verbosity level: 0",
" |--Network interface: tun",
" |--Mullvad settings:",
" |--OpenVPN selection:",
" |--Protocol: udp",

View File

@@ -1,10 +1,5 @@
package constants
const (
TUN = "tun0"
TAP = "tap0"
)
const (
AES128cbc = "aes-128-cbc"
AES256cbc = "aes-256-cbc"

View File

@@ -4,8 +4,6 @@ import (
"context"
"errors"
"fmt"
"github.com/qdm12/gluetun/internal/constants"
)
var (
@@ -109,10 +107,10 @@ func (c *Config) enable(ctx context.Context) (err error) {
if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil {
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
return fmt.Errorf("cannot enable firewall: %w", err)
}
}
for _, network := range c.localNetworks {
if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, *network.IPNet, remove); err != nil {

View File

@@ -40,6 +40,7 @@ type Config struct { //nolint:maligned
// State
enabled bool
vpnConnection models.Connection
vpnIntf string
outboundSubnets []net.IPNet
allowedInputPorts map[uint16]string // port to interface mapping
stateMutex sync.Mutex

View File

@@ -8,10 +8,12 @@ import (
)
type VPNConnectionSetter interface {
SetVPNConnection(ctx context.Context, connection models.Connection) error
SetVPNConnection(ctx context.Context,
connection models.Connection, vpnIntf string) error
}
func (c *Config) SetVPNConnection(ctx context.Context, connection models.Connection) (err error) {
func (c *Config) SetVPNConnection(ctx context.Context,
connection models.Connection, vpnIntf string) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
@@ -34,10 +36,25 @@ func (c *Config) SetVPNConnection(ctx context.Context, connection models.Connect
}
}
c.vpnConnection = models.Connection{}
if c.vpnIntf != "" {
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
c.logger.Error("cannot remove outdated VPN interface from firewall: " + err.Error())
}
}
c.vpnIntf = ""
remove = false
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil {
return fmt.Errorf("cannot set VPN connection through firewall: %w", err)
}
c.vpnConnection = connection
if err = c.acceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil {
return fmt.Errorf("cannot accept output traffic through interface %s: %w", vpnIntf, err)
}
c.vpnIntf = vpnIntf
return nil
}

View File

@@ -33,18 +33,15 @@ func (c Connection) OpenVPNProtoLine() (line string) {
}
// UpdateEmptyWith updates each field of the connection where the
// value is not set using the value from the other connection.
func (c *Connection) UpdateEmptyWith(connection Connection) {
// value is not set using the value given as arguments.
func (c *Connection) UpdateEmptyWith(ip net.IP, port uint16, protocol string) {
if c.IP == nil {
c.IP = connection.IP
c.IP = ip
}
if c.Port == 0 {
c.Port = connection.Port
c.Port = port
}
if c.Protocol == "" {
c.Protocol = connection.Protocol
}
if c.Hostname == "" {
c.Hostname = connection.Hostname
c.Protocol = protocol
}
}

View File

@@ -14,18 +14,22 @@ var (
)
func BuildConfig(settings configuration.OpenVPN) (
lines []string, connection models.Connection, err error) {
lines []string, connection models.Connection, intf string, err error) {
lines, err = readCustomConfigLines(settings.Config)
if err != nil {
return nil, connection, fmt.Errorf("%w: %s", ErrReadCustomConfig, err)
return nil, connection, "", fmt.Errorf("%w: %s", ErrReadCustomConfig, err)
}
connection, err = extractConnectionFromLines(lines)
connection, intf, err = extractDataFromLines(lines)
if err != nil {
return nil, connection, fmt.Errorf("%w: %s", ErrExtractConnection, err)
return nil, connection, "", fmt.Errorf("%w: %s", ErrExtractConnection, err)
}
lines = modifyCustomConfig(lines, settings, connection)
return lines, connection, nil
if intf == "" {
intf = settings.Interface
}
lines = modifyCustomConfig(lines, settings, connection, intf)
return lines, connection, intf, nil
}

View File

@@ -30,15 +30,17 @@ func Test_BuildConfig(t *testing.T) {
Cipher: "cipher",
MSSFix: 999,
Config: file.Name(),
Interface: "tun",
}
lines, connection, err := BuildConfig(settings)
lines, connection, intf, err := BuildConfig(settings)
assert.NoError(t, err)
expectedLines := []string{
"keep me",
"proto udp",
"remote 1.9.8.7 1194",
"dev tun",
"mute-replay-warnings",
"auth-nocache",
"pull-filter ignore \"auth-token\"",
@@ -60,4 +62,6 @@ func Test_BuildConfig(t *testing.T) {
Protocol: constants.UDP,
}
assert.Equal(t, expectedConnection, connection)
assert.Equal(t, "tun", intf)
}

View File

@@ -15,23 +15,24 @@ var (
errRemoteLineNotFound = errors.New("remote line not found")
)
// extractConnectionFromLines always takes the first remote line only.
func extractConnectionFromLines(lines []string) (
connection models.Connection, err error) {
func extractDataFromLines(lines []string) (
connection models.Connection, intf string, err error) {
for i, line := range lines {
newConnectionData, err := extractConnectionFromLine(line)
ip, port, protocol, intfFound, err := extractDataFromLine(line)
if err != nil {
return connection, fmt.Errorf("on line %d: %w", i+1, err)
return connection, "", fmt.Errorf("on line %d: %w", i+1, err)
}
connection.UpdateEmptyWith(newConnectionData)
if connection.Protocol != "" && connection.IP != nil {
intf = intfFound
connection.UpdateEmptyWith(ip, port, protocol)
if connection.Protocol != "" && connection.IP != nil && intf != "" {
break
}
}
if connection.IP == nil {
return connection, errRemoteLineNotFound
return connection, "", errRemoteLineNotFound
}
if connection.Protocol == "" {
@@ -45,32 +46,41 @@ func extractConnectionFromLines(lines []string) (
}
}
return connection, nil
return connection, intf, nil
}
var (
errExtractProto = errors.New("failed extracting protocol from proto line")
errExtractRemote = errors.New("failed extracting protocol from remote line")
errExtractRemote = errors.New("failed extracting from remote line")
errExtractDev = errors.New("failed extracting network interface from dev line")
)
func extractConnectionFromLine(line string) (
connection models.Connection, err error) {
func extractDataFromLine(line string) (
ip net.IP, port uint16, protocol, intf string, err error) {
switch {
case strings.HasPrefix(line, "proto "):
connection.Protocol, err = extractProto(line)
protocol, err = extractProto(line)
if err != nil {
return connection, fmt.Errorf("%w: %s", errExtractProto, err)
return nil, 0, "", "", fmt.Errorf("%w: %s", errExtractProto, err)
}
return nil, 0, protocol, "", nil
case strings.HasPrefix(line, "remote "):
ip, port, protocol, err = extractRemote(line)
if err != nil {
return nil, 0, "", "", fmt.Errorf("%w: %s", errExtractRemote, err)
}
return ip, port, protocol, "", nil
case strings.HasPrefix(line, "dev "):
intf, err = extractInterfaceFromLine(line)
if err != nil {
return nil, 0, "", "", fmt.Errorf("%w: %s", errExtractDev, err)
}
return nil, 0, "", intf, nil
}
// only take the first remote line
case strings.HasPrefix(line, "remote ") && connection.IP == nil:
connection.IP, connection.Port, connection.Protocol, err = extractRemote(line)
if err != nil {
return connection, fmt.Errorf("%w: %s", errExtractRemote, err)
}
}
return connection, nil
return nil, 0, "", "", nil
}
var (
@@ -137,3 +147,16 @@ func extractRemote(line string) (ip net.IP, port uint16,
return ip, port, protocol, nil
}
var (
errDevLineFieldsCount = errors.New("dev line has not 2 fields as expected")
)
func extractInterfaceFromLine(line string) (intf string, err error) {
fields := strings.Fields(line)
if len(fields) != 2 { //nolint:gomnd
return "", fmt.Errorf("%w: %s", errDevLineFieldsCount, line)
}
return fields[1], nil
}

View File

@@ -11,21 +11,23 @@ import (
"github.com/stretchr/testify/require"
)
func Test_extractConnectionFromLines(t *testing.T) {
func Test_extractDataFromLines(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
lines []string
connection models.Connection
intf string
err error
}{
"success": {
lines: []string{"bla bla", "proto tcp", "remote 1.2.3.4 1194 tcp"},
lines: []string{"bla bla", "proto tcp", "remote 1.2.3.4 1194 tcp", "dev tun6"},
connection: models.Connection{
IP: net.IPv4(1, 2, 3, 4),
Port: 1194,
Protocol: constants.TCP,
},
intf: "tun6",
},
"extraction error": {
lines: []string{"bla bla", "proto bad", "remote 1.2.3.4 1194 tcp"},
@@ -69,7 +71,7 @@ func Test_extractConnectionFromLines(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
connection, err := extractConnectionFromLines(testCase.lines)
connection, intf, err := extractDataFromLines(testCase.lines)
if testCase.err != nil {
require.Error(t, err)
@@ -79,16 +81,20 @@ func Test_extractConnectionFromLines(t *testing.T) {
}
assert.Equal(t, testCase.connection, connection)
assert.Equal(t, testCase.intf, intf)
})
}
}
func Test_extractConnectionFromLine(t *testing.T) {
func Test_extractDataFromLine(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
line string
connection models.Connection
ip net.IP
port uint16
protocol string
intf string
isErr error
}{
"irrelevant line": {
@@ -100,9 +106,15 @@ func Test_extractConnectionFromLine(t *testing.T) {
},
"extract proto success": {
line: "proto tcp",
connection: models.Connection{
Protocol: constants.TCP,
protocol: constants.TCP,
},
"extract intf error": {
line: "dev ",
isErr: errExtractDev,
},
"extract intf success": {
line: "dev tun3",
intf: "tun3",
},
"extract remote error": {
line: "remote bad",
@@ -110,11 +122,9 @@ func Test_extractConnectionFromLine(t *testing.T) {
},
"extract remote success": {
line: "remote 1.2.3.4 1194 udp",
connection: models.Connection{
IP: net.IPv4(1, 2, 3, 4),
Port: 1194,
Protocol: constants.UDP,
},
ip: net.IPv4(1, 2, 3, 4),
port: 1194,
protocol: constants.UDP,
},
}
@@ -123,7 +133,7 @@ func Test_extractConnectionFromLine(t *testing.T) {
t.Run(name, func(t *testing.T) {
t.Parallel()
connection, err := extractConnectionFromLine(testCase.line)
ip, port, protocol, intf, err := extractDataFromLine(testCase.line)
if testCase.isErr != nil {
assert.ErrorIs(t, err, testCase.isErr)
@@ -131,7 +141,10 @@ func Test_extractConnectionFromLine(t *testing.T) {
assert.NoError(t, err)
}
assert.Equal(t, testCase.connection, connection)
assert.Equal(t, testCase.ip, ip)
assert.Equal(t, testCase.port, port)
assert.Equal(t, testCase.protocol, protocol)
assert.Equal(t, testCase.intf, intf)
})
}
}
@@ -260,3 +273,44 @@ func Test_extractRemote(t *testing.T) {
})
}
}
func Test_extractInterface(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
line string
intf string
err error
}{
"found": {
line: "dev tun3",
intf: "tun3",
},
"not enough fields": {
line: "dev ",
err: errors.New("dev line has not 2 fields as expected: dev "),
},
"too many fields": {
line: "dev one two",
err: errors.New("dev line has not 2 fields as expected: dev one two"),
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
intf, err := extractInterfaceFromLine(testCase.line)
if testCase.err != nil {
require.Error(t, err)
assert.Equal(t, testCase.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, testCase.intf, intf)
})
}
}

View File

@@ -11,7 +11,7 @@ import (
)
func modifyCustomConfig(lines []string, settings configuration.OpenVPN,
connection models.Connection) (modified []string) {
connection models.Connection, intf string) (modified []string) {
// Remove some lines
for _, line := range lines {
switch {
@@ -22,6 +22,7 @@ func modifyCustomConfig(lines []string, settings configuration.OpenVPN,
strings.HasPrefix(line, "user "),
strings.HasPrefix(line, "proto "),
strings.HasPrefix(line, "remote "),
strings.HasPrefix(line, "dev "),
settings.Cipher != "" && strings.HasPrefix(line, "cipher "),
settings.Cipher != "" && strings.HasPrefix(line, "data-ciphers "),
settings.Auth != "" && strings.HasPrefix(line, "auth "),
@@ -35,6 +36,7 @@ func modifyCustomConfig(lines []string, settings configuration.OpenVPN,
// Add values
modified = append(modified, connection.OpenVPNProtoLine())
modified = append(modified, connection.OpenVPNRemoteLine())
modified = append(modified, "dev "+intf)
modified = append(modified, "mute-replay-warnings")
modified = append(modified, "auth-nocache")
modified = append(modified, "pull-filter ignore \"auth-token\"") // prevent auth failed loop

View File

@@ -17,6 +17,7 @@ func Test_modifyCustomConfig(t *testing.T) {
lines []string
settings configuration.OpenVPN
connection models.Connection
intf string
modified []string
}{
"mixed": {
@@ -41,10 +42,12 @@ func Test_modifyCustomConfig(t *testing.T) {
Port: 1194,
Protocol: constants.UDP,
},
intf: "tun3",
modified: []string{
"keep me here",
"proto udp",
"remote 1.2.3.4 1194",
"dev tun3",
"mute-replay-warnings",
"auth-nocache",
"pull-filter ignore \"auth-token\"",
@@ -69,7 +72,7 @@ func Test_modifyCustomConfig(t *testing.T) {
t.Parallel()
modified := modifyCustomConfig(testCase.lines,
testCase.settings, testCase.connection)
testCase.settings, testCase.connection, testCase.intf)
assert.Equal(t, testCase.modified, modified)
})

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"net"
"github.com/qdm12/gluetun/internal/constants"
"github.com/vishvananda/netlink"
)
@@ -242,10 +241,10 @@ func (r *routing) VPNDestinationIP() (ip net.IP, err error) {
}
type VPNLocalGatewayIPGetter interface {
VPNLocalGatewayIP() (ip net.IP, err error)
VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error)
}
func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) {
func (r *routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) {
routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL)
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrRoutesList, err)
@@ -256,7 +255,7 @@ func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) {
return nil, fmt.Errorf("%w: %s", ErrLinkByIndex, err)
}
interfaceName := link.Attrs().Name
if interfaceName == string(constants.TUN) &&
if interfaceName == vpnIntf &&
route.Dst != nil &&
route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) {
return route.Gw, nil

View File

@@ -29,14 +29,16 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
settings configuration.VPN, starter command.Starter, logger logging.Logger) (
runner vpnRunner, serverName string, err error) {
var connection models.Connection
var netInterface string
var lines []string
if settings.OpenVPN.Config == "" {
netInterface = settings.OpenVPN.Interface
connection, err = providerConf.GetConnection(settings.Provider.ServerSelection)
if err == nil {
lines = providerConf.BuildConf(connection, settings.OpenVPN)
}
} else {
lines, connection, err = custom.BuildConfig(settings.OpenVPN)
lines, connection, netInterface, err = custom.BuildConfig(settings.OpenVPN)
}
if err != nil {
return nil, "", fmt.Errorf("%w: %s", errBuildConfig, err)
@@ -53,7 +55,7 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter,
}
}
if err := fw.SetVPNConnection(ctx, connection); err != nil {
if err := fw.SetVPNConnection(ctx, connection, netInterface); err != nil {
return nil, "", fmt.Errorf("%w: %s", errFirewall, err)
}

View File

@@ -6,9 +6,7 @@ import (
"fmt"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/portforward"
"github.com/qdm12/gluetun/internal/provider"
)
var (
@@ -16,24 +14,23 @@ var (
errStartPortForwarding = errors.New("cannot start port forwarding")
)
func (l *Loop) startPortForwarding(ctx context.Context, enabled bool,
portForwarder provider.PortForwarder, serverName string) (err error) {
if !enabled {
func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err error) {
if !data.portForwarding {
return nil
}
// only used for PIA for now
gateway, err := l.routing.VPNLocalGatewayIP()
gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf)
if err != nil {
return fmt.Errorf("%w: %s", errObtainVPNLocalGateway, err)
}
l.logger.Info("VPN gateway IP address: " + gateway.String())
pfData := portforward.StartData{
PortForwarder: portForwarder,
PortForwarder: data.portForwarder,
Gateway: gateway,
ServerName: serverName,
Interface: constants.TUN,
ServerName: data.serverName,
Interface: data.vpnIntf,
}
_, err = l.portForward.Start(ctx, pfData)
if err != nil {

View File

@@ -11,6 +11,7 @@ import (
type tunnelUpData struct {
// Port forwarding
portForwarding bool
vpnIntf string
serverName string
portForwarder provider.PortForwarder
}
@@ -39,7 +40,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) {
}
}
err = l.startPortForwarding(ctx, data.portForwarding, data.portForwarder, data.serverName)
err = l.startPortForwarding(ctx, data)
if err != nil {
l.logger.Error(err.Error())
}